def testWithSosReportAlreadyInstalled(self):
        # SETUP
        user = ssh.GetDefaultSshUsername()
        install_path = SOSREPORT_INSTALL_PATH
        reports_path = REPORTS_PATH

        # EXECUTE
        self.Run("""
             compute diagnose sosreport {name} --zone={zone}
             """.format(name=INSTANCE.name, zone=INSTANCE.zone))

        # VERIFY
        # We assert on the amount of calls
        self.AssertSSHCallCount(5)
        self.AssertSSHCall(["ls", os.path.join(install_path, "sosreport")])
        self.AssertSSHCall(["mkdir", "-p", reports_path])
        self.AssertSSHCall([
            "sudo",
            os.path.join(install_path, "sosreport"), "--batch",
            "--compression-type", "gzip", "--config-file",
            os.path.join(install_path, "sos.conf"), "--tmp-dir", reports_path
        ])
        self.AssertSSHCall(
            ["sudo", "chown", user,
             os.path.join(reports_path, "*")])
        self.AssertSSHCall([
            "ls", "-t",
            os.path.join(reports_path, "*.tar.gz"), "|", "head", "-n", "1"
        ])
    def testWithCustomPath(self):
        # SETUP
        user = ssh.GetDefaultSshUsername()
        install_path = "/custom/install/path"
        reports_path = "/custom/report/path"

        # EXECUTE
        self.Run("""
             compute diagnose sosreport {name}
             --zone={zone}
             --sosreport-install-path="{install_path}"
             --reports-path="{reports_path}"
             """.format(name=INSTANCE.name,
                        zone=INSTANCE.zone,
                        install_path=install_path,
                        reports_path=reports_path))

        # VERIFY
        # We assert on the amount of calls
        self.AssertSSHCallCount(5)
        self.AssertSSHCall(["ls", os.path.join(install_path, "sosreport")])
        self.AssertSSHCall(["mkdir", "-p", reports_path])
        self.AssertSSHCall([
            "sudo",
            os.path.join(install_path, "sosreport"), "--batch",
            "--compression-type", "gzip", "--config-file",
            os.path.join(install_path, "sos.conf"), "--tmp-dir", reports_path
        ])
        self.AssertSSHCall(
            ["sudo", "chown", user,
             os.path.join(reports_path, "*")])
        self.AssertSSHCall([
            "ls", "-t",
            os.path.join(reports_path, "*.tar.gz"), "|", "head", "-n", "1"
        ])
def GetUserAndInstance(user_host):
    """Returns pair consiting of user name and instance name."""
    parts = user_host.split('@')
    if len(parts) == 1:
        user = ssh.GetDefaultSshUsername(warn_on_account_user=True)
        instance = parts[0]
        return user, instance
    if len(parts) == 2:
        return parts
    raise exceptions.ToolException(
        'Expected argument of the form [USER@]INSTANCE; received [{0}].'.
        format(user_host))
Example #4
0
def GetUserAndInstance(user_host, use_account_service, http):
    """Returns pair consiting of user name and instance name."""
    parts = user_host.split('@')
    if len(parts) == 1:
        if use_account_service:  # Using Account Service.
            user = gaia.GetDefaultAccountName(http)
        else:  # Uploading keys through metadata.
            user = ssh.GetDefaultSshUsername(warn_on_account_user=True)
        instance = parts[0]
        return user, instance
    if len(parts) == 2:
        return parts
    raise exceptions.ToolException(
        'Expected argument of the form [USER@]INSTANCE; received [{0}].'.
        format(user_host))
Example #5
0
    def Run(self, args):
        """Default run method implementation."""
        super(SosReport, self).Run(args)
        self._use_accounts_service = False

        # Obtain the gcloud variables
        holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
        instance = self.GetInstance(holder, args)
        user = args.user if args.user else ssh.GetDefaultSshUsername()
        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)

        # Create the context variables
        context = {
            "args": args,
            "instance": instance,
            "ssh_helper": ssh_helper,
            "user": user,
            "python_path": args.python_path,
        }
        install_path = args.sosreport_install_path
        reports_path = args.reports_path

        # We dowload Sosreport into the VM if needed (normally first time)
        soshelper.ObtainSosreport(context, install_path)

        # (If needed) We create the directory where the reports will be created
        log.out.Print(
            "Creating the path where reports will be written if needed.")
        soshelper.CreatePath(context, reports_path)

        # We run the report
        soshelper.RunSosreport(context, install_path, reports_path)

        # Obtain and report the filename of the generated report
        report_path = soshelper.ObtainReportFilename(context, reports_path)
        msg = 'Report generated into "{report_path}".'
        log.status.Print(msg.format(report_path=report_path))

        # If download_dir is set, we download the report over
        if args.download_dir:
            report_path = soshelper.CopyReportFile(context, args.download_dir,
                                                   report_path)
            msg = 'Successfully downloaded report to "{report_path}"'
            log.status.Print(msg.format(report_path=report_path))
