def testWithSosReportAlreadyInstalled(self): # SETUP user = ssh.GetDefaultSshUsername() install_path = SOSREPORT_INSTALL_PATH reports_path = REPORTS_PATH # EXECUTE self.Run(""" compute diagnose sosreport {name} --zone={zone} """.format(name=INSTANCE.name, zone=INSTANCE.zone)) # VERIFY # We assert on the amount of calls self.AssertSSHCallCount(5) self.AssertSSHCall(["ls", os.path.join(install_path, "sosreport")]) self.AssertSSHCall(["mkdir", "-p", reports_path]) self.AssertSSHCall([ "sudo", os.path.join(install_path, "sosreport"), "--batch", "--compression-type", "gzip", "--config-file", os.path.join(install_path, "sos.conf"), "--tmp-dir", reports_path ]) self.AssertSSHCall( ["sudo", "chown", user, os.path.join(reports_path, "*")]) self.AssertSSHCall([ "ls", "-t", os.path.join(reports_path, "*.tar.gz"), "|", "head", "-n", "1" ])
def testWithCustomPath(self): # SETUP user = ssh.GetDefaultSshUsername() install_path = "/custom/install/path" reports_path = "/custom/report/path" # EXECUTE self.Run(""" compute diagnose sosreport {name} --zone={zone} --sosreport-install-path="{install_path}" --reports-path="{reports_path}" """.format(name=INSTANCE.name, zone=INSTANCE.zone, install_path=install_path, reports_path=reports_path)) # VERIFY # We assert on the amount of calls self.AssertSSHCallCount(5) self.AssertSSHCall(["ls", os.path.join(install_path, "sosreport")]) self.AssertSSHCall(["mkdir", "-p", reports_path]) self.AssertSSHCall([ "sudo", os.path.join(install_path, "sosreport"), "--batch", "--compression-type", "gzip", "--config-file", os.path.join(install_path, "sos.conf"), "--tmp-dir", reports_path ]) self.AssertSSHCall( ["sudo", "chown", user, os.path.join(reports_path, "*")]) self.AssertSSHCall([ "ls", "-t", os.path.join(reports_path, "*.tar.gz"), "|", "head", "-n", "1" ])
def GetUserAndInstance(user_host): """Returns pair consiting of user name and instance name.""" parts = user_host.split('@') if len(parts) == 1: user = ssh.GetDefaultSshUsername(warn_on_account_user=True) instance = parts[0] return user, instance if len(parts) == 2: return parts raise exceptions.ToolException( 'Expected argument of the form [USER@]INSTANCE; received [{0}].'. format(user_host))
def GetUserAndInstance(user_host, use_account_service, http): """Returns pair consiting of user name and instance name.""" parts = user_host.split('@') if len(parts) == 1: if use_account_service: # Using Account Service. user = gaia.GetDefaultAccountName(http) else: # Uploading keys through metadata. user = ssh.GetDefaultSshUsername(warn_on_account_user=True) instance = parts[0] return user, instance if len(parts) == 2: return parts raise exceptions.ToolException( 'Expected argument of the form [USER@]INSTANCE; received [{0}].'. format(user_host))
def Run(self, args): """Default run method implementation.""" super(SosReport, self).Run(args) self._use_accounts_service = False # Obtain the gcloud variables holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) instance = self.GetInstance(holder, args) user = args.user if args.user else ssh.GetDefaultSshUsername() ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) # Create the context variables context = { "args": args, "instance": instance, "ssh_helper": ssh_helper, "user": user, "python_path": args.python_path, } install_path = args.sosreport_install_path reports_path = args.reports_path # We dowload Sosreport into the VM if needed (normally first time) soshelper.ObtainSosreport(context, install_path) # (If needed) We create the directory where the reports will be created log.out.Print( "Creating the path where reports will be written if needed.") soshelper.CreatePath(context, reports_path) # We run the report soshelper.RunSosreport(context, install_path, reports_path) # Obtain and report the filename of the generated report report_path = soshelper.ObtainReportFilename(context, reports_path) msg = 'Report generated into "{report_path}".' log.status.Print(msg.format(report_path=report_path)) # If download_dir is set, we download the report over if args.download_dir: report_path = soshelper.CopyReportFile(context, args.download_dir, report_path) msg = 'Successfully downloaded report to "{report_path}"' log.status.Print(msg.format(report_path=report_path))
def SSHToInstance(self, args, instance): """Helper to manage authentication followed by SSH to the instance.""" args = self._DefaultArgsForSSH(args) external_nat = ssh_utils.GetExternalIPAddress(instance) log.status.Print( 'Trying to SSH to VM with NAT IP:{}'.format(external_nat)) remote = ssh.Remote(external_nat, ssh.GetDefaultSshUsername()) args.ssh_key_file = ssh.Keys.DEFAULT_KEY_FILE ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) identity_file = ssh_helper.keys.key_file user, _ = ssh_utils.GetUserAndInstance(args.name) host_keys = self._GetHostKeyFromInstance(args.zone, ssh_helper, instance) options = self._GetSSHOptions(args.name, ssh_helper, instance, host_keys) self._WaitForSSHKeysToPropagate(ssh_helper, remote, identity_file, user, instance, options) extra_flags = [] # Ctpu seems to be forwarding some other ports on what # seems like the TPU node. Need to understand better before enabling. if args.forward_ports: extra_flags.extend([ '-A', '-L', '6006:localhost:6006', '-L', '8888:localhost:8888' ]) ssh_cmd_args = { 'remote': remote, 'identity_file': identity_file, 'options': options, 'extra_flags': extra_flags } cmd = ssh.SSHCommand(**ssh_cmd_args) # Errors from SSH itself result in an ssh.CommandError being raised return_code = cmd.Run( ssh_helper.env, force_connect=properties.VALUES.ssh.putty_force_connect.GetBool()) if return_code: # This is the return code of the remote command. Problems with SSH itself # will result in ssh.CommandError being raised above. sys.exit(return_code)
def testCustomPythonPath(self): # SETUP user = ssh.GetDefaultSshUsername() install_path = SOSREPORT_INSTALL_PATH reports_path = REPORTS_PATH python_path = PYTHON_PATH # We need the first command to fail self.mock_ssh_command.side_effect = [1, 0, 0, 0, 0, 0, 0] # Error # EXECUTE self.Run(""" compute diagnose sosreport {name} --zone={zone} --python-path={python_path} """.format(name=INSTANCE.name, zone=INSTANCE.zone, python_path=python_path)) # VERIFY # We assert on the amount of calls self.AssertSSHCallCount(7) self.AssertSSHCall(["ls", os.path.join(install_path, "sosreport")]) self.AssertSSHCall(["mkdir", "-p", install_path]) self.AssertSSHCall([ "git", "clone", "https://github.com/sosreport/sos.git", install_path ]) self.AssertSSHCall(["mkdir", "-p", reports_path]) self.AssertSSHCall([ "sudo", python_path, os.path.join(install_path, "sosreport"), "--batch", "--compression-type", "gzip", "--config-file", os.path.join(install_path, "sos.conf"), "--tmp-dir", reports_path ]) self.AssertSSHCall( ["sudo", "chown", user, os.path.join(reports_path, "*")]) self.AssertSSHCall([ "ls", "-t", os.path.join(reports_path, "*.tar.gz"), "|", "head", "-n", "1" ])
def testSSHToInstance(self): ssh_util = tpu_utils.SSH(self.track) args = SSHTest.Args('fake-name', 'central2-a', True) instance = self._makeFakeInstance('fake-instance') self.make_requests.side_effect = iter([ [self._makeFakeProjectResource()], [self._makeFakeProjectResource()], [self._makeFakeProjectResource()], ]) ssh_util.SSHToInstance(args, instance) self.ensure_keys.assert_called_once_with(self.keys, None, allow_passphrase=True) self.poller_poll.assert_called_once_with(mock_matchers.TypeMatcher( ssh.SSHPoller), self.env, force_connect=True) # SSH Command self.ssh_init.assert_called_once_with( mock_matchers.TypeMatcher(ssh.SSHCommand), remote=ssh.Remote('23.251.133.75', ssh.GetDefaultSshUsername()), identity_file=self.private_key_file, extra_flags=[ '-A', '-L', '6006:localhost:6006', '-L', '8888:localhost:8888' ], options=dict( self.options, HostKeyAlias='compute.1111', SendEnv='TPU_NAME', )) self.ssh_run.assert_called_once_with(mock_matchers.TypeMatcher( ssh.SSHCommand), self.env, force_connect=True)
def testDryRun(self): # SETUP user = ssh.GetDefaultSshUsername() install_path = SOSREPORT_INSTALL_PATH reports_path = REPORTS_PATH # We need the first command to fail self.mock_ssh_command.side_effect = [1, 0, 0, 0, 0, 0, 0] # Error # EXECUTE self.Run(""" compute diagnose sosreport {name} --zone={zone} --dry-run """.format(name=INSTANCE.name, zone=INSTANCE.zone)) # Dry run does not call SSH self.AssertSSHCallCount(6) self.AssertSSHCall( ["ls", os.path.join(install_path, "sosreport")], is_dry_run=True) self.AssertSSHCall(["mkdir", "-p", install_path], is_dry_run=True) self.AssertSSHCall([ "git", "clone", "https://github.com/sosreport/sos.git", install_path ], is_dry_run=True) self.AssertSSHCall(["mkdir", "-p", reports_path], is_dry_run=True) self.AssertSSHCall([ "sudo", os.path.join(install_path, "sosreport"), "--batch", "--compression-type", "gzip", "--config-file", os.path.join(install_path, "sos.conf"), "--tmp-dir", reports_path ], is_dry_run=True) self.AssertSSHCall( ["sudo", "chown", user, os.path.join(reports_path, "*")], is_dry_run=True)
def Run(self, args): """See ssh_utils.BaseSSHCommand.Run.""" holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) client = holder.client ssh_helper = ssh_utils.BaseSSHHelper() ssh_helper.Run(args) ssh_helper.keys.EnsureKeysExist(args.force_key_file_overwrite, allow_passphrase=True) ssh_config_file = files.ExpandHomeDir(args.ssh_config_file or ssh.PER_USER_SSH_CONFIG_FILE) instances = None try: existing_content = files.ReadFileContents(ssh_config_file) except files.Error as e: existing_content = '' log.debug('SSH Config File [{0}] could not be opened: {1}'.format( ssh_config_file, e)) if args.remove: compute_section = '' try: new_content = _RemoveComputeSection(existing_content) except MultipleComputeSectionsError: raise MultipleComputeSectionsError(ssh_config_file) else: ssh_helper.EnsureSSHKeyIsInProject( client, ssh.GetDefaultSshUsername(warn_on_account_user=True), None) instances = list(self.GetRunningInstances(client)) if instances: compute_section = _BuildComputeSection( instances, ssh_helper.keys.key_file, ssh.KnownHosts.DEFAULT_PATH) else: compute_section = '' if existing_content and not args.remove: try: new_content = _MergeComputeSections(existing_content, compute_section) except MultipleComputeSectionsError: raise MultipleComputeSectionsError(ssh_config_file) elif not existing_content: new_content = compute_section if args.dry_run: log.out.write(new_content or '') return if new_content != existing_content: if (os.path.exists(ssh_config_file) and platforms.OperatingSystem.Current() is not platforms.OperatingSystem.WINDOWS): ssh_config_perms = os.stat(ssh_config_file).st_mode # From `man 5 ssh_config`: # this file must have strict permissions: read/write for the user, # and not accessible by others. # We check that here: if not (ssh_config_perms & stat.S_IRWXU == stat.S_IWUSR | stat.S_IRUSR and ssh_config_perms & stat.S_IWGRP == 0 and ssh_config_perms & stat.S_IWOTH == 0): log.warning( 'Invalid permissions on [{0}]. Please change to match ssh ' 'requirements (see man 5 ssh).') # TODO(b/36050483): This write will not work very well if there is # a lot of write contention for the SSH config file. We should # add a function to do a better job at "atomic file writes". files.WriteFileContents(ssh_config_file, new_content, private=True) if compute_section: log.out.write( textwrap.dedent("""\ You should now be able to use ssh/scp with your instances. For example, try running: $ ssh {alias} """.format(alias=_CreateAlias(instances[0])))) elif not instances and not args.remove: log.warning( 'No host aliases were added to your SSH configs because you do not ' 'have any running instances. Try running this command again after ' 'running some instances.')
def Run(self, args, port=None, recursive=False, extra_flags=None): """SCP files between local and remote GCE instance. Run this method from subclasses' Run methods. Args: args: argparse.Namespace, the args the command was invoked with. port: str, int or None, Port number to use for SSH connection. recursive: bool, Whether to use recursive copying using -R flag. extra_flags: [str] or None, extra flags to add to command invocation. Raises: ssh_utils.NetworkError: Network issue which likely is due to failure of SSH key propagation. ssh.CommandError: The SSH command exited with SSH exit code, which usually implies that a connection problem occurred. """ super(BaseScpCommand, self).Run(args) dst = ssh.FileReference.FromPath(args.destination) srcs = [ssh.FileReference.FromPath(src) for src in args.sources] # Make sure we have a unique remote ssh.SCPCommand.Verify(srcs, dst, single_remote=True) remote = dst.remote or srcs[0].remote if not dst.remote: # Make sure all remotes point to the same ref for src in srcs: src.remote = remote instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources( [remote.host], compute_scope.ScopeEnum.ZONE, args.zone, self.resources, scope_lister=flags.GetDefaultScopeLister(self.compute_client))[0] instance = self.GetInstance(instance_ref) # Now replace the instance name with the actual IP/hostname remote.host = ssh_utils.GetExternalIPAddress(instance) if not remote.user: remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True) identity_file = None options = None if not args.plain: identity_file = self.keys.key_file options = self.GetConfig(ssh_utils.HostKeyAlias(instance), args.strict_host_key_checking) cmd = ssh.SCPCommand( srcs, dst, identity_file=identity_file, options=options, recursive=recursive, port=port, extra_flags=extra_flags) if args.dry_run: log.out.Print(' '.join(cmd.Build(self.env))) return if args.plain: keys_newly_added = False else: keys_newly_added = self.EnsureSSHKeyExists( remote.user, instance, instance_ref.project, use_account_service=self._use_account_service) if keys_newly_added: poller = ssh.SSHPoller( remote, identity_file=identity_file, options=options, max_wait_ms=ssh_utils.SSH_KEY_PROPAGATION_TIMEOUT_SEC) log.status.Print('Waiting for SSH key to propagate.') # TODO(b/35355795): Don't force_connect try: poller.Poll(self.env, force_connect=True) except retry.WaitException: raise ssh_utils.NetworkError() return_code = cmd.Run(self.env, force_connect=True) if return_code: # Can't raise an exception because we don't want any "ERROR" message # printed; the output from `ssh` will be enough. sys.exit(return_code)
def Run(self, args): """Default run method implementation.""" super(Routes, self).Run(args) self._use_accounts_service = False holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) resource_registry = holder.resources ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) # We store always needed commands non-changing fields self._args = args self._ssh_helper = ssh_helper # We obtain generic parameters of the call project = properties.VALUES.core.project.GetOrFail() filters = _RoutesArgs.GetFilters(args) instances = _RoutesQueries.ObtainInstances( args.names, service=self.compute.instances, project=project, zones=args.zones, filters=filters, http=self.http, batch_url=self.batch_url) user = args.user if not user: user = ssh.GetDefaultSshUsername() # We unpack the flags dry_run = args.dry_run reverse_traceroute = args.reverse_traceroute traceroute_args = args.traceroute_args external_route_ip = args.external_route_ip internal_helpers.PrintHeader(instances) prompt = 'The following VMs will be tracerouted.' if instances and not dry_run and not console_io.PromptContinue(prompt): return # Sometimes the prompt would appear after the instance data log.out.flush() for instance in instances: header = 'Checking instance %s' % instance.name log.out.Print(header) log.out.Print('-' * len(header)) try: self.TracerouteInstance(instance, traceroute_args, dry_run, resource_registry) except exceptions.ToolException as e: log.error('Error routing to instance') log.error(str(e)) continue if reverse_traceroute: try: has_traceroute = self.CheckTraceroute( instance, user, dry_run, resource_registry) if has_traceroute: # We obtain the self ip if not external_route_ip: external_route_ip = self.ObtainSelfIp( instance, user, dry_run, resource_registry) if external_route_ip: self.ReverseTracerouteInstance( instance, user, external_route_ip, traceroute_args, dry_run, resource_registry) else: log.out.Print( 'Unable to obtain self ip. Aborting.') else: log.out.Print( 'Please make sure traceroute is installed in PATH to move on.' ) except ssh.CommandError as e: log.error(str(e)) log.out.Print('') # Separator
def Run(self, args): """Connect to a running flex instance. Args: args: argparse.Namespace, the args the command was invoked with. Raises: InvalidInstanceTypeError: The instance is not supported for SSH. MissingVersionError: The version specified does not exist. MissingInstanceError: The instance specified does not exist. UnattendedPromptError: Not running in a tty. OperationCancelledError: User cancelled the operation. ssh.CommandError: The SSH command exited with SSH exit code, which usually implies that a connection problem occurred. Returns: int, The exit code of the SSH command. """ api_client = appengine_api_client.GetApiClient() env = ssh.Environment.Current() env.RequireSSH() keys = ssh.Keys.FromFilename() keys.EnsureKeysExist(overwrite=False) try: version = api_client.GetVersionResource(service=args.service, version=args.version) except api_exceptions.NotFoundError: raise command_exceptions.MissingVersionError('{}/{}'.format( args.service, args.version)) version = version_util.Version.FromVersionResource(version, None) if version.environment is not util.Environment.FLEX: if version.environment is util.Environment.MANAGED_VMS: environment = 'Managed VMs' msg = 'Use `gcloud compute ssh` for Managed VMs instances.' else: environment = 'Standard' msg = None raise command_exceptions.InvalidInstanceTypeError(environment, msg) res = resources.REGISTRY.Parse( args.instance, params={ 'appsId': properties.VALUES.core.project.GetOrFail, 'versionsId': args.version, 'instancesId': args.instance, 'servicesId': args.service, }, collection='appengine.apps.services.versions.instances') rel_name = res.RelativeName() try: instance = api_client.GetInstanceResource(res) except api_exceptions.NotFoundError: raise command_exceptions.MissingInstanceError(rel_name) if not instance.vmDebugEnabled: log.warn(ENABLE_DEBUG_WARNING) console_io.PromptContinue(cancel_on_no=True, throw_if_unattended=True) user = ssh.GetDefaultSshUsername() remote = ssh.Remote(instance.vmIp, user=user) public_key = keys.GetPublicKey().ToEntry() ssh_key = '{user}:{key} {user}'.format(user=user, key=public_key) log.status.Print( 'Sending public key to instance [{}].'.format(rel_name)) api_client.DebugInstance(res, ssh_key) options = { 'IdentitiesOnly': 'yes', # No ssh-agent as of yet 'UserKnownHostsFile': ssh.KnownHosts.DEFAULT_PATH, 'CheckHostIP': 'no', 'HostKeyAlias': HOST_KEY_ALIAS.format(project=api_client.project, instance_id=args.instance) } cmd = ssh.SSHCommand(remote, identity_file=keys.key_file, options=options) if args.container: cmd.tty = True cmd.remote_command = ['container_exec', args.container, '/bin/sh'] return cmd.Run(env)
def RunScp(self, compute_holder, args, port=None, recursive=False, compress=False, extra_flags=None, release_track=None, ip_type=ip.IpTypeEnum.EXTERNAL): """SCP files between local and remote GCE instance. Run this method from subclasses' Run methods. Args: compute_holder: The ComputeApiHolder. args: argparse.Namespace, the args the command was invoked with. port: str or None, Port number to use for SSH connection. recursive: bool, Whether to use recursive copying using -R flag. compress: bool, Whether to use compression. extra_flags: [str] or None, extra flags to add to command invocation. release_track: obj, The current release track. ip_type: IpTypeEnum, Specify using internal ip or external ip address. Raises: ssh_utils.NetworkError: Network issue which likely is due to failure of SSH key propagation. ssh.CommandError: The SSH command exited with SSH exit code, which usually implies that a connection problem occurred. """ if release_track is None: release_track = base.ReleaseTrack.GA super(BaseScpHelper, self).Run(args) dst = ssh.FileReference.FromPath(args.destination) srcs = [ssh.FileReference.FromPath(src) for src in args.sources] # Make sure we have a unique remote ssh.SCPCommand.Verify(srcs, dst, single_remote=True) remote = dst.remote or srcs[0].remote if not dst.remote: # Make sure all remotes point to the same ref for src in srcs: src.remote = remote instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources( [remote.host], compute_scope.ScopeEnum.ZONE, args.zone, compute_holder.resources, scope_lister=instance_flags.GetInstanceZoneScopeLister( compute_holder.client))[0] instance = self.GetInstance(compute_holder.client, instance_ref) project = self.GetProject(compute_holder.client, instance_ref.project) # Now replace the instance name with the actual IP/hostname if ip_type is ip.IpTypeEnum.INTERNAL: remote.host = ssh_utils.GetInternalIPAddress(instance) else: remote.host = ssh_utils.GetExternalIPAddress(instance) if not remote.user: remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True) if args.plain: use_oslogin = False else: public_key = self.keys.GetPublicKey().ToEntry(include_comment=True) remote.user, use_oslogin = ssh.CheckForOsloginAndGetUser( instance, project, remote.user, public_key, release_track) identity_file = None options = None if not args.plain: identity_file = self.keys.key_file options = self.GetConfig(ssh_utils.HostKeyAlias(instance), args.strict_host_key_checking) tunnel_helper = None cmd_port = port interface = None if hasattr(args, 'tunnel_through_iap') and args.tunnel_through_iap: tunnel_helper, interface = ssh_utils.CreateIapTunnelHelper( args, instance_ref, instance, ip_type, port=port) tunnel_helper.StartListener() cmd_port = str(tunnel_helper.GetLocalPort()) if dst.remote: dst.remote.host = 'localhost' else: for src in srcs: src.remote.host = 'localhost' cmd = ssh.SCPCommand(srcs, dst, identity_file=identity_file, options=options, recursive=recursive, compress=compress, port=cmd_port, extra_flags=extra_flags) if args.dry_run: log.out.Print(' '.join(cmd.Build(self.env))) if tunnel_helper: tunnel_helper.StopListener() return if args.plain or use_oslogin: keys_newly_added = False else: keys_newly_added = self.EnsureSSHKeyExists(compute_holder.client, remote.user, instance, project) if keys_newly_added: poller_tunnel_helper = None if tunnel_helper: poller_tunnel_helper, _ = ssh_utils.CreateIapTunnelHelper( args, instance_ref, instance, ip_type, port=port, interface=interface) poller_tunnel_helper.StartListener( accept_multiple_connections=True) poller = ssh_utils.CreateSSHPoller(remote, identity_file, options, poller_tunnel_helper, port=port) log.status.Print('Waiting for SSH key to propagate.') # TODO(b/35355795): Don't force_connect try: poller.Poll(self.env, force_connect=True) except retry.WaitException: if tunnel_helper: tunnel_helper.StopListener() raise ssh_utils.NetworkError() finally: if poller_tunnel_helper: poller_tunnel_helper.StopListener() if ip_type is ip.IpTypeEnum.INTERNAL and not tunnel_helper: # The IAP Tunnel connection uses instance name and network interface name, # so do not need to additionally verify the instance. Also, the # SSHCommand used within the function does not support IAP Tunnels. self.PreliminarilyVerifyInstance(instance.id, remote, identity_file, options) try: # Errors from the SCP command result in an ssh.CommandError being raised cmd.Run(self.env, force_connect=True) finally: if tunnel_helper: tunnel_helper.StopListener()
def testWithCopyOver(self): # SETUP user = ssh.GetDefaultSshUsername() install_path = SOSREPORT_INSTALL_PATH reports_path = REPORTS_PATH report_filepath = "/path/to/sosreport/file" download_dir = "/local/dir" # Create a side effect generator than will return (and run the # expected side effect!) for each moment ssh_mock = MockSSHCalls() ssh_mock.SetReturnValue(report_filepath) context = {"call_times": 0} def SideEffectGenerator(*args, **kwargs): call_times = context["call_times"] context["call_times"] += 1 if call_times == 0: return 1 if call_times in [1, 2, 3]: return 0 return ssh_mock(*args, **kwargs) self.mock_ssh_command.side_effect = SideEffectGenerator # EXECUTE self.Run(""" compute diagnose sosreport {name} --zone={zone} --download-dir="{download_dir}" """.format(name=INSTANCE.name, zone=INSTANCE.zone, download_dir=download_dir)) # VERIFY # We assert on the amount of calls self.AssertSSHCallCount(7) self.AssertSSHCall(["ls", os.path.join(install_path, "sosreport")]) self.AssertSSHCall(["mkdir", "-p", install_path]) self.AssertSSHCall([ "git", "clone", "https://github.com/sosreport/sos.git", install_path ]) self.AssertSSHCall(["mkdir", "-p", reports_path]) self.AssertSSHCall([ "sudo", os.path.join(install_path, "sosreport"), "--batch", "--compression-type", "gzip", "--config-file", os.path.join(install_path, "sos.conf"), "--tmp-dir", reports_path ]) self.AssertSSHCall( ["sudo", "chown", user, os.path.join(reports_path, "*")]) self.AssertSSHCall([ "ls", "-t", os.path.join(reports_path, "*.tar.gz"), "|", "head", "-n", "1" ]) self.AssertSubProcessCalls([ "gcloud", "compute", "scp", "--zone", INSTANCE.zone, INSTANCE.name + ":" + report_filepath, os.path.join(download_dir, "file") ])
def PopulatePublicKey(api_client, service_id, version_id, instance_id, public_key, release_track): """Enable debug mode on and send SSH keys to a flex instance. Common method for SSH-like commands, does the following: - Makes sure that the service/version/instance specified exists and is of the right type (Flexible). - If not already done, prompts and enables debug on the instance. - Populates the public key onto the instance. Args: api_client: An appengine_api_client.AppEngineApiClient. service_id: str, The service ID. version_id: str, The version ID. instance_id: str, The instance ID. public_key: ssh.Keys.PublicKey, Public key to send. release_track: calliope.base.ReleaseTrack, The current release track. Raises: InvalidInstanceTypeError: The instance is not supported for SSH. MissingVersionError: The version specified does not exist. MissingInstanceError: The instance specified does not exist. UnattendedPromptError: Not running in a tty. OperationCancelledError: User cancelled the operation. Returns: ConnectionDetails, the details to use for SSH/SCP for the SSH connection. """ try: version = api_client.GetVersionResource(service=service_id, version=version_id) except apitools_exceptions.HttpNotFoundError: raise command_exceptions.MissingVersionError('{}/{}'.format( service_id, version_id)) version = version_util.Version.FromVersionResource(version, None) if version.environment is not env.FLEX: if version.environment is env.MANAGED_VMS: environment = 'Managed VMs' msg = 'Use `gcloud compute ssh` for Managed VMs instances.' else: environment = 'Standard' msg = None raise command_exceptions.InvalidInstanceTypeError(environment, msg) res = resources.REGISTRY.Parse( instance_id, params={ 'appsId': properties.VALUES.core.project.GetOrFail, 'versionsId': version_id, 'instancesId': instance_id, 'servicesId': service_id, }, collection='appengine.apps.services.versions.instances') rel_name = res.RelativeName() try: instance = api_client.GetInstanceResource(res) except apitools_exceptions.HttpNotFoundError: raise command_exceptions.MissingInstanceError(rel_name) if not instance.vmDebugEnabled: log.warning(_ENABLE_DEBUG_WARNING) console_io.PromptContinue(cancel_on_no=True, throw_if_unattended=True) user = ssh.GetDefaultSshUsername() project = _GetComputeProject(release_track) user, use_oslogin = ssh.CheckForOsloginAndGetUser(None, project, user, public_key.ToEntry(), release_track) remote = ssh.Remote(instance.vmIp, user=user) if not use_oslogin: ssh_key = '{user}:{key} {user}'.format(user=user, key=public_key.ToEntry()) log.status.Print( 'Sending public key to instance [{}].'.format(rel_name)) api_client.DebugInstance(res, ssh_key) options = { 'IdentitiesOnly': 'yes', # No ssh-agent as of yet 'UserKnownHostsFile': ssh.KnownHosts.DEFAULT_PATH, 'CheckHostIP': 'no', 'HostKeyAlias': _HOST_KEY_ALIAS.format(project=api_client.project, instance_id=instance_id) } return ConnectionDetails(remote, options)
def Run(self, args): dst = ssh.FileReference.FromPath(args.destination) srcs = [ssh.FileReference.FromPath(src) for src in args.sources] ssh.SCPCommand.Verify(srcs, dst, single_remote=True) if dst.remote: tpu_name = dst.remote.host else: tpu_name = srcs[0].remote.host # If zone is not set, retrieve the one from the config. if args.zone is None: args.zone = properties.VALUES.compute.zone.Get(required=True) # Retrieve the node. tpu = tpu_utils.TPUNode(self.ReleaseTrack()) node = tpu.Get(tpu_name, args.zone) if not tpu_utils.IsTPUVMNode(node): raise exceptions.BadArgumentException( 'TPU', 'this command is only available for Cloud TPU VM nodes. To access ' 'this node, please see ' 'https://cloud.google.com/tpu/docs/creating-deleting-tpus.') worker_ips = tpu_ssh_utils.ParseWorkerFlag(args.worker, node.networkEndpoints, args.internal_ip) if len(worker_ips) > 1 and srcs[0].remote: raise exceptions.InvalidArgumentException( '--worker', 'cannot target multiple workers while copying files to ' 'client.') tpu_ssh_utils.ValidateTPUState(node.state, tpu.messages.Node.StateValueValuesEnum) # Retrieve GuestAttributes. single_pod_worker = len( node.networkEndpoints) > 1 and len(worker_ips) == 1 if single_pod_worker: # Retrieve only that worker's GuestAttributes. worker_id = list(worker_ips)[0] guest_attributes_response = tpu.GetGuestAttributes( tpu_name, args.zone, six.text_type((worker_id))) host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes( guest_attributes_response.guestAttributes, len(node.networkEndpoints), worker_id) else: # Retrieve the GuestAttributes for all workers in that TPU. guest_attributes_response = tpu.GetGuestAttributes( tpu_name, args.zone) host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes( guest_attributes_response.guestAttributes) # Generate the public key. ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) public_key = ssh_helper.keys.GetPublicKey().ToEntry() remote = dst.remote or srcs[0].remote if not dst.remote: # Make sure all remotes point to the same ref. for src in srcs: src.remote = remote if remote.user: username_requested = True else: username_requested = False remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True) project = tpu_utils.GetProject(self.ReleaseTrack(), ssh_helper) if not args.plain: # If there is an '@' symbol in the user_host arg, the user is requesting # to connect as a specific user. This may get overridden by OS Login. _, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs(args) oslogin_state = ssh.GetOsloginState( None, project, remote.user, public_key, expiration_micros, self.ReleaseTrack(), username_requested=username_requested, instance_enable_oslogin=tpu_ssh_utils.TpuHasOsLoginEnabled( node)) remote.user = oslogin_state.user # Format the key correctly. public_key = '{1}:{0} {1}'.format(public_key, remote.user) if not args.plain and not args.dry_run: tpu_ssh_utils.AddSSHKeyIfNeeded(project, tpu, node, tpu_name, args.zone, public_key) identity_file = None if not args.plain: identity_file = ssh_helper.keys.key_file # If the user's key is not in the SSH agent, the command will stall. We # want to verify it is added before proceeding, and raise an error if it # is not. if not args.dry_run and len(worker_ips) > 1: tpu_ssh_utils.VerifyKeyInAgent(identity_file) extra_flags = [] if args.scp_flag: extra_flags.extend(args.scp_flag) instance_names = {} if (args.IsKnownAndSpecified('tunnel_through_iap') and args.tunnel_through_iap): # Retrieve the instance names from the GuestAttributes. for worker in worker_ips: # The GuestAttributes will only have one entry if we're targeting a # single worker. index = 0 if single_pod_worker else worker instance_name = tpu_ssh_utils.GetFromGuestAttributes( guest_attributes_response.guestAttributes, index, 'hostname') if instance_name is None: log.status.Print('Failed to connect to TPU.') log.status.Print(tpu_ssh_utils.IAP_TROUBLESHOOTING_HELP) raise tpu_exceptions.IapTunnelingUnavailable() instance_names[worker] = instance_name ssh_threads = [] exit_statuses = [None] * len(worker_ips) for worker, ips in worker_ips.items(): options = None if not args.plain: options = ssh_helper.GetConfig( tpu_ssh_utils.GetInstanceID(node.id, worker, host_key_suffixes), args.strict_host_key_checking, None) iap_tunnel_args = None if (args.IsKnownAndSpecified('tunnel_through_iap') and args.tunnel_through_iap): # Retrieve the instance name from the GuestAttributes. instance_name = instance_names[worker] iap_tunnel_args = tpu_ssh_utils.CreateSshTunnelArgs( args, self.ReleaseTrack(), project, args.zone, instance_name) remote.host = ips.ip_address cmd = ssh.SCPCommand(srcs, dst, identity_file=identity_file, options=options, recursive=args.recurse, compress=args.compress, extra_flags=extra_flags, iap_tunnel_args=iap_tunnel_args) if args.dry_run: log.out.Print(' '.join(cmd.Build(ssh_helper.env))) continue if len(worker_ips) > 1: # Run the command on multiple workers concurrently. ssh_threads.append( threading.Thread( target=tpu_ssh_utils.AttemptRunWithRetries, args=('SCP', worker, exit_statuses, cmd, ssh_helper.env, None, True, SCPRunCmd))) ssh_threads[-1].start() else: # Run on a single worker. tpu_ssh_utils.AttemptRunWithRetries('SCP', worker, exit_statuses, cmd, ssh_helper.env, None, False, SCPRunCmd) if len(worker_ips) > 1: # Wait for all the threads to complete. for i in range(len(ssh_threads)): ssh_threads[i].join() # Exit with a nonzero status, if any. # This ensures that if any command failed on a worker, we don't end up # returning 0 for a value. for status in exit_statuses: if status: sys.exit(status)
def Run(self, args): """See ssh_utils.BaseSSHCommand.Run.""" holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) cua_holder = base_classes.ComputeUserAccountsApiHolder( self.ReleaseTrack()) client = holder.client ssh_helper = ssh_utils.BaseSSHHelper() ssh_helper.Run(args) ssh_helper.keys.EnsureKeysExist(args.force_key_file_overwrite, allow_passphrase=True) remote = ssh.Remote.FromArg(args.user_host) if not remote: raise ssh_utils.ArgumentError( 'Expected argument of the form [USER@]INSTANCE. Received [{0}].' .format(args.user_host)) if not remote.user: remote.user = ssh.GetDefaultSshUsername() hostname = '[{0}]:{1}'.format(args.serial_port_gateway, CONNECTION_PORT) # Update google_compute_known_hosts file with published host key if args.serial_port_gateway == SERIAL_PORT_GATEWAY: http_client = http.Http() http_response = http_client.request(HOST_KEY_URL) known_hosts = ssh.KnownHosts.FromDefaultFile() if http_response[0]['status'] == '200': host_key = http_response[1].strip() known_hosts.Add(hostname, host_key, overwrite=True) known_hosts.Write() elif known_hosts.ContainsAlias(hostname): log.warn( 'Unable to download and update Host Key for [{0}] from [{1}]. ' 'Attempting to connect using existing Host Key in [{2}]. If ' 'the connection fails, please try again to update the Host ' 'Key.'.format(SERIAL_PORT_GATEWAY, HOST_KEY_URL, known_hosts.file_path)) else: known_hosts.Add(hostname, DEFAULT_HOST_KEY) known_hosts.Write() log.warn( 'Unable to download Host Key for [{0}] from [{1}]. To ensure ' 'the security of the SSH connetion, gcloud will attempt to ' 'connect using a hard-coded Host Key value. If the connection ' 'fails, please try again. If the problem persists, try ' 'updating gcloud and connecting again.'.format( SERIAL_PORT_GATEWAY, HOST_KEY_URL)) instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources( [remote.host], compute_scope.ScopeEnum.ZONE, args.zone, holder.resources, scope_lister=instance_flags.GetInstanceZoneScopeLister(client))[0] instance = ssh_helper.GetInstance(client, instance_ref) project = ssh_helper.GetProject(client, instance_ref.project) # Determine the serial user, host tuple (remote) port = 'port={0}'.format(args.port) constructed_username_list = [ instance_ref.project, instance_ref.zone, instance_ref.Name(), remote.user, port ] if args.extra_args: for k, v in args.extra_args.items(): constructed_username_list.append('{0}={1}'.format(k, v)) serial_user = '******'.join(constructed_username_list) serial_remote = ssh.Remote(args.serial_port_gateway, user=serial_user) identity_file = ssh_helper.keys.key_file options = ssh_helper.GetConfig(hostname, strict_host_key_checking='yes') del options['HostKeyAlias'] cmd = ssh.SSHCommand(serial_remote, identity_file=identity_file, port=CONNECTION_PORT, options=options) if args.dry_run: log.out.Print(' '.join(cmd.Build(ssh_helper.env))) return ssh_helper.EnsureSSHKeyExists(client, cua_holder.client, remote.user, instance, project) # Don't wait for the instance to become SSHable. We are not connecting to # the instance itself through SSH, so the instance doesn't need to have # fully booted to connect to the serial port. Also, ignore exit code 255, # since the normal way to terminate the serial port connection is ~. and # that causes ssh to exit with 255. try: return_code = cmd.Run(ssh_helper.env, force_connect=True) except ssh.CommandError: return_code = 255 if return_code: sys.exit(return_code)
def RunScp(self, compute_holder, args, port=None, recursive=False, compress=False, extra_flags=None, release_track=None, ip_type=ip.IpTypeEnum.EXTERNAL): """SCP files between local and remote GCE instance. Run this method from subclasses' Run methods. Args: compute_holder: The ComputeApiHolder. args: argparse.Namespace, the args the command was invoked with. port: str or None, Port number to use for SSH connection. recursive: bool, Whether to use recursive copying using -R flag. compress: bool, Whether to use compression. extra_flags: [str] or None, extra flags to add to command invocation. release_track: obj, The current release track. ip_type: IpTypeEnum, Specify using internal ip or external ip address. Raises: ssh_utils.NetworkError: Network issue which likely is due to failure of SSH key propagation. ssh.CommandError: The SSH command exited with SSH exit code, which usually implies that a connection problem occurred. """ if release_track is None: release_track = base.ReleaseTrack.GA super(BaseScpHelper, self).Run(args) dst = ssh.FileReference.FromPath(args.destination) srcs = [ssh.FileReference.FromPath(src) for src in args.sources] # Make sure we have a unique remote ssh.SCPCommand.Verify(srcs, dst, single_remote=True) remote = dst.remote or srcs[0].remote if not dst.remote: # Make sure all remotes point to the same ref for src in srcs: src.remote = remote instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources( [remote.host], compute_scope.ScopeEnum.ZONE, args.zone, compute_holder.resources, scope_lister=instance_flags.GetInstanceZoneScopeLister( compute_holder.client))[0] instance = self.GetInstance(compute_holder.client, instance_ref) project = self.GetProject(compute_holder.client, instance_ref.project) if not remote.user: remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True) if args.plain: use_oslogin = False else: public_key = self.keys.GetPublicKey().ToEntry(include_comment=True) remote.user, use_oslogin = ssh.CheckForOsloginAndGetUser( instance, project, remote.user, public_key, release_track) identity_file = None options = None if not args.plain: identity_file = self.keys.key_file options = self.GetConfig(ssh_utils.HostKeyAlias(instance), args.strict_host_key_checking) iap_tunnel_args = iap_tunnel.SshTunnelArgs.FromArgs( args, release_track, instance_ref, ssh_utils.GetInternalInterface(instance), ssh_utils.GetExternalInterface(instance, no_raise=True)) if iap_tunnel_args: remote.host = ssh_utils.HostKeyAlias(instance) elif ip_type is ip.IpTypeEnum.INTERNAL: remote.host = ssh_utils.GetInternalIPAddress(instance) else: remote.host = ssh_utils.GetExternalIPAddress(instance) cmd = ssh.SCPCommand(srcs, dst, identity_file=identity_file, options=options, recursive=recursive, compress=compress, port=port, extra_flags=extra_flags, iap_tunnel_args=iap_tunnel_args) if args.dry_run: log.out.Print(' '.join(cmd.Build(self.env))) return if args.plain or use_oslogin: keys_newly_added = False else: keys_newly_added = self.EnsureSSHKeyExists(compute_holder.client, remote.user, instance, project) if keys_newly_added: poller = ssh_utils.CreateSSHPoller(remote, identity_file, options, iap_tunnel_args, port=port) log.status.Print('Waiting for SSH key to propagate.') # TODO(b/35355795): Don't force_connect try: poller.Poll(self.env, force_connect=True) except retry.WaitException: raise ssh_utils.NetworkError() if ip_type is ip.IpTypeEnum.INTERNAL: # This will never happen when IAP Tunnel is enabled, because ip_type is # always EXTERNAL when IAP Tunnel is enabled, even if the instance has no # external IP. IAP Tunnel doesn't need verification because it uses # unambiguous identifiers for the instance. self.PreliminarilyVerifyInstance(instance.id, remote, identity_file, options) # Errors from the SCP command result in an ssh.CommandError being raised cmd.Run(self.env, force_connect=True)