def _FetchInstance(self, args): holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) client = holder.client ssh_helper = ssh_utils.BaseSSHCLIHelper() instance_ref = flags.SSH_INSTANCE_RESOLVER.ResolveResources( [args.instance_name], scope.ScopeEnum.ZONE, args.zone, holder.resources, scope_lister=flags.GetInstanceZoneScopeLister(client))[0] return instance_ref, ssh_helper.GetInstance(client, instance_ref)
def _GetTargetArgs(self, args): holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) client = holder.client ssh_helper = ssh_utils.BaseSSHCLIHelper() instance_ref = flags.SSH_INSTANCE_RESOLVER.ResolveResources( [args.instance_name], scope.ScopeEnum.ZONE, args.zone, holder.resources, scope_lister=flags.GetInstanceZoneScopeLister(client))[0] instance_obj = ssh_helper.GetInstance(client, instance_ref) project = instance_ref.project zone = instance_ref.zone instance = instance_obj.name port = args.instance_port interface = ssh_utils.GetInternalInterface(instance_obj).name return project, zone, instance, interface, port
def Run(self, args): """Default run method implementation.""" super(SosReport, self).Run(args) self._use_accounts_service = False # Obtain the gcloud variables holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) instance = self.GetInstance(holder, args) user = args.user if args.user else ssh.GetDefaultSshUsername() ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) # Create the context variables context = { "args": args, "instance": instance, "ssh_helper": ssh_helper, "user": user, "python_path": args.python_path, } install_path = args.sosreport_install_path reports_path = args.reports_path # We dowload Sosreport into the VM if needed (normally first time) soshelper.ObtainSosreport(context, install_path) # (If needed) We create the directory where the reports will be created log.out.Print( "Creating the path where reports will be written if needed.") soshelper.CreatePath(context, reports_path) # We run the report soshelper.RunSosreport(context, install_path, reports_path) # Obtain and report the filename of the generated report report_path = soshelper.ObtainReportFilename(context, reports_path) msg = 'Report generated into "{report_path}".' log.status.Print(msg.format(report_path=report_path)) # If download_dir is set, we download the report over if args.download_dir: report_path = soshelper.CopyReportFile(context, args.download_dir, report_path) msg = 'Successfully downloaded report to "{report_path}"' log.status.Print(msg.format(report_path=report_path))
def SSHToInstance(self, args, instance): """Helper to manage authentication followed by SSH to the instance.""" args = self._DefaultArgsForSSH(args) external_nat = ssh_utils.GetExternalIPAddress(instance) log.status.Print( 'Trying to SSH to VM with NAT IP:{}'.format(external_nat)) remote = ssh.Remote(external_nat, ssh.GetDefaultSshUsername()) args.ssh_key_file = ssh.Keys.DEFAULT_KEY_FILE ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) identity_file = ssh_helper.keys.key_file user, _ = ssh_utils.GetUserAndInstance(args.name) host_keys = self._GetHostKeyFromInstance(args.zone, ssh_helper, instance) options = self._GetSSHOptions(args.name, ssh_helper, instance, host_keys) self._WaitForSSHKeysToPropagate(ssh_helper, remote, identity_file, user, instance, options) extra_flags = [] # Ctpu seems to be forwarding some other ports on what # seems like the TPU node. Need to understand better before enabling. if args.forward_ports: extra_flags.extend([ '-A', '-L', '6006:localhost:6006', '-L', '8888:localhost:8888' ]) ssh_cmd_args = { 'remote': remote, 'identity_file': identity_file, 'options': options, 'extra_flags': extra_flags } cmd = ssh.SSHCommand(**ssh_cmd_args) # Errors from SSH itself result in an ssh.CommandError being raised return_code = cmd.Run( ssh_helper.env, force_connect=properties.VALUES.ssh.putty_force_connect.GetBool()) if return_code: # This is the return code of the remote command. Problems with SSH itself # will result in ssh.CommandError being raised above. sys.exit(return_code)
def Run(self, args): """See ssh_utils.BaseSSHCLICommand.Run.""" holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) client = holder.client ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) user, instance_name = ssh_utils.GetUserAndInstance(args.user_host) instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources( [instance_name], compute_scope.ScopeEnum.ZONE, args.zone, holder.resources, scope_lister=instance_flags.GetInstanceZoneScopeLister(client))[0] instance = ssh_helper.GetInstance(client, instance_ref) project = ssh_helper.GetProject(client, instance_ref.project) if args.plain: use_oslogin = False else: user, use_oslogin = ssh_helper.CheckForOsloginAndGetUser( instance, project, user, self.ReleaseTrack()) if self._use_internal_ip: ip_address = ssh_utils.GetInternalIPAddress(instance) else: ip_address = ssh_utils.GetExternalIPAddress(instance) remote = ssh.Remote(ip_address, user) identity_file = None options = None if not args.plain: identity_file = ssh_helper.keys.key_file options = ssh_helper.GetConfig(ssh_utils.HostKeyAlias(instance), args.strict_host_key_checking) extra_flags = ssh.ParseAndSubstituteSSHFlags(args, remote, ip_address) remainder = [] if args.ssh_args: remainder.extend(args.ssh_args) # Transform args.command into arg list or None if no command command_list = args.command.split(' ') if args.command else None tty = containers.GetTty(args.container, command_list) remote_command = containers.GetRemoteCommand(args.container, command_list) cmd = ssh.SSHCommand(remote, identity_file=identity_file, options=options, extra_flags=extra_flags, remote_command=remote_command, tty=tty, remainder=remainder) if args.dry_run: log.out.Print(' '.join(cmd.Build(ssh_helper.env))) return if args.plain or use_oslogin: keys_newly_added = False else: keys_newly_added = ssh_helper.EnsureSSHKeyExists( client, remote.user, instance, project) if keys_newly_added: poller = ssh.SSHPoller( remote, identity_file=identity_file, options=options, extra_flags=extra_flags, max_wait_ms=ssh_utils.SSH_KEY_PROPAGATION_TIMEOUT_SEC) log.status.Print('Waiting for SSH key to propagate.') # TODO(b/35355795): Don't force_connect try: poller.Poll(ssh_helper.env, force_connect=True) except retry.WaitException: raise ssh_utils.NetworkError() if self._use_internal_ip: ssh_helper.PreliminarilyVerifyInstance(instance.id, remote, identity_file, options) return_code = cmd.Run(ssh_helper.env, force_connect=True) if return_code: # Can't raise an exception because we don't want any "ERROR" message # printed; the output from `ssh` will be enough. sys.exit(return_code)
def Run(self, args): """Default run method implementation.""" super(Routes, self).Run(args) self._use_accounts_service = False holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) resource_registry = holder.resources ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) # We store always needed commands non-changing fields self._args = args self._ssh_helper = ssh_helper # We obtain generic parameters of the call project = properties.VALUES.core.project.GetOrFail() filters = _RoutesArgs.GetFilters(args) instances = _RoutesQueries.ObtainInstances( args.names, service=self.compute.instances, project=project, zones=args.zones, filters=filters, http=self.http, batch_url=self.batch_url) user = args.user if not user: user = ssh.GetDefaultSshUsername() # We unpack the flags dry_run = args.dry_run reverse_traceroute = args.reverse_traceroute traceroute_args = args.traceroute_args external_route_ip = args.external_route_ip internal_helpers.PrintHeader(instances) prompt = 'The following VMs will be tracerouted.' if instances and not dry_run and not console_io.PromptContinue(prompt): return # Sometimes the prompt would appear after the instance data log.out.flush() for instance in instances: header = 'Checking instance %s' % instance.name log.out.Print(header) log.out.Print('-' * len(header)) try: self.TracerouteInstance(instance, traceroute_args, dry_run, resource_registry) except exceptions.ToolException as e: log.error('Error routing to instance') log.error(str(e)) continue if reverse_traceroute: try: has_traceroute = self.CheckTraceroute( instance, user, dry_run, resource_registry) if has_traceroute: # We obtain the self ip if not external_route_ip: external_route_ip = self.ObtainSelfIp( instance, user, dry_run, resource_registry) if external_route_ip: self.ReverseTracerouteInstance( instance, user, external_route_ip, traceroute_args, dry_run, resource_registry) else: log.out.Print( 'Unable to obtain self ip. Aborting.') else: log.out.Print( 'Please make sure traceroute is installed in PATH to move on.' ) except ssh.CommandError as e: log.error(str(e)) log.out.Print('') # Separator
def Run(self, args): """See ssh_utils.BaseSSHCLICommand.Run.""" on_prem = ( args.IsKnownAndSpecified('network') and args.IsKnownAndSpecified('region')) if on_prem: args.plain = True # These two lines are needed to ensure reauth is performed as needed, even # for on-prem, which doesn't use the resulting variables. holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) client = holder.client ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) oslogin_state = ssh.OsloginState() if on_prem: user, ip = ssh_utils.GetUserAndInstance(args.user_host) remote = ssh.Remote(ip, user) iap_tunnel_args = iap_tunnel.CreateOnPremSshTunnelArgs( args, self.ReleaseTrack(), ip) instance_address = ip internal_address = ip else: user, instance_name = ssh_utils.GetUserAndInstance(args.user_host) instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources( [instance_name], compute_scope.ScopeEnum.ZONE, args.zone, holder.resources, scope_lister=instance_flags.GetInstanceZoneScopeLister(client))[0] instance = ssh_helper.GetInstance(client, instance_ref) project = ssh_helper.GetProject(client, instance_ref.project) host_keys = ssh_helper.GetHostKeysFromGuestAttributes( client, instance_ref, instance, project) iap_tunnel_args = iap_tunnel.CreateSshTunnelArgs( args, self.ReleaseTrack(), instance_ref, ssh_utils.GetExternalInterface(instance, no_raise=True)) internal_address = ssh_utils.GetInternalIPAddress(instance) if args.troubleshoot: log.status.Print(TROUBLESHOOT_HEADER.format( instance_ref, args.zone or instance_ref.zone, datetime.datetime.now() )) RunTroubleshooting(project, args.zone or instance_ref.zone, instance, iap_tunnel_args) return if not host_keys and host_keys is not None: log.debug('Unable to retrieve host keys from instance metadata. ' 'Continuing.') expiration, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs( args) if args.plain: oslogin_state.oslogin_enabled = False else: public_key = ssh_helper.keys.GetPublicKey().ToEntry( include_comment=True) # If there is an '@' symbol in the user_host arg, the user is requesting # to connect as a specific user. This may get overridden by OS Login. username_requested = '@' in args.user_host oslogin_state = ssh.GetOsloginState( instance, project, user, public_key, expiration_micros, self.ReleaseTrack(), username_requested=username_requested) user = oslogin_state.user log.debug(oslogin_state) if iap_tunnel_args: # IAP Tunnel only uses instance_address for the purpose of --ssh-flag # substitution. In this case, dest_addr doesn't do much, it just matches # against entries in the user's ssh_config file. It's best to use # something unique to avoid false positive matches, thus we use # HostKeyAlias. instance_address = internal_address dest_addr = ssh_utils.HostKeyAlias(instance) elif args.internal_ip: instance_address = internal_address dest_addr = instance_address else: instance_address = ssh_utils.GetExternalIPAddress(instance) dest_addr = instance_address remote = ssh.Remote(dest_addr, user) # identity_file_list will be None if security keys are not enabled. identity_file_list = ssh.WriteSecurityKeys(oslogin_state) identity_file = None options = None if not args.plain: if not identity_file_list: identity_file = ssh_helper.keys.key_file options = ssh_helper.GetConfig(ssh_utils.HostKeyAlias(instance), args.strict_host_key_checking, host_keys_to_add=host_keys) extra_flags = ssh.ParseAndSubstituteSSHFlags(args, remote, instance_address, internal_address) remainder = [] if args.ssh_args: remainder.extend(args.ssh_args) # Transform args.command into arg list or None if no command command_list = args.command.split(' ') if args.command else None tty = containers.GetTty(args.container, command_list) remote_command = containers.GetRemoteCommand(args.container, command_list) # Do not include default port since that will prevent users from # specifying a custom port (b/121998342). ssh_cmd_args = {'remote': remote, 'identity_file': identity_file, 'options': options, 'extra_flags': extra_flags, 'remote_command': remote_command, 'tty': tty, 'iap_tunnel_args': iap_tunnel_args, 'remainder': remainder, 'identity_list': identity_file_list} cmd = ssh.SSHCommand(**ssh_cmd_args) if args.dry_run: # Add quotes around any arguments that contain spaces. log.out.Print(' '.join('"{0}"'.format(arg) if ' ' in arg else arg for arg in cmd.Build(ssh_helper.env))) return # Raise errors if instance requires a security key but the local # envionment doesn't support them. This is after the 'dry-run' because # we want to allow printing the command regardless. if self.enable_security_keys: ssh_utils.ConfirmSecurityKeyStatus(oslogin_state) if args.plain or oslogin_state.oslogin_enabled: keys_newly_added = False else: keys_newly_added = ssh_helper.EnsureSSHKeyExists( client, remote.user, instance, project, expiration=expiration) if keys_newly_added: poller = ssh_utils.CreateSSHPoller(remote, identity_file, options, iap_tunnel_args, extra_flags=extra_flags) log.status.Print('Waiting for SSH key to propagate.') # TODO(b/35355795): Don't force_connect try: poller.Poll( ssh_helper.env, force_connect=properties.VALUES.ssh.putty_force_connect.GetBool()) except retry.WaitException: raise ssh_utils.NetworkError() if args.internal_ip and not on_prem: ssh_helper.PreliminarilyVerifyInstance(instance.id, remote, identity_file, options) # Errors from SSH itself result in an ssh.CommandError being raised try: return_code = cmd.Run( ssh_helper.env, force_connect=properties.VALUES.ssh.putty_force_connect.GetBool()) except ssh.CommandError as e: if not on_prem: log.status.Print(self.createRecommendMessage(args, instance_name, instance_ref, project)) raise e if return_code: # This is the return code of the remote command. Problems with SSH itself # will result in ssh.CommandError being raised above. sys.exit(return_code)
def Run(self, args): """See ssh_utils.BaseSSHCLICommand.Run.""" holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) client = holder.client ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) user, instance_name = ssh_utils.GetUserAndInstance(args.user_host) instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources( [instance_name], compute_scope.ScopeEnum.ZONE, args.zone, holder.resources, scope_lister=instance_flags.GetInstanceZoneScopeLister(client))[0] instance = ssh_helper.GetInstance(client, instance_ref) project = ssh_helper.GetProject(client, instance_ref.project) if self.get_host_keys: host_keys = ssh_helper.GetHostKeysFromGuestAttributes( client, instance_ref) if not host_keys: log.warning('Unable to retrieve host keys from instance metadata. ' 'Continuing.') else: host_keys = {} expiration, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs(args) if args.plain: use_oslogin = False else: public_key = ssh_helper.keys.GetPublicKey().ToEntry(include_comment=True) user, use_oslogin = ssh.CheckForOsloginAndGetUser( instance, project, user, public_key, expiration_micros, self.ReleaseTrack()) iap_tunnel_args = iap_tunnel.SshTunnelArgs.FromArgs( args, self.ReleaseTrack(), instance_ref, ssh_utils.GetExternalInterface(instance, no_raise=True)) internal_address = ssh_utils.GetInternalIPAddress(instance) if iap_tunnel_args: # IAP Tunnel only uses instance_address for the purpose of --ssh-flag # substitution. In this case, dest_addr doesn't do much, it just matches # against entries in the user's ssh_config file. It's best to use # something unique to avoid false positive matches, thus we use # HostKeyAlias. instance_address = internal_address dest_addr = ssh_utils.HostKeyAlias(instance) elif args.internal_ip: instance_address = internal_address dest_addr = instance_address else: instance_address = ssh_utils.GetExternalIPAddress(instance) dest_addr = instance_address remote = ssh.Remote(dest_addr, user) identity_file = None options = None if not args.plain: identity_file = ssh_helper.keys.key_file options = ssh_helper.GetConfig(ssh_utils.HostKeyAlias(instance), args.strict_host_key_checking, host_keys_to_add=host_keys) extra_flags = ssh.ParseAndSubstituteSSHFlags(args, remote, instance_address, internal_address) remainder = [] if args.ssh_args: remainder.extend(args.ssh_args) # Transform args.command into arg list or None if no command command_list = args.command.split(' ') if args.command else None tty = containers.GetTty(args.container, command_list) remote_command = containers.GetRemoteCommand(args.container, command_list) # Do not include default port since that will prevent users from # specifying a custom port (b/121998342). ssh_cmd_args = {'remote': remote, 'identity_file': identity_file, 'options': options, 'extra_flags': extra_flags, 'remote_command': remote_command, 'tty': tty, 'iap_tunnel_args': iap_tunnel_args, 'remainder': remainder} cmd = ssh.SSHCommand(**ssh_cmd_args) if args.dry_run: log.out.Print(' '.join(cmd.Build(ssh_helper.env))) return if args.plain or use_oslogin: keys_newly_added = False else: keys_newly_added = ssh_helper.EnsureSSHKeyExists( client, remote.user, instance, project, expiration=expiration) if keys_newly_added: poller = ssh_utils.CreateSSHPoller(remote, identity_file, options, iap_tunnel_args, extra_flags=extra_flags) log.status.Print('Waiting for SSH key to propagate.') # TODO(b/35355795): Don't force_connect try: poller.Poll(ssh_helper.env, force_connect=True) except retry.WaitException: raise ssh_utils.NetworkError() if args.internal_ip: ssh_helper.PreliminarilyVerifyInstance(instance.id, remote, identity_file, options) # Errors from SSH itself result in an ssh.CommandError being raised return_code = cmd.Run(ssh_helper.env, force_connect=True) if return_code: # This is the return code of the remote command. Problems with SSH itself # will result in ssh.CommandError being raised above. sys.exit(return_code)
def Run(self, args): """See ssh_utils.BaseSSHCLICommand.Run.""" holder = base_classes.ComputeApiHolder(self.ReleaseTrack()) client = holder.client ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) user, instance_name = ssh_utils.GetUserAndInstance(args.user_host) instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources( [instance_name], compute_scope.ScopeEnum.ZONE, args.zone, holder.resources, scope_lister=instance_flags.GetInstanceZoneScopeLister(client))[0] instance = ssh_helper.GetInstance(client, instance_ref) project = ssh_helper.GetProject(client, instance_ref.project) if args.plain: use_oslogin = False else: public_key = ssh_helper.keys.GetPublicKey().ToEntry( include_comment=True) user, use_oslogin = ssh.CheckForOsloginAndGetUser( instance, project, user, public_key, self.ReleaseTrack()) if args.internal_ip: ip_address = ssh_utils.GetInternalIPAddress(instance) else: ip_address = ssh_utils.GetExternalIPAddress(instance) remote = ssh.Remote(ip_address, user) identity_file = None options = None if not args.plain: identity_file = ssh_helper.keys.key_file options = ssh_helper.GetConfig(ssh_utils.HostKeyAlias(instance), args.strict_host_key_checking) extra_flags = ssh.ParseAndSubstituteSSHFlags(args, remote, ip_address) remainder = [] if args.ssh_args: remainder.extend(args.ssh_args) # Transform args.command into arg list or None if no command command_list = args.command.split(' ') if args.command else None tty = containers.GetTty(args.container, command_list) remote_command = containers.GetRemoteCommand(args.container, command_list) target_remote = remote port = ssh_utils.DEFAULT_SSH_PORT ip_type = (ip.IpTypeEnum.INTERNAL if args.internal_ip else ip.IpTypeEnum.EXTERNAL) tunnel_helper = None interface = None if hasattr(args, 'tunnel_through_iap') and args.tunnel_through_iap: tunnel_helper, interface = ssh_utils.CreateIapTunnelHelper( args, instance_ref, instance, ip_type) tunnel_helper.StartListener() target_remote = ssh.Remote('localhost', user) port = tunnel_helper.GetLocalPort() cmd = ssh.SSHCommand(target_remote, port=str(port), identity_file=identity_file, options=options, extra_flags=extra_flags, remote_command=remote_command, tty=tty, remainder=remainder) if args.dry_run: log.out.Print(' '.join(cmd.Build(ssh_helper.env))) if tunnel_helper: tunnel_helper.StopListener() return if args.plain or use_oslogin: keys_newly_added = False else: keys_newly_added = ssh_helper.EnsureSSHKeyExists( client, remote.user, instance, project) if keys_newly_added: poller_tunnel_helper = None if tunnel_helper: poller_tunnel_helper, _ = ssh_utils.CreateIapTunnelHelper( args, instance_ref, instance, ip_type, interface=interface) poller_tunnel_helper.StartListener( accept_multiple_connections=True) poller = ssh_utils.CreateSSHPoller(remote, identity_file, options, poller_tunnel_helper, extra_flags=extra_flags) log.status.Print('Waiting for SSH key to propagate.') # TODO(b/35355795): Don't force_connect try: poller.Poll(ssh_helper.env, force_connect=True) except retry.WaitException: if tunnel_helper: tunnel_helper.StopListener() raise ssh_utils.NetworkError() finally: if poller_tunnel_helper: poller_tunnel_helper.StopListener() if args.internal_ip and not tunnel_helper: # The IAP Tunnel connection uses instance name and network interface name, # so do not need to additionally verify the instance. Also, the # SSHCommand used within the function does not support IAP Tunnels. ssh_helper.PreliminarilyVerifyInstance(instance.id, remote, identity_file, options) try: # Errors from SSH itself result in an ssh.CommandError being raised return_code = cmd.Run(ssh_helper.env, force_connect=True) finally: if tunnel_helper: tunnel_helper.StopListener() if return_code: # This is the return code of the remote command. Problems with SSH itself # will result in ssh.CommandError being raised above. sys.exit(return_code)
def Run(self, args): dst = ssh.FileReference.FromPath(args.destination) srcs = [ssh.FileReference.FromPath(src) for src in args.sources] ssh.SCPCommand.Verify(srcs, dst, single_remote=True) if dst.remote: tpu_name = dst.remote.host else: tpu_name = srcs[0].remote.host # If zone is not set, retrieve the one from the config. if args.zone is None: args.zone = properties.VALUES.compute.zone.Get(required=True) # Retrieve the node. tpu = tpu_utils.TPUNode(self.ReleaseTrack()) node = tpu.Get(tpu_name, args.zone) if not tpu_utils.IsTPUVMNode(node): raise exceptions.BadArgumentException( 'TPU', 'this command is only available for Cloud TPU VM nodes. To access ' 'this node, please see ' 'https://cloud.google.com/tpu/docs/creating-deleting-tpus.') worker_ips = tpu_ssh_utils.ParseWorkerFlag(args.worker, node.networkEndpoints, args.internal_ip) if len(worker_ips) > 1 and srcs[0].remote: raise exceptions.InvalidArgumentException( '--worker', 'cannot target multiple workers while copying files to ' 'client.') tpu_ssh_utils.ValidateTPUState(node.state, tpu.messages.Node.StateValueValuesEnum) # Retrieve GuestAttributes. single_pod_worker = len( node.networkEndpoints) > 1 and len(worker_ips) == 1 if single_pod_worker: # Retrieve only that worker's GuestAttributes. worker_id = list(worker_ips)[0] guest_attributes_response = tpu.GetGuestAttributes( tpu_name, args.zone, six.text_type((worker_id))) host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes( guest_attributes_response.guestAttributes, len(node.networkEndpoints), worker_id) else: # Retrieve the GuestAttributes for all workers in that TPU. guest_attributes_response = tpu.GetGuestAttributes( tpu_name, args.zone) host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes( guest_attributes_response.guestAttributes) # Generate the public key. ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) public_key = ssh_helper.keys.GetPublicKey().ToEntry() remote = dst.remote or srcs[0].remote if not dst.remote: # Make sure all remotes point to the same ref. for src in srcs: src.remote = remote if remote.user: username_requested = True else: username_requested = False remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True) project = tpu_utils.GetProject(self.ReleaseTrack(), ssh_helper) if not args.plain: # If there is an '@' symbol in the user_host arg, the user is requesting # to connect as a specific user. This may get overridden by OS Login. _, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs(args) oslogin_state = ssh.GetOsloginState( None, project, remote.user, public_key, expiration_micros, self.ReleaseTrack(), username_requested=username_requested, instance_enable_oslogin=tpu_ssh_utils.TpuHasOsLoginEnabled( node)) remote.user = oslogin_state.user # Format the key correctly. public_key = '{1}:{0} {1}'.format(public_key, remote.user) if not args.plain and not args.dry_run: tpu_ssh_utils.AddSSHKeyIfNeeded(project, tpu, node, tpu_name, args.zone, public_key) identity_file = None if not args.plain: identity_file = ssh_helper.keys.key_file # If the user's key is not in the SSH agent, the command will stall. We # want to verify it is added before proceeding, and raise an error if it # is not. if not args.dry_run and len(worker_ips) > 1: tpu_ssh_utils.VerifyKeyInAgent(identity_file) extra_flags = [] if args.scp_flag: extra_flags.extend(args.scp_flag) instance_names = {} if (args.IsKnownAndSpecified('tunnel_through_iap') and args.tunnel_through_iap): # Retrieve the instance names from the GuestAttributes. for worker in worker_ips: # The GuestAttributes will only have one entry if we're targeting a # single worker. index = 0 if single_pod_worker else worker instance_name = tpu_ssh_utils.GetFromGuestAttributes( guest_attributes_response.guestAttributes, index, 'hostname') if instance_name is None: log.status.Print('Failed to connect to TPU.') log.status.Print(tpu_ssh_utils.IAP_TROUBLESHOOTING_HELP) raise tpu_exceptions.IapTunnelingUnavailable() instance_names[worker] = instance_name ssh_threads = [] exit_statuses = [None] * len(worker_ips) for worker, ips in worker_ips.items(): options = None if not args.plain: options = ssh_helper.GetConfig( tpu_ssh_utils.GetInstanceID(node.id, worker, host_key_suffixes), args.strict_host_key_checking, None) iap_tunnel_args = None if (args.IsKnownAndSpecified('tunnel_through_iap') and args.tunnel_through_iap): # Retrieve the instance name from the GuestAttributes. instance_name = instance_names[worker] iap_tunnel_args = tpu_ssh_utils.CreateSshTunnelArgs( args, self.ReleaseTrack(), project, args.zone, instance_name) remote.host = ips.ip_address cmd = ssh.SCPCommand(srcs, dst, identity_file=identity_file, options=options, recursive=args.recurse, compress=args.compress, extra_flags=extra_flags, iap_tunnel_args=iap_tunnel_args) if args.dry_run: log.out.Print(' '.join(cmd.Build(ssh_helper.env))) continue if len(worker_ips) > 1: # Run the command on multiple workers concurrently. ssh_threads.append( threading.Thread( target=tpu_ssh_utils.AttemptRunWithRetries, args=('SCP', worker, exit_statuses, cmd, ssh_helper.env, None, True, SCPRunCmd))) ssh_threads[-1].start() else: # Run on a single worker. tpu_ssh_utils.AttemptRunWithRetries('SCP', worker, exit_statuses, cmd, ssh_helper.env, None, False, SCPRunCmd) if len(worker_ips) > 1: # Wait for all the threads to complete. for i in range(len(ssh_threads)): ssh_threads[i].join() # Exit with a nonzero status, if any. # This ensures that if any command failed on a worker, we don't end up # returning 0 for a value. for status in exit_statuses: if status: sys.exit(status)
def Run(self, args): user, tpu_name = ssh_utils.GetUserAndInstance(args.user_tpu) # If zone is not set, retrieve the one from the config. if args.zone is None: args.zone = properties.VALUES.compute.zone.Get(required=True) # Validate the output path. if args.output_directory: if not args.command: raise exceptions.InvalidArgumentException( '--output_directory', 'cannot be specified without the `--command` ' 'flag. Please specify the `--command` flag or remove the ' '--output-directory flag.') output_directory_path = os.path.abspath( os.path.expandvars(os.path.expanduser(args.output_directory))) if not os.path.isdir(output_directory_path): raise exceptions.InvalidArgumentException( '--output_directory', 'Failed to find directory {}. Please create ' 'it or specify another directory'.format( output_directory_path)) # Retrieve the node. tpu = tpu_utils.TPUNode(self.ReleaseTrack()) node = tpu.Get(tpu_name, args.zone) if not tpu_utils.IsTPUVMNode(node): raise exceptions.BadArgumentException( 'TPU', 'this command is only available for Cloud TPU VM nodes. To access ' 'this node, please see ' 'https://cloud.google.com/tpu/docs/creating-deleting-tpus.') tpu_ssh_utils.ValidateTPUState(node.state, tpu.messages.Node.StateValueValuesEnum) worker_ips = tpu_ssh_utils.ParseWorkerFlag(args.worker, node.networkEndpoints, args.internal_ip) if len(worker_ips) > 1 and not args.command: raise exceptions.InvalidArgumentException( '--worker', 'cannot target multiple workers without the `--command` ' 'flag.') # Retrieve GuestAttributes. single_pod_worker = len( node.networkEndpoints) > 1 and len(worker_ips) == 1 if single_pod_worker: # Retrieve only that worker's GuestAttributes. worker_id = list(worker_ips)[0] guest_attributes_response = tpu.GetGuestAttributes( tpu_name, args.zone, six.text_type((worker_id))) host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes( guest_attributes_response.guestAttributes, len(node.networkEndpoints), worker_id) else: # Retrieve the GuestAttributes for all workers in that TPU. guest_attributes_response = tpu.GetGuestAttributes( tpu_name, args.zone) host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes( guest_attributes_response.guestAttributes) # Generate the public key. ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) public_key = ssh_helper.keys.GetPublicKey().ToEntry() project = tpu_utils.GetProject(self.ReleaseTrack(), ssh_helper) if not args.plain: # If there is an '@' symbol in the user_host arg, the user is requesting # to connect as a specific user. This may get overridden by OS Login. username_requested = '@' in args.user_tpu _, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs(args) oslogin_state = ssh.GetOsloginState( None, project, user, public_key, expiration_micros, self.ReleaseTrack(), username_requested=username_requested, instance_enable_oslogin=tpu_ssh_utils.TpuHasOsLoginEnabled( node)) user = oslogin_state.user # Format the key correctly. public_key = '{1}:{0} {1}'.format(public_key, user) if not args.plain and not args.dry_run: tpu_ssh_utils.AddSSHKeyIfNeeded(project, tpu, node, tpu_name, args.zone, public_key) command_list = args.command.split(' ') if args.command else None remainder = [] if args.ssh_args: remainder.extend(args.ssh_args) if args.output_directory: log.status.Print( 'Preparing SSH command execution; output will be logged ' 'to {}'.format(output_directory_path)) instance_names = {} if (args.IsKnownAndSpecified('tunnel_through_iap') and args.tunnel_through_iap): # Retrieve the instance names from the GuestAttributes. for worker in worker_ips: # The GuestAttributes will only have one entry if we're targeting a # single worker. index = 0 if single_pod_worker else worker instance_name = tpu_ssh_utils.GetFromGuestAttributes( guest_attributes_response.guestAttributes, index, 'hostname') if instance_name is None: log.status.Print('Failed to connect to TPU.') log.status.Print(tpu_ssh_utils.IAP_TROUBLESHOOTING_HELP) raise tpu_exceptions.IapTunnelingUnavailable() instance_names[worker] = instance_name ssh_threads = [] exit_statuses = [None] * len(worker_ips) for worker, ips in worker_ips.items(): identity_file = None options = None if not args.plain: identity_file = ssh_helper.keys.key_file options = ssh_helper.GetConfig( tpu_ssh_utils.GetInstanceID(node.id, worker, host_key_suffixes), args.strict_host_key_checking, None) remote = ssh.Remote(ips.ip_address, user) extra_flags = ssh.ParseAndSubstituteSSHFlags( args, remote, ips.ip_address, ips.internal_address) iap_tunnel_args = None if (args.IsKnownAndSpecified('tunnel_through_iap') and args.tunnel_through_iap): # Retrieve the instance name from the GuestAttributes. instance_name = instance_names[worker] iap_tunnel_args = tpu_ssh_utils.CreateSshTunnelArgs( args, self.ReleaseTrack(), project, args.zone, instance_name) cmd = ssh.SSHCommand(remote=remote, identity_file=identity_file, remote_command=command_list, extra_flags=extra_flags, options=options, remainder=remainder, iap_tunnel_args=iap_tunnel_args) if args.dry_run: log.out.Print(' '.join(cmd.Build(ssh_helper.env))) continue output_file_writer = None if args.output_directory: output_file_writer = FileWriter('{}/{}.log'.format( output_directory_path, six.text_type(worker))) if len(worker_ips) > 1: # Run the command on multiple workers concurrently. ssh_threads.append( threading.Thread( target=tpu_ssh_utils.AttemptRunWithRetries, args=('SSH', worker, exit_statuses, cmd, ssh_helper.env, output_file_writer, True, SSHRunCmd))) ssh_threads[-1].start() else: # Run on a single worker. tpu_ssh_utils.AttemptRunWithRetries('SSH', worker, exit_statuses, cmd, ssh_helper.env, output_file_writer, False, SSHRunCmd) if len(worker_ips) > 1: # Wait for all the threads to complete. for i in range(len(ssh_threads)): ssh_threads[i].join() # Exit with a nonzero code, if there are any. # This ensures that if any command failed on a worker, we don't end up # returning 0 for a value. for status in exit_statuses: if status: sys.exit(status)
def SSHToInstance(self, args, instance): """Helper to manage authentication followed by SSH to the instance.""" args = self._DefaultArgsForSSH(args) external_nat = ssh_utils.GetExternalIPAddress(instance) log.status.Print( 'Trying to SSH to VM with NAT IP:{}'.format(external_nat)) args.ssh_key_file = ssh.Keys.DEFAULT_KEY_FILE ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) identity_file = ssh_helper.keys.key_file user, _ = ssh_utils.GetUserAndInstance(args.name) host_keys = self._GetHostKeyFromInstance(args.zone, ssh_helper, instance) options = self._GetSSHOptions(args.name, ssh_helper, instance, host_keys) public_key = ssh_helper.keys.GetPublicKey().ToEntry(include_comment=True) user, use_oslogin = ssh.CheckForOsloginAndGetUser( instance, ssh_helper.GetProject( self.client, properties.VALUES.core.project.Get(required=True)), user, public_key, None, self.release_track, username_requested=False) remote = ssh.Remote(external_nat, user) if not use_oslogin: self._WaitForSSHKeysToPropagate(ssh_helper, remote, identity_file, user, instance, options) extra_flags = [] # Ctpu seems to be forwarding some other ports on what # seems like the TPU node. Need to understand better before enabling. if args.forward_ports: extra_flags.extend( ['-A', '-L', '6006:localhost:6006', '-L', '8888:localhost:8888']) ssh_cmd_args = { 'remote': remote, 'identity_file': identity_file, 'options': options, 'extra_flags': extra_flags } cmd = ssh.SSHCommand(**ssh_cmd_args) max_attempts = 10 sleep_interval = 30 # Since the instance was just created, it can take a while for the instance # to be ready to accept ssh connections, therefore retry up to 5m. Doesn't # need to be backed off, regular interval retry is sufficient since we # aren't looking to throttle. for i in range(max_attempts): try: log.status.Print('SSH Attempt #{}...'.format(i)) # Errors from SSH itself result in an ssh.CommandError being raised return_code = cmd.Run( ssh_helper.env, force_connect=properties.VALUES.ssh.putty_force_connect.GetBool()) if return_code: # This is the return code of the remote command. # Problems with SSH itself will result in ssh.CommandError # being raised above. sys.exit(return_code) except ssh.CommandError as e: if i == max_attempts - 1: raise e log.status.Print( 'Retrying: SSH command error: {}'.format(six.text_type(e))) time.sleep(sleep_interval) continue break