Example #6
0
    def SSHToInstance(self, args, instance):
        """Helper to manage authentication followed by SSH to the instance."""
        args = self._DefaultArgsForSSH(args)

        external_nat = ssh_utils.GetExternalIPAddress(instance)
        log.status.Print(
            'Trying to SSH to VM with NAT IP:{}'.format(external_nat))
        remote = ssh.Remote(external_nat, ssh.GetDefaultSshUsername())
        args.ssh_key_file = ssh.Keys.DEFAULT_KEY_FILE

        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)
        identity_file = ssh_helper.keys.key_file

        user, _ = ssh_utils.GetUserAndInstance(args.name)
        host_keys = self._GetHostKeyFromInstance(args.zone, ssh_helper,
                                                 instance)
        options = self._GetSSHOptions(args.name, ssh_helper, instance,
                                      host_keys)
        self._WaitForSSHKeysToPropagate(ssh_helper, remote, identity_file,
                                        user, instance, options)

        extra_flags = []
        # Ctpu seems to be forwarding some other ports on what
        # seems like the TPU node. Need to understand better before enabling.
        if args.forward_ports:
            extra_flags.extend([
                '-A', '-L', '6006:localhost:6006', '-L', '8888:localhost:8888'
            ])
        ssh_cmd_args = {
            'remote': remote,
            'identity_file': identity_file,
            'options': options,
            'extra_flags': extra_flags
        }
        cmd = ssh.SSHCommand(**ssh_cmd_args)
        # Errors from SSH itself result in an ssh.CommandError being raised
        return_code = cmd.Run(
            ssh_helper.env,
            force_connect=properties.VALUES.ssh.putty_force_connect.GetBool())
        if return_code:
            # This is the return code of the remote command.  Problems with SSH itself
            # will result in ssh.CommandError being raised above.
            sys.exit(return_code)
    def testCustomPythonPath(self):
        # SETUP
        user = ssh.GetDefaultSshUsername()
        install_path = SOSREPORT_INSTALL_PATH
        reports_path = REPORTS_PATH
        python_path = PYTHON_PATH

        # We need the first command to fail
        self.mock_ssh_command.side_effect = [1, 0, 0, 0, 0, 0, 0]  # Error

        # EXECUTE
        self.Run("""
             compute diagnose sosreport {name}
             --zone={zone}
             --python-path={python_path}
             """.format(name=INSTANCE.name,
                        zone=INSTANCE.zone,
                        python_path=python_path))

        # VERIFY
        # We assert on the amount of calls
        self.AssertSSHCallCount(7)
        self.AssertSSHCall(["ls", os.path.join(install_path, "sosreport")])
        self.AssertSSHCall(["mkdir", "-p", install_path])
        self.AssertSSHCall([
            "git", "clone", "https://github.com/sosreport/sos.git",
            install_path
        ])
        self.AssertSSHCall(["mkdir", "-p", reports_path])
        self.AssertSSHCall([
            "sudo", python_path,
            os.path.join(install_path, "sosreport"), "--batch",
            "--compression-type", "gzip", "--config-file",
            os.path.join(install_path, "sos.conf"), "--tmp-dir", reports_path
        ])
        self.AssertSSHCall(
            ["sudo", "chown", user,
             os.path.join(reports_path, "*")])
        self.AssertSSHCall([
            "ls", "-t",
            os.path.join(reports_path, "*.tar.gz"), "|", "head", "-n", "1"
        ])
Example #8
0
    def testSSHToInstance(self):
        ssh_util = tpu_utils.SSH(self.track)
        args = SSHTest.Args('fake-name', 'central2-a', True)
        instance = self._makeFakeInstance('fake-instance')

        self.make_requests.side_effect = iter([
            [self._makeFakeProjectResource()],
            [self._makeFakeProjectResource()],
            [self._makeFakeProjectResource()],
        ])

        ssh_util.SSHToInstance(args, instance)

        self.ensure_keys.assert_called_once_with(self.keys,
                                                 None,
                                                 allow_passphrase=True)

        self.poller_poll.assert_called_once_with(mock_matchers.TypeMatcher(
            ssh.SSHPoller),
                                                 self.env,
                                                 force_connect=True)

        # SSH Command
        self.ssh_init.assert_called_once_with(
            mock_matchers.TypeMatcher(ssh.SSHCommand),
            remote=ssh.Remote('23.251.133.75', ssh.GetDefaultSshUsername()),
            identity_file=self.private_key_file,
            extra_flags=[
                '-A', '-L', '6006:localhost:6006', '-L', '8888:localhost:8888'
            ],
            options=dict(
                self.options,
                HostKeyAlias='compute.1111',
                SendEnv='TPU_NAME',
            ))

        self.ssh_run.assert_called_once_with(mock_matchers.TypeMatcher(
            ssh.SSHCommand),
                                             self.env,
                                             force_connect=True)
    def testDryRun(self):
        # SETUP
        user = ssh.GetDefaultSshUsername()
        install_path = SOSREPORT_INSTALL_PATH
        reports_path = REPORTS_PATH

        # We need the first command to fail
        self.mock_ssh_command.side_effect = [1, 0, 0, 0, 0, 0, 0]  # Error

        # EXECUTE
        self.Run("""
             compute diagnose sosreport {name}
             --zone={zone}
             --dry-run
             """.format(name=INSTANCE.name, zone=INSTANCE.zone))

        # Dry run does not call SSH
        self.AssertSSHCallCount(6)

        self.AssertSSHCall(
            ["ls", os.path.join(install_path, "sosreport")], is_dry_run=True)
        self.AssertSSHCall(["mkdir", "-p", install_path], is_dry_run=True)
        self.AssertSSHCall([
            "git", "clone", "https://github.com/sosreport/sos.git",
            install_path
        ],
                           is_dry_run=True)
        self.AssertSSHCall(["mkdir", "-p", reports_path], is_dry_run=True)
        self.AssertSSHCall([
            "sudo",
            os.path.join(install_path, "sosreport"), "--batch",
            "--compression-type", "gzip", "--config-file",
            os.path.join(install_path, "sos.conf"), "--tmp-dir", reports_path
        ],
                           is_dry_run=True)
        self.AssertSSHCall(
            ["sudo", "chown", user,
             os.path.join(reports_path, "*")],
            is_dry_run=True)
Example #10
0
    def Run(self, args):
        """See ssh_utils.BaseSSHCommand.Run."""
        holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
        client = holder.client

        ssh_helper = ssh_utils.BaseSSHHelper()
        ssh_helper.Run(args)
        ssh_helper.keys.EnsureKeysExist(args.force_key_file_overwrite,
                                        allow_passphrase=True)

        ssh_config_file = files.ExpandHomeDir(args.ssh_config_file
                                              or ssh.PER_USER_SSH_CONFIG_FILE)

        instances = None
        try:
            existing_content = files.ReadFileContents(ssh_config_file)
        except files.Error as e:
            existing_content = ''
            log.debug('SSH Config File [{0}] could not be opened: {1}'.format(
                ssh_config_file, e))

        if args.remove:
            compute_section = ''
            try:
                new_content = _RemoveComputeSection(existing_content)
            except MultipleComputeSectionsError:
                raise MultipleComputeSectionsError(ssh_config_file)
        else:
            ssh_helper.EnsureSSHKeyIsInProject(
                client, ssh.GetDefaultSshUsername(warn_on_account_user=True),
                None)
            instances = list(self.GetRunningInstances(client))
            if instances:
                compute_section = _BuildComputeSection(
                    instances, ssh_helper.keys.key_file,
                    ssh.KnownHosts.DEFAULT_PATH)
            else:
                compute_section = ''

        if existing_content and not args.remove:
            try:
                new_content = _MergeComputeSections(existing_content,
                                                    compute_section)
            except MultipleComputeSectionsError:
                raise MultipleComputeSectionsError(ssh_config_file)
        elif not existing_content:
            new_content = compute_section

        if args.dry_run:
            log.out.write(new_content or '')
            return

        if new_content != existing_content:
            if (os.path.exists(ssh_config_file)
                    and platforms.OperatingSystem.Current()
                    is not platforms.OperatingSystem.WINDOWS):
                ssh_config_perms = os.stat(ssh_config_file).st_mode
                # From `man 5 ssh_config`:
                #    this file must have strict permissions: read/write for the user,
                #    and not accessible by others.
                # We check that here:
                if not (ssh_config_perms & stat.S_IRWXU == stat.S_IWUSR
                        | stat.S_IRUSR and ssh_config_perms & stat.S_IWGRP == 0
                        and ssh_config_perms & stat.S_IWOTH == 0):
                    log.warning(
                        'Invalid permissions on [{0}]. Please change to match ssh '
                        'requirements (see man 5 ssh).')
            # TODO(b/36050483): This write will not work very well if there is
            # a lot of write contention for the SSH config file. We should
            # add a function to do a better job at "atomic file writes".
            files.WriteFileContents(ssh_config_file, new_content, private=True)

        if compute_section:
            log.out.write(
                textwrap.dedent("""\
          You should now be able to use ssh/scp with your instances.
          For example, try running:

            $ ssh {alias}

          """.format(alias=_CreateAlias(instances[0]))))

        elif not instances and not args.remove:
            log.warning(
                'No host aliases were added to your SSH configs because you do not '
                'have any running instances. Try running this command again after '
                'running some instances.')
Example #11
0
  def Run(self, args, port=None, recursive=False, extra_flags=None):
    """SCP files between local and remote GCE instance.

    Run this method from subclasses' Run methods.

    Args:
      args: argparse.Namespace, the args the command was invoked with.
      port: str, int or None, Port number to use for SSH connection.
      recursive: bool, Whether to use recursive copying using -R flag.
      extra_flags: [str] or None, extra flags to add to command invocation.

    Raises:
      ssh_utils.NetworkError: Network issue which likely is due to failure
        of SSH key propagation.
      ssh.CommandError: The SSH command exited with SSH exit code, which
        usually implies that a connection problem occurred.
    """
    super(BaseScpCommand, self).Run(args)

    dst = ssh.FileReference.FromPath(args.destination)
    srcs = [ssh.FileReference.FromPath(src) for src in args.sources]

    # Make sure we have a unique remote
    ssh.SCPCommand.Verify(srcs, dst, single_remote=True)

    remote = dst.remote or srcs[0].remote
    if not dst.remote:  # Make sure all remotes point to the same ref
      for src in srcs:
        src.remote = remote

    instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
        [remote.host], compute_scope.ScopeEnum.ZONE, args.zone, self.resources,
        scope_lister=flags.GetDefaultScopeLister(self.compute_client))[0]
    instance = self.GetInstance(instance_ref)

    # Now replace the instance name with the actual IP/hostname
    remote.host = ssh_utils.GetExternalIPAddress(instance)
    if not remote.user:
      remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True)

    identity_file = None
    options = None
    if not args.plain:
      identity_file = self.keys.key_file
      options = self.GetConfig(ssh_utils.HostKeyAlias(instance),
                               args.strict_host_key_checking)

    cmd = ssh.SCPCommand(
        srcs, dst, identity_file=identity_file, options=options,
        recursive=recursive, port=port, extra_flags=extra_flags)

    if args.dry_run:
      log.out.Print(' '.join(cmd.Build(self.env)))
      return

    if args.plain:
      keys_newly_added = False
    else:
      keys_newly_added = self.EnsureSSHKeyExists(
          remote.user, instance, instance_ref.project,
          use_account_service=self._use_account_service)

    if keys_newly_added:
      poller = ssh.SSHPoller(
          remote, identity_file=identity_file, options=options,
          max_wait_ms=ssh_utils.SSH_KEY_PROPAGATION_TIMEOUT_SEC)
      log.status.Print('Waiting for SSH key to propagate.')
      # TODO(b/35355795): Don't force_connect
      try:
        poller.Poll(self.env, force_connect=True)
      except retry.WaitException:
        raise ssh_utils.NetworkError()
    return_code = cmd.Run(self.env, force_connect=True)
    if return_code:
      # Can't raise an exception because we don't want any "ERROR" message
      # printed; the output from `ssh` will be enough.
      sys.exit(return_code)
Example #12
0
    def Run(self, args):
        """Default run method implementation."""
        super(Routes, self).Run(args)

        self._use_accounts_service = False

        holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
        resource_registry = holder.resources
        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)

        # We store always needed commands non-changing fields
        self._args = args
        self._ssh_helper = ssh_helper

        # We obtain generic parameters of the call
        project = properties.VALUES.core.project.GetOrFail()
        filters = _RoutesArgs.GetFilters(args)
        instances = _RoutesQueries.ObtainInstances(
            args.names,
            service=self.compute.instances,
            project=project,
            zones=args.zones,
            filters=filters,
            http=self.http,
            batch_url=self.batch_url)

        user = args.user
        if not user:
            user = ssh.GetDefaultSshUsername()

        # We unpack the flags
        dry_run = args.dry_run
        reverse_traceroute = args.reverse_traceroute
        traceroute_args = args.traceroute_args
        external_route_ip = args.external_route_ip

        internal_helpers.PrintHeader(instances)
        prompt = 'The following VMs will be tracerouted.'
        if instances and not dry_run and not console_io.PromptContinue(prompt):
            return
        # Sometimes the prompt would appear after the instance data
        log.out.flush()

        for instance in instances:
            header = 'Checking instance %s' % instance.name
            log.out.Print(header)
            log.out.Print('-' * len(header))

            try:
                self.TracerouteInstance(instance, traceroute_args, dry_run,
                                        resource_registry)
            except exceptions.ToolException as e:
                log.error('Error routing to instance')
                log.error(str(e))
                continue

            if reverse_traceroute:
                try:
                    has_traceroute = self.CheckTraceroute(
                        instance, user, dry_run, resource_registry)
                    if has_traceroute:
                        # We obtain the self ip
                        if not external_route_ip:
                            external_route_ip = self.ObtainSelfIp(
                                instance, user, dry_run, resource_registry)
                        if external_route_ip:
                            self.ReverseTracerouteInstance(
                                instance, user, external_route_ip,
                                traceroute_args, dry_run, resource_registry)
                        else:
                            log.out.Print(
                                'Unable to obtain self ip. Aborting.')
                    else:
                        log.out.Print(
                            'Please make sure traceroute is installed in PATH to move on.'
                        )
                except ssh.CommandError as e:
                    log.error(str(e))
            log.out.Print('')  # Separator
Example #13
0
    def Run(self, args):
        """Connect to a running flex instance.

    Args:
      args: argparse.Namespace, the args the command was invoked with.

    Raises:
      InvalidInstanceTypeError: The instance is not supported for SSH.
      MissingVersionError: The version specified does not exist.
      MissingInstanceError: The instance specified does not exist.
      UnattendedPromptError: Not running in a tty.
      OperationCancelledError: User cancelled the operation.
      ssh.CommandError: The SSH command exited with SSH exit code, which
        usually implies that a connection problem occurred.

    Returns:
      int, The exit code of the SSH command.
    """
        api_client = appengine_api_client.GetApiClient()
        env = ssh.Environment.Current()
        env.RequireSSH()
        keys = ssh.Keys.FromFilename()
        keys.EnsureKeysExist(overwrite=False)

        try:
            version = api_client.GetVersionResource(service=args.service,
                                                    version=args.version)
        except api_exceptions.NotFoundError:
            raise command_exceptions.MissingVersionError('{}/{}'.format(
                args.service, args.version))
        version = version_util.Version.FromVersionResource(version, None)
        if version.environment is not util.Environment.FLEX:
            if version.environment is util.Environment.MANAGED_VMS:
                environment = 'Managed VMs'
                msg = 'Use `gcloud compute ssh` for Managed VMs instances.'
            else:
                environment = 'Standard'
                msg = None
            raise command_exceptions.InvalidInstanceTypeError(environment, msg)
        res = resources.REGISTRY.Parse(
            args.instance,
            params={
                'appsId': properties.VALUES.core.project.GetOrFail,
                'versionsId': args.version,
                'instancesId': args.instance,
                'servicesId': args.service,
            },
            collection='appengine.apps.services.versions.instances')
        rel_name = res.RelativeName()
        try:
            instance = api_client.GetInstanceResource(res)
        except api_exceptions.NotFoundError:
            raise command_exceptions.MissingInstanceError(rel_name)

        if not instance.vmDebugEnabled:
            log.warn(ENABLE_DEBUG_WARNING)
            console_io.PromptContinue(cancel_on_no=True,
                                      throw_if_unattended=True)
        user = ssh.GetDefaultSshUsername()
        remote = ssh.Remote(instance.vmIp, user=user)
        public_key = keys.GetPublicKey().ToEntry()
        ssh_key = '{user}:{key} {user}'.format(user=user, key=public_key)
        log.status.Print(
            'Sending public key to instance [{}].'.format(rel_name))
        api_client.DebugInstance(res, ssh_key)
        options = {
            'IdentitiesOnly':
            'yes',  # No ssh-agent as of yet
            'UserKnownHostsFile':
            ssh.KnownHosts.DEFAULT_PATH,
            'CheckHostIP':
            'no',
            'HostKeyAlias':
            HOST_KEY_ALIAS.format(project=api_client.project,
                                  instance_id=args.instance)
        }
        cmd = ssh.SSHCommand(remote,
                             identity_file=keys.key_file,
                             options=options)
        if args.container:
            cmd.tty = True
            cmd.remote_command = ['container_exec', args.container, '/bin/sh']
        return cmd.Run(env)
Example #14
0
    def RunScp(self,
               compute_holder,
               args,
               port=None,
               recursive=False,
               compress=False,
               extra_flags=None,
               release_track=None,
               ip_type=ip.IpTypeEnum.EXTERNAL):
        """SCP files between local and remote GCE instance.

    Run this method from subclasses' Run methods.

    Args:
      compute_holder: The ComputeApiHolder.
      args: argparse.Namespace, the args the command was invoked with.
      port: str or None, Port number to use for SSH connection.
      recursive: bool, Whether to use recursive copying using -R flag.
      compress: bool, Whether to use compression.
      extra_flags: [str] or None, extra flags to add to command invocation.
      release_track: obj, The current release track.
      ip_type: IpTypeEnum, Specify using internal ip or external ip address.

    Raises:
      ssh_utils.NetworkError: Network issue which likely is due to failure
        of SSH key propagation.
      ssh.CommandError: The SSH command exited with SSH exit code, which
        usually implies that a connection problem occurred.
    """
        if release_track is None:
            release_track = base.ReleaseTrack.GA
        super(BaseScpHelper, self).Run(args)

        dst = ssh.FileReference.FromPath(args.destination)
        srcs = [ssh.FileReference.FromPath(src) for src in args.sources]

        # Make sure we have a unique remote
        ssh.SCPCommand.Verify(srcs, dst, single_remote=True)

        remote = dst.remote or srcs[0].remote
        if not dst.remote:  # Make sure all remotes point to the same ref
            for src in srcs:
                src.remote = remote

        instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
            [remote.host],
            compute_scope.ScopeEnum.ZONE,
            args.zone,
            compute_holder.resources,
            scope_lister=instance_flags.GetInstanceZoneScopeLister(
                compute_holder.client))[0]
        instance = self.GetInstance(compute_holder.client, instance_ref)
        project = self.GetProject(compute_holder.client, instance_ref.project)

        # Now replace the instance name with the actual IP/hostname
        if ip_type is ip.IpTypeEnum.INTERNAL:
            remote.host = ssh_utils.GetInternalIPAddress(instance)
        else:
            remote.host = ssh_utils.GetExternalIPAddress(instance)
        if not remote.user:
            remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True)
        if args.plain:
            use_oslogin = False
        else:
            public_key = self.keys.GetPublicKey().ToEntry(include_comment=True)
            remote.user, use_oslogin = ssh.CheckForOsloginAndGetUser(
                instance, project, remote.user, public_key, release_track)

        identity_file = None
        options = None
        if not args.plain:
            identity_file = self.keys.key_file
            options = self.GetConfig(ssh_utils.HostKeyAlias(instance),
                                     args.strict_host_key_checking)

        tunnel_helper = None
        cmd_port = port
        interface = None
        if hasattr(args, 'tunnel_through_iap') and args.tunnel_through_iap:
            tunnel_helper, interface = ssh_utils.CreateIapTunnelHelper(
                args, instance_ref, instance, ip_type, port=port)
            tunnel_helper.StartListener()
            cmd_port = str(tunnel_helper.GetLocalPort())
            if dst.remote:
                dst.remote.host = 'localhost'
            else:
                for src in srcs:
                    src.remote.host = 'localhost'

        cmd = ssh.SCPCommand(srcs,
                             dst,
                             identity_file=identity_file,
                             options=options,
                             recursive=recursive,
                             compress=compress,
                             port=cmd_port,
                             extra_flags=extra_flags)

        if args.dry_run:
            log.out.Print(' '.join(cmd.Build(self.env)))
            if tunnel_helper:
                tunnel_helper.StopListener()
            return

        if args.plain or use_oslogin:
            keys_newly_added = False
        else:
            keys_newly_added = self.EnsureSSHKeyExists(compute_holder.client,
                                                       remote.user, instance,
                                                       project)

        if keys_newly_added:
            poller_tunnel_helper = None
            if tunnel_helper:
                poller_tunnel_helper, _ = ssh_utils.CreateIapTunnelHelper(
                    args,
                    instance_ref,
                    instance,
                    ip_type,
                    port=port,
                    interface=interface)
                poller_tunnel_helper.StartListener(
                    accept_multiple_connections=True)
            poller = ssh_utils.CreateSSHPoller(remote,
                                               identity_file,
                                               options,
                                               poller_tunnel_helper,
                                               port=port)

            log.status.Print('Waiting for SSH key to propagate.')
            # TODO(b/35355795): Don't force_connect
            try:
                poller.Poll(self.env, force_connect=True)
            except retry.WaitException:
                if tunnel_helper:
                    tunnel_helper.StopListener()
                raise ssh_utils.NetworkError()
            finally:
                if poller_tunnel_helper:
                    poller_tunnel_helper.StopListener()

        if ip_type is ip.IpTypeEnum.INTERNAL and not tunnel_helper:
            # The IAP Tunnel connection uses instance name and network interface name,
            # so do not need to additionally verify the instance. Also, the
            # SSHCommand used within the function does not support IAP Tunnels.
            self.PreliminarilyVerifyInstance(instance.id, remote,
                                             identity_file, options)

        try:
            # Errors from the SCP command result in an ssh.CommandError being raised
            cmd.Run(self.env, force_connect=True)
        finally:
            if tunnel_helper:
                tunnel_helper.StopListener()
    def testWithCopyOver(self):
        # SETUP
        user = ssh.GetDefaultSshUsername()
        install_path = SOSREPORT_INSTALL_PATH
        reports_path = REPORTS_PATH
        report_filepath = "/path/to/sosreport/file"
        download_dir = "/local/dir"

        # Create a side effect generator than will return (and run the
        # expected side effect!) for each moment
        ssh_mock = MockSSHCalls()
        ssh_mock.SetReturnValue(report_filepath)
        context = {"call_times": 0}

        def SideEffectGenerator(*args, **kwargs):
            call_times = context["call_times"]
            context["call_times"] += 1
            if call_times == 0:
                return 1
            if call_times in [1, 2, 3]:
                return 0
            return ssh_mock(*args, **kwargs)

        self.mock_ssh_command.side_effect = SideEffectGenerator

        # EXECUTE
        self.Run("""
             compute diagnose sosreport {name}
             --zone={zone}
             --download-dir="{download_dir}"
             """.format(name=INSTANCE.name,
                        zone=INSTANCE.zone,
                        download_dir=download_dir))

        # VERIFY
        # We assert on the amount of calls
        self.AssertSSHCallCount(7)
        self.AssertSSHCall(["ls", os.path.join(install_path, "sosreport")])
        self.AssertSSHCall(["mkdir", "-p", install_path])
        self.AssertSSHCall([
            "git", "clone", "https://github.com/sosreport/sos.git",
            install_path
        ])
        self.AssertSSHCall(["mkdir", "-p", reports_path])
        self.AssertSSHCall([
            "sudo",
            os.path.join(install_path, "sosreport"), "--batch",
            "--compression-type", "gzip", "--config-file",
            os.path.join(install_path, "sos.conf"), "--tmp-dir", reports_path
        ])
        self.AssertSSHCall(
            ["sudo", "chown", user,
             os.path.join(reports_path, "*")])
        self.AssertSSHCall([
            "ls", "-t",
            os.path.join(reports_path, "*.tar.gz"), "|", "head", "-n", "1"
        ])
        self.AssertSubProcessCalls([
            "gcloud", "compute", "scp", "--zone", INSTANCE.zone,
            INSTANCE.name + ":" + report_filepath,
            os.path.join(download_dir, "file")
        ])
Example #16
0
def PopulatePublicKey(api_client, service_id, version_id, instance_id,
                      public_key, release_track):
    """Enable debug mode on and send SSH keys to a flex instance.

  Common method for SSH-like commands, does the following:
  - Makes sure that the service/version/instance specified exists and is of the
    right type (Flexible).
  - If not already done, prompts and enables debug on the instance.
  - Populates the public key onto the instance.

  Args:
    api_client: An appengine_api_client.AppEngineApiClient.
    service_id: str, The service ID.
    version_id: str, The version ID.
    instance_id: str, The instance ID.
    public_key: ssh.Keys.PublicKey, Public key to send.
    release_track: calliope.base.ReleaseTrack, The current release track.

  Raises:
    InvalidInstanceTypeError: The instance is not supported for SSH.
    MissingVersionError: The version specified does not exist.
    MissingInstanceError: The instance specified does not exist.
    UnattendedPromptError: Not running in a tty.
    OperationCancelledError: User cancelled the operation.

  Returns:
    ConnectionDetails, the details to use for SSH/SCP for the SSH
    connection.
  """
    try:
        version = api_client.GetVersionResource(service=service_id,
                                                version=version_id)
    except apitools_exceptions.HttpNotFoundError:
        raise command_exceptions.MissingVersionError('{}/{}'.format(
            service_id, version_id))
    version = version_util.Version.FromVersionResource(version, None)
    if version.environment is not env.FLEX:
        if version.environment is env.MANAGED_VMS:
            environment = 'Managed VMs'
            msg = 'Use `gcloud compute ssh` for Managed VMs instances.'
        else:
            environment = 'Standard'
            msg = None
        raise command_exceptions.InvalidInstanceTypeError(environment, msg)
    res = resources.REGISTRY.Parse(
        instance_id,
        params={
            'appsId': properties.VALUES.core.project.GetOrFail,
            'versionsId': version_id,
            'instancesId': instance_id,
            'servicesId': service_id,
        },
        collection='appengine.apps.services.versions.instances')
    rel_name = res.RelativeName()
    try:
        instance = api_client.GetInstanceResource(res)
    except apitools_exceptions.HttpNotFoundError:
        raise command_exceptions.MissingInstanceError(rel_name)

    if not instance.vmDebugEnabled:
        log.warning(_ENABLE_DEBUG_WARNING)
        console_io.PromptContinue(cancel_on_no=True, throw_if_unattended=True)
    user = ssh.GetDefaultSshUsername()
    project = _GetComputeProject(release_track)
    user, use_oslogin = ssh.CheckForOsloginAndGetUser(None, project, user,
                                                      public_key.ToEntry(),
                                                      release_track)
    remote = ssh.Remote(instance.vmIp, user=user)
    if not use_oslogin:
        ssh_key = '{user}:{key} {user}'.format(user=user,
                                               key=public_key.ToEntry())
        log.status.Print(
            'Sending public key to instance [{}].'.format(rel_name))
        api_client.DebugInstance(res, ssh_key)
    options = {
        'IdentitiesOnly':
        'yes',  # No ssh-agent as of yet
        'UserKnownHostsFile':
        ssh.KnownHosts.DEFAULT_PATH,
        'CheckHostIP':
        'no',
        'HostKeyAlias':
        _HOST_KEY_ALIAS.format(project=api_client.project,
                               instance_id=instance_id)
    }
    return ConnectionDetails(remote, options)
Example #17
0
    def Run(self, args):
        dst = ssh.FileReference.FromPath(args.destination)
        srcs = [ssh.FileReference.FromPath(src) for src in args.sources]
        ssh.SCPCommand.Verify(srcs, dst, single_remote=True)
        if dst.remote:
            tpu_name = dst.remote.host
        else:
            tpu_name = srcs[0].remote.host

        # If zone is not set, retrieve the one from the config.
        if args.zone is None:
            args.zone = properties.VALUES.compute.zone.Get(required=True)

        # Retrieve the node.
        tpu = tpu_utils.TPUNode(self.ReleaseTrack())
        node = tpu.Get(tpu_name, args.zone)
        if not tpu_utils.IsTPUVMNode(node):
            raise exceptions.BadArgumentException(
                'TPU',
                'this command is only available for Cloud TPU VM nodes. To access '
                'this node, please see '
                'https://cloud.google.com/tpu/docs/creating-deleting-tpus.')

        worker_ips = tpu_ssh_utils.ParseWorkerFlag(args.worker,
                                                   node.networkEndpoints,
                                                   args.internal_ip)

        if len(worker_ips) > 1 and srcs[0].remote:
            raise exceptions.InvalidArgumentException(
                '--worker',
                'cannot target multiple workers while copying files to '
                'client.')

        tpu_ssh_utils.ValidateTPUState(node.state,
                                       tpu.messages.Node.StateValueValuesEnum)

        # Retrieve GuestAttributes.
        single_pod_worker = len(
            node.networkEndpoints) > 1 and len(worker_ips) == 1
        if single_pod_worker:
            # Retrieve only that worker's GuestAttributes.
            worker_id = list(worker_ips)[0]
            guest_attributes_response = tpu.GetGuestAttributes(
                tpu_name, args.zone, six.text_type((worker_id)))
            host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes(
                guest_attributes_response.guestAttributes,
                len(node.networkEndpoints), worker_id)
        else:
            # Retrieve the GuestAttributes for all workers in that TPU.
            guest_attributes_response = tpu.GetGuestAttributes(
                tpu_name, args.zone)
            host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes(
                guest_attributes_response.guestAttributes)

        # Generate the public key.
        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)
        public_key = ssh_helper.keys.GetPublicKey().ToEntry()

        remote = dst.remote or srcs[0].remote
        if not dst.remote:  # Make sure all remotes point to the same ref.
            for src in srcs:
                src.remote = remote

        if remote.user:
            username_requested = True
        else:
            username_requested = False
            remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True)

        project = tpu_utils.GetProject(self.ReleaseTrack(), ssh_helper)

        if not args.plain:
            # If there is an '@' symbol in the user_host arg, the user is requesting
            # to connect as a specific user. This may get overridden by OS Login.
            _, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs(args)
            oslogin_state = ssh.GetOsloginState(
                None,
                project,
                remote.user,
                public_key,
                expiration_micros,
                self.ReleaseTrack(),
                username_requested=username_requested,
                instance_enable_oslogin=tpu_ssh_utils.TpuHasOsLoginEnabled(
                    node))
            remote.user = oslogin_state.user

        # Format the key correctly.
        public_key = '{1}:{0} {1}'.format(public_key, remote.user)
        if not args.plain and not args.dry_run:
            tpu_ssh_utils.AddSSHKeyIfNeeded(project, tpu, node, tpu_name,
                                            args.zone, public_key)

        identity_file = None
        if not args.plain:
            identity_file = ssh_helper.keys.key_file
            # If the user's key is not in the SSH agent, the command will stall. We
            # want to verify it is added before proceeding, and raise an error if it
            # is not.
            if not args.dry_run and len(worker_ips) > 1:
                tpu_ssh_utils.VerifyKeyInAgent(identity_file)

        extra_flags = []

        if args.scp_flag:
            extra_flags.extend(args.scp_flag)

        instance_names = {}
        if (args.IsKnownAndSpecified('tunnel_through_iap')
                and args.tunnel_through_iap):
            # Retrieve the instance names from the GuestAttributes.
            for worker in worker_ips:
                # The GuestAttributes will only have one entry if we're targeting a
                # single worker.
                index = 0 if single_pod_worker else worker
                instance_name = tpu_ssh_utils.GetFromGuestAttributes(
                    guest_attributes_response.guestAttributes, index,
                    'hostname')
                if instance_name is None:
                    log.status.Print('Failed to connect to TPU.')
                    log.status.Print(tpu_ssh_utils.IAP_TROUBLESHOOTING_HELP)
                    raise tpu_exceptions.IapTunnelingUnavailable()
                instance_names[worker] = instance_name

        ssh_threads = []
        exit_statuses = [None] * len(worker_ips)
        for worker, ips in worker_ips.items():
            options = None
            if not args.plain:
                options = ssh_helper.GetConfig(
                    tpu_ssh_utils.GetInstanceID(node.id, worker,
                                                host_key_suffixes),
                    args.strict_host_key_checking, None)

            iap_tunnel_args = None
            if (args.IsKnownAndSpecified('tunnel_through_iap')
                    and args.tunnel_through_iap):
                # Retrieve the instance name from the GuestAttributes.
                instance_name = instance_names[worker]
                iap_tunnel_args = tpu_ssh_utils.CreateSshTunnelArgs(
                    args, self.ReleaseTrack(), project, args.zone,
                    instance_name)

            remote.host = ips.ip_address
            cmd = ssh.SCPCommand(srcs,
                                 dst,
                                 identity_file=identity_file,
                                 options=options,
                                 recursive=args.recurse,
                                 compress=args.compress,
                                 extra_flags=extra_flags,
                                 iap_tunnel_args=iap_tunnel_args)

            if args.dry_run:
                log.out.Print(' '.join(cmd.Build(ssh_helper.env)))
                continue

            if len(worker_ips) > 1:
                # Run the command on multiple workers concurrently.
                ssh_threads.append(
                    threading.Thread(
                        target=tpu_ssh_utils.AttemptRunWithRetries,
                        args=('SCP', worker, exit_statuses, cmd,
                              ssh_helper.env, None, True, SCPRunCmd)))
                ssh_threads[-1].start()
            else:
                # Run on a single worker.
                tpu_ssh_utils.AttemptRunWithRetries('SCP', worker,
                                                    exit_statuses, cmd,
                                                    ssh_helper.env, None,
                                                    False, SCPRunCmd)

        if len(worker_ips) > 1:
            # Wait for all the threads to complete.
            for i in range(len(ssh_threads)):
                ssh_threads[i].join()

            # Exit with a nonzero status, if any.
            # This ensures that if any command failed on a worker, we don't end up
            # returning 0 for a value.
            for status in exit_statuses:
                if status:
                    sys.exit(status)
Example #18
0
    def Run(self, args):
        """See ssh_utils.BaseSSHCommand.Run."""
        holder = base_classes.ComputeApiHolder(self.ReleaseTrack())
        cua_holder = base_classes.ComputeUserAccountsApiHolder(
            self.ReleaseTrack())
        client = holder.client

        ssh_helper = ssh_utils.BaseSSHHelper()
        ssh_helper.Run(args)
        ssh_helper.keys.EnsureKeysExist(args.force_key_file_overwrite,
                                        allow_passphrase=True)

        remote = ssh.Remote.FromArg(args.user_host)
        if not remote:
            raise ssh_utils.ArgumentError(
                'Expected argument of the form [USER@]INSTANCE. Received [{0}].'
                .format(args.user_host))
        if not remote.user:
            remote.user = ssh.GetDefaultSshUsername()

        hostname = '[{0}]:{1}'.format(args.serial_port_gateway,
                                      CONNECTION_PORT)
        # Update google_compute_known_hosts file with published host key
        if args.serial_port_gateway == SERIAL_PORT_GATEWAY:
            http_client = http.Http()
            http_response = http_client.request(HOST_KEY_URL)
            known_hosts = ssh.KnownHosts.FromDefaultFile()
            if http_response[0]['status'] == '200':
                host_key = http_response[1].strip()
                known_hosts.Add(hostname, host_key, overwrite=True)
                known_hosts.Write()
            elif known_hosts.ContainsAlias(hostname):
                log.warn(
                    'Unable to download and update Host Key for [{0}] from [{1}]. '
                    'Attempting to connect using existing Host Key in [{2}]. If '
                    'the connection fails, please try again to update the Host '
                    'Key.'.format(SERIAL_PORT_GATEWAY, HOST_KEY_URL,
                                  known_hosts.file_path))
            else:
                known_hosts.Add(hostname, DEFAULT_HOST_KEY)
                known_hosts.Write()
                log.warn(
                    'Unable to download Host Key for [{0}] from [{1}]. To ensure '
                    'the security of the SSH connetion, gcloud will attempt to '
                    'connect using a hard-coded Host Key value. If the connection '
                    'fails, please try again. If the problem persists, try '
                    'updating gcloud and connecting again.'.format(
                        SERIAL_PORT_GATEWAY, HOST_KEY_URL))
        instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
            [remote.host],
            compute_scope.ScopeEnum.ZONE,
            args.zone,
            holder.resources,
            scope_lister=instance_flags.GetInstanceZoneScopeLister(client))[0]
        instance = ssh_helper.GetInstance(client, instance_ref)
        project = ssh_helper.GetProject(client, instance_ref.project)

        # Determine the serial user, host tuple (remote)
        port = 'port={0}'.format(args.port)
        constructed_username_list = [
            instance_ref.project, instance_ref.zone,
            instance_ref.Name(), remote.user, port
        ]
        if args.extra_args:
            for k, v in args.extra_args.items():
                constructed_username_list.append('{0}={1}'.format(k, v))
        serial_user = '******'.join(constructed_username_list)
        serial_remote = ssh.Remote(args.serial_port_gateway, user=serial_user)

        identity_file = ssh_helper.keys.key_file
        options = ssh_helper.GetConfig(hostname,
                                       strict_host_key_checking='yes')
        del options['HostKeyAlias']
        cmd = ssh.SSHCommand(serial_remote,
                             identity_file=identity_file,
                             port=CONNECTION_PORT,
                             options=options)
        if args.dry_run:
            log.out.Print(' '.join(cmd.Build(ssh_helper.env)))
            return
        ssh_helper.EnsureSSHKeyExists(client, cua_holder.client, remote.user,
                                      instance, project)

        # Don't wait for the instance to become SSHable. We are not connecting to
        # the instance itself through SSH, so the instance doesn't need to have
        # fully booted to connect to the serial port. Also, ignore exit code 255,
        # since the normal way to terminate the serial port connection is ~. and
        # that causes ssh to exit with 255.
        try:
            return_code = cmd.Run(ssh_helper.env, force_connect=True)
        except ssh.CommandError:
            return_code = 255
        if return_code:
            sys.exit(return_code)
Example #19
0
    def RunScp(self,
               compute_holder,
               args,
               port=None,
               recursive=False,
               compress=False,
               extra_flags=None,
               release_track=None,
               ip_type=ip.IpTypeEnum.EXTERNAL):
        """SCP files between local and remote GCE instance.

    Run this method from subclasses' Run methods.

    Args:
      compute_holder: The ComputeApiHolder.
      args: argparse.Namespace, the args the command was invoked with.
      port: str or None, Port number to use for SSH connection.
      recursive: bool, Whether to use recursive copying using -R flag.
      compress: bool, Whether to use compression.
      extra_flags: [str] or None, extra flags to add to command invocation.
      release_track: obj, The current release track.
      ip_type: IpTypeEnum, Specify using internal ip or external ip address.

    Raises:
      ssh_utils.NetworkError: Network issue which likely is due to failure
        of SSH key propagation.
      ssh.CommandError: The SSH command exited with SSH exit code, which
        usually implies that a connection problem occurred.
    """
        if release_track is None:
            release_track = base.ReleaseTrack.GA
        super(BaseScpHelper, self).Run(args)

        dst = ssh.FileReference.FromPath(args.destination)
        srcs = [ssh.FileReference.FromPath(src) for src in args.sources]

        # Make sure we have a unique remote
        ssh.SCPCommand.Verify(srcs, dst, single_remote=True)

        remote = dst.remote or srcs[0].remote
        if not dst.remote:  # Make sure all remotes point to the same ref
            for src in srcs:
                src.remote = remote

        instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
            [remote.host],
            compute_scope.ScopeEnum.ZONE,
            args.zone,
            compute_holder.resources,
            scope_lister=instance_flags.GetInstanceZoneScopeLister(
                compute_holder.client))[0]
        instance = self.GetInstance(compute_holder.client, instance_ref)
        project = self.GetProject(compute_holder.client, instance_ref.project)

        if not remote.user:
            remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True)
        if args.plain:
            use_oslogin = False
        else:
            public_key = self.keys.GetPublicKey().ToEntry(include_comment=True)
            remote.user, use_oslogin = ssh.CheckForOsloginAndGetUser(
                instance, project, remote.user, public_key, release_track)

        identity_file = None
        options = None
        if not args.plain:
            identity_file = self.keys.key_file
            options = self.GetConfig(ssh_utils.HostKeyAlias(instance),
                                     args.strict_host_key_checking)

        iap_tunnel_args = iap_tunnel.SshTunnelArgs.FromArgs(
            args, release_track, instance_ref,
            ssh_utils.GetInternalInterface(instance),
            ssh_utils.GetExternalInterface(instance, no_raise=True))

        if iap_tunnel_args:
            remote.host = ssh_utils.HostKeyAlias(instance)
        elif ip_type is ip.IpTypeEnum.INTERNAL:
            remote.host = ssh_utils.GetInternalIPAddress(instance)
        else:
            remote.host = ssh_utils.GetExternalIPAddress(instance)

        cmd = ssh.SCPCommand(srcs,
                             dst,
                             identity_file=identity_file,
                             options=options,
                             recursive=recursive,
                             compress=compress,
                             port=port,
                             extra_flags=extra_flags,
                             iap_tunnel_args=iap_tunnel_args)

        if args.dry_run:
            log.out.Print(' '.join(cmd.Build(self.env)))
            return

        if args.plain or use_oslogin:
            keys_newly_added = False
        else:
            keys_newly_added = self.EnsureSSHKeyExists(compute_holder.client,
                                                       remote.user, instance,
                                                       project)

        if keys_newly_added:
            poller = ssh_utils.CreateSSHPoller(remote,
                                               identity_file,
                                               options,
                                               iap_tunnel_args,
                                               port=port)

            log.status.Print('Waiting for SSH key to propagate.')
            # TODO(b/35355795): Don't force_connect
            try:
                poller.Poll(self.env, force_connect=True)
            except retry.WaitException:
                raise ssh_utils.NetworkError()

        if ip_type is ip.IpTypeEnum.INTERNAL:
            # This will never happen when IAP Tunnel is enabled, because ip_type is
            # always EXTERNAL when IAP Tunnel is enabled, even if the instance has no
            # external IP. IAP Tunnel doesn't need verification because it uses
            # unambiguous identifiers for the instance.
            self.PreliminarilyVerifyInstance(instance.id, remote,
                                             identity_file, options)

        # Errors from the SCP command result in an ssh.CommandError being raised
        cmd.Run(self.env, force_connect=True)