def Run(self, args): """Securily copy files from/to a running flex instance. Args: args: argparse.Namespace, the args the command was invoked with. Raises: InvalidInstanceTypeError: The instance is not supported for SSH. MissingVersionError: The version specified does not exist. MissingInstanceError: The instance specified does not exist. UnattendedPromptError: Not running in a tty. OperationCancelledError: User cancelled the operation. ssh.CommandError: The SCP command exited with SCP exit code, which usually implies that a connection problem occurred. Returns: int, The exit code of the SCP command. """ api_client = appengine_api_client.GetApiClientForTrack( self.ReleaseTrack()) env = ssh.Environment.Current() env.RequireSSH() keys = ssh.Keys.FromFilename() keys.EnsureKeysExist(overwrite=False) # Make sure we have a unique remote dst = ssh.FileReference.FromPath(args.destination) srcs = [ssh.FileReference.FromPath(source) for source in args.sources] ssh.SCPCommand.Verify(srcs, dst, single_remote=True) remote = dst.remote or srcs[0].remote if not dst.remote: # Make sure all remotes point to the same ref for src in srcs: src.remote = remote connection_details = ssh_common.PopulatePublicKey( api_client, args.service, args.version, remote.host, keys.GetPublicKey(), self.ReleaseTrack()) # Update all remote references remote.host = connection_details.remote.host remote.user = connection_details.remote.user cmd = ssh.SCPCommand(srcs, dst, identity_file=keys.key_file, compress=args.compress, recursive=args.recurse, options=connection_details.options) return cmd.Run(env)
def Run(self, args): connection_info = util.PrepareEnvironment(args) remote = ssh.Remote(host=connection_info.host, user=connection_info.user) command = ssh.SCPCommand( sources=[ToFileReference(src, remote) for src in args.sources], destination=ToFileReference(args.destination, remote), recursive=args.recurse, compress=False, port=str(connection_info.port), identity_file=connection_info.key, extra_flags=args.scp_flag, options={'StrictHostKeyChecking': 'no'}, ) if args.dry_run: log.Print(' '.join(command.Build(connection_info.ssh_env))) else: command.Run(connection_info.ssh_env)
def Run(self, args, port=None, recursive=False, extra_flags=None): """SCP files between local and remote GCE instance. Run this method from subclasses' Run methods. Args: args: argparse.Namespace, the args the command was invoked with. port: str, int or None, Port number to use for SSH connection. recursive: bool, Whether to use recursive copying using -R flag. extra_flags: [str] or None, extra flags to add to command invocation. Raises: ssh_utils.NetworkError: Network issue which likely is due to failure of SSH key propagation. ssh.CommandError: The SSH command exited with SSH exit code, which usually implies that a connection problem occurred. """ super(BaseScpCommand, self).Run(args) dst = ssh.FileReference.FromPath(args.destination) srcs = [ssh.FileReference.FromPath(src) for src in args.sources] # Make sure we have a unique remote ssh.SCPCommand.Verify(srcs, dst, single_remote=True) remote = dst.remote or srcs[0].remote if not dst.remote: # Make sure all remotes point to the same ref for src in srcs: src.remote = remote instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources( [remote.host], compute_scope.ScopeEnum.ZONE, args.zone, self.resources, scope_lister=flags.GetDefaultScopeLister(self.compute_client))[0] instance = self.GetInstance(instance_ref) # Now replace the instance name with the actual IP/hostname remote.host = ssh_utils.GetExternalIPAddress(instance) if not remote.user: remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True) identity_file = None options = None if not args.plain: identity_file = self.keys.key_file options = self.GetConfig(ssh_utils.HostKeyAlias(instance), args.strict_host_key_checking) cmd = ssh.SCPCommand( srcs, dst, identity_file=identity_file, options=options, recursive=recursive, port=port, extra_flags=extra_flags) if args.dry_run: log.out.Print(' '.join(cmd.Build(self.env))) return if args.plain: keys_newly_added = False else: keys_newly_added = self.EnsureSSHKeyExists( remote.user, instance, instance_ref.project, use_account_service=self._use_account_service) if keys_newly_added: poller = ssh.SSHPoller( remote, identity_file=identity_file, options=options, max_wait_ms=ssh_utils.SSH_KEY_PROPAGATION_TIMEOUT_SEC) log.status.Print('Waiting for SSH key to propagate.') # TODO(b/35355795): Don't force_connect try: poller.Poll(self.env, force_connect=True) except retry.WaitException: raise ssh_utils.NetworkError() return_code = cmd.Run(self.env, force_connect=True) if return_code: # Can't raise an exception because we don't want any "ERROR" message # printed; the output from `ssh` will be enough. sys.exit(return_code)
def RunScp(self, compute_holder, args, port=None, recursive=False, compress=False, extra_flags=None, release_track=None, ip_type=ip.IpTypeEnum.EXTERNAL): """SCP files between local and remote GCE instance. Run this method from subclasses' Run methods. Args: compute_holder: The ComputeApiHolder. args: argparse.Namespace, the args the command was invoked with. port: str or None, Port number to use for SSH connection. recursive: bool, Whether to use recursive copying using -R flag. compress: bool, Whether to use compression. extra_flags: [str] or None, extra flags to add to command invocation. release_track: obj, The current release track. ip_type: IpTypeEnum, Specify using internal ip or external ip address. Raises: ssh_utils.NetworkError: Network issue which likely is due to failure of SSH key propagation. ssh.CommandError: The SSH command exited with SSH exit code, which usually implies that a connection problem occurred. """ if release_track is None: release_track = base.ReleaseTrack.GA super(BaseScpHelper, self).Run(args) dst = ssh.FileReference.FromPath(args.destination) srcs = [ssh.FileReference.FromPath(src) for src in args.sources] # Make sure we have a unique remote ssh.SCPCommand.Verify(srcs, dst, single_remote=True) remote = dst.remote or srcs[0].remote if not dst.remote: # Make sure all remotes point to the same ref for src in srcs: src.remote = remote instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources( [remote.host], compute_scope.ScopeEnum.ZONE, args.zone, compute_holder.resources, scope_lister=instance_flags.GetInstanceZoneScopeLister( compute_holder.client))[0] instance = self.GetInstance(compute_holder.client, instance_ref) project = self.GetProject(compute_holder.client, instance_ref.project) if not remote.user: remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True) if args.plain: use_oslogin = False else: public_key = self.keys.GetPublicKey().ToEntry(include_comment=True) remote.user, use_oslogin = ssh.CheckForOsloginAndGetUser( instance, project, remote.user, public_key, release_track) identity_file = None options = None if not args.plain: identity_file = self.keys.key_file options = self.GetConfig(ssh_utils.HostKeyAlias(instance), args.strict_host_key_checking) iap_tunnel_args = iap_tunnel.SshTunnelArgs.FromArgs( args, release_track, instance_ref, ssh_utils.GetInternalInterface(instance), ssh_utils.GetExternalInterface(instance, no_raise=True)) if iap_tunnel_args: remote.host = ssh_utils.HostKeyAlias(instance) elif ip_type is ip.IpTypeEnum.INTERNAL: remote.host = ssh_utils.GetInternalIPAddress(instance) else: remote.host = ssh_utils.GetExternalIPAddress(instance) cmd = ssh.SCPCommand(srcs, dst, identity_file=identity_file, options=options, recursive=recursive, compress=compress, port=port, extra_flags=extra_flags, iap_tunnel_args=iap_tunnel_args) if args.dry_run: log.out.Print(' '.join(cmd.Build(self.env))) return if args.plain or use_oslogin: keys_newly_added = False else: keys_newly_added = self.EnsureSSHKeyExists(compute_holder.client, remote.user, instance, project) if keys_newly_added: poller = ssh_utils.CreateSSHPoller(remote, identity_file, options, iap_tunnel_args, port=port) log.status.Print('Waiting for SSH key to propagate.') # TODO(b/35355795): Don't force_connect try: poller.Poll(self.env, force_connect=True) except retry.WaitException: raise ssh_utils.NetworkError() if ip_type is ip.IpTypeEnum.INTERNAL: # This will never happen when IAP Tunnel is enabled, because ip_type is # always EXTERNAL when IAP Tunnel is enabled, even if the instance has no # external IP. IAP Tunnel doesn't need verification because it uses # unambiguous identifiers for the instance. self.PreliminarilyVerifyInstance(instance.id, remote, identity_file, options) # Errors from the SCP command result in an ssh.CommandError being raised cmd.Run(self.env, force_connect=True)
def Run(self, args): dst = ssh.FileReference.FromPath(args.destination) srcs = [ssh.FileReference.FromPath(src) for src in args.sources] ssh.SCPCommand.Verify(srcs, dst, single_remote=True) if dst.remote: tpu_name = dst.remote.host else: tpu_name = srcs[0].remote.host # If zone is not set, retrieve the one from the config. if args.zone is None: args.zone = properties.VALUES.compute.zone.Get(required=True) # Retrieve the node. tpu = tpu_utils.TPUNode(self.ReleaseTrack()) node = tpu.Get(tpu_name, args.zone) if not tpu_utils.IsTPUVMNode(node): raise exceptions.BadArgumentException( 'TPU', 'this command is only available for Cloud TPU VM nodes. To access ' 'this node, please see ' 'https://cloud.google.com/tpu/docs/creating-deleting-tpus.') worker_ips = tpu_ssh_utils.ParseWorkerFlag(args.worker, node.networkEndpoints, args.internal_ip) if len(worker_ips) > 1 and srcs[0].remote: raise exceptions.InvalidArgumentException( '--worker', 'cannot target multiple workers while copying files to ' 'client.') tpu_ssh_utils.ValidateTPUState(node.state, tpu.messages.Node.StateValueValuesEnum) # Retrieve GuestAttributes. single_pod_worker = len( node.networkEndpoints) > 1 and len(worker_ips) == 1 if single_pod_worker: # Retrieve only that worker's GuestAttributes. worker_id = list(worker_ips)[0] guest_attributes_response = tpu.GetGuestAttributes( tpu_name, args.zone, six.text_type((worker_id))) host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes( guest_attributes_response.guestAttributes, len(node.networkEndpoints), worker_id) else: # Retrieve the GuestAttributes for all workers in that TPU. guest_attributes_response = tpu.GetGuestAttributes( tpu_name, args.zone) host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes( guest_attributes_response.guestAttributes) # Generate the public key. ssh_helper = ssh_utils.BaseSSHCLIHelper() ssh_helper.Run(args) public_key = ssh_helper.keys.GetPublicKey().ToEntry() remote = dst.remote or srcs[0].remote if not dst.remote: # Make sure all remotes point to the same ref. for src in srcs: src.remote = remote if remote.user: username_requested = True else: username_requested = False remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True) project = tpu_utils.GetProject(self.ReleaseTrack(), ssh_helper) if not args.plain: # If there is an '@' symbol in the user_host arg, the user is requesting # to connect as a specific user. This may get overridden by OS Login. _, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs(args) oslogin_state = ssh.GetOsloginState( None, project, remote.user, public_key, expiration_micros, self.ReleaseTrack(), username_requested=username_requested, instance_enable_oslogin=tpu_ssh_utils.TpuHasOsLoginEnabled( node)) remote.user = oslogin_state.user # Format the key correctly. public_key = '{1}:{0} {1}'.format(public_key, remote.user) if not args.plain and not args.dry_run: tpu_ssh_utils.AddSSHKeyIfNeeded(project, tpu, node, tpu_name, args.zone, public_key) identity_file = None if not args.plain: identity_file = ssh_helper.keys.key_file # If the user's key is not in the SSH agent, the command will stall. We # want to verify it is added before proceeding, and raise an error if it # is not. if not args.dry_run and len(worker_ips) > 1: tpu_ssh_utils.VerifyKeyInAgent(identity_file) extra_flags = [] if args.scp_flag: extra_flags.extend(args.scp_flag) instance_names = {} if (args.IsKnownAndSpecified('tunnel_through_iap') and args.tunnel_through_iap): # Retrieve the instance names from the GuestAttributes. for worker in worker_ips: # The GuestAttributes will only have one entry if we're targeting a # single worker. index = 0 if single_pod_worker else worker instance_name = tpu_ssh_utils.GetFromGuestAttributes( guest_attributes_response.guestAttributes, index, 'hostname') if instance_name is None: log.status.Print('Failed to connect to TPU.') log.status.Print(tpu_ssh_utils.IAP_TROUBLESHOOTING_HELP) raise tpu_exceptions.IapTunnelingUnavailable() instance_names[worker] = instance_name ssh_threads = [] exit_statuses = [None] * len(worker_ips) for worker, ips in worker_ips.items(): options = None if not args.plain: options = ssh_helper.GetConfig( tpu_ssh_utils.GetInstanceID(node.id, worker, host_key_suffixes), args.strict_host_key_checking, None) iap_tunnel_args = None if (args.IsKnownAndSpecified('tunnel_through_iap') and args.tunnel_through_iap): # Retrieve the instance name from the GuestAttributes. instance_name = instance_names[worker] iap_tunnel_args = tpu_ssh_utils.CreateSshTunnelArgs( args, self.ReleaseTrack(), project, args.zone, instance_name) remote.host = ips.ip_address cmd = ssh.SCPCommand(srcs, dst, identity_file=identity_file, options=options, recursive=args.recurse, compress=args.compress, extra_flags=extra_flags, iap_tunnel_args=iap_tunnel_args) if args.dry_run: log.out.Print(' '.join(cmd.Build(ssh_helper.env))) continue if len(worker_ips) > 1: # Run the command on multiple workers concurrently. ssh_threads.append( threading.Thread( target=tpu_ssh_utils.AttemptRunWithRetries, args=('SCP', worker, exit_statuses, cmd, ssh_helper.env, None, True, SCPRunCmd))) ssh_threads[-1].start() else: # Run on a single worker. tpu_ssh_utils.AttemptRunWithRetries('SCP', worker, exit_statuses, cmd, ssh_helper.env, None, False, SCPRunCmd) if len(worker_ips) > 1: # Wait for all the threads to complete. for i in range(len(ssh_threads)): ssh_threads[i].join() # Exit with a nonzero status, if any. # This ensures that if any command failed on a worker, we don't end up # returning 0 for a value. for status in exit_statuses: if status: sys.exit(status)
def RunScp(self, compute_holder, args, port=None, recursive=False, compress=False, extra_flags=None, release_track=None, ip_type=ip.IpTypeEnum.EXTERNAL): """SCP files between local and remote GCE instance. Run this method from subclasses' Run methods. Args: compute_holder: The ComputeApiHolder. args: argparse.Namespace, the args the command was invoked with. port: str or None, Port number to use for SSH connection. recursive: bool, Whether to use recursive copying using -R flag. compress: bool, Whether to use compression. extra_flags: [str] or None, extra flags to add to command invocation. release_track: obj, The current release track. ip_type: IpTypeEnum, Specify using internal ip or external ip address. Raises: ssh_utils.NetworkError: Network issue which likely is due to failure of SSH key propagation. ssh.CommandError: The SSH command exited with SSH exit code, which usually implies that a connection problem occurred. """ if release_track is None: release_track = base.ReleaseTrack.GA super(BaseScpHelper, self).Run(args) dst = ssh.FileReference.FromPath(args.destination) srcs = [ssh.FileReference.FromPath(src) for src in args.sources] # Make sure we have a unique remote ssh.SCPCommand.Verify(srcs, dst, single_remote=True) remote = dst.remote or srcs[0].remote if not dst.remote: # Make sure all remotes point to the same ref for src in srcs: src.remote = remote instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources( [remote.host], compute_scope.ScopeEnum.ZONE, args.zone, compute_holder.resources, scope_lister=instance_flags.GetInstanceZoneScopeLister( compute_holder.client))[0] instance = self.GetInstance(compute_holder.client, instance_ref) project = self.GetProject(compute_holder.client, instance_ref.project) # Now replace the instance name with the actual IP/hostname if ip_type is ip.IpTypeEnum.INTERNAL: remote.host = ssh_utils.GetInternalIPAddress(instance) else: remote.host = ssh_utils.GetExternalIPAddress(instance) if not remote.user: remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True) if args.plain: use_oslogin = False else: public_key = self.keys.GetPublicKey().ToEntry(include_comment=True) remote.user, use_oslogin = ssh.CheckForOsloginAndGetUser( instance, project, remote.user, public_key, release_track) identity_file = None options = None if not args.plain: identity_file = self.keys.key_file options = self.GetConfig(ssh_utils.HostKeyAlias(instance), args.strict_host_key_checking) tunnel_helper = None cmd_port = port interface = None if hasattr(args, 'tunnel_through_iap') and args.tunnel_through_iap: tunnel_helper, interface = ssh_utils.CreateIapTunnelHelper( args, instance_ref, instance, ip_type, port=port) tunnel_helper.StartListener() cmd_port = str(tunnel_helper.GetLocalPort()) if dst.remote: dst.remote.host = 'localhost' else: for src in srcs: src.remote.host = 'localhost' cmd = ssh.SCPCommand(srcs, dst, identity_file=identity_file, options=options, recursive=recursive, compress=compress, port=cmd_port, extra_flags=extra_flags) if args.dry_run: log.out.Print(' '.join(cmd.Build(self.env))) if tunnel_helper: tunnel_helper.StopListener() return if args.plain or use_oslogin: keys_newly_added = False else: keys_newly_added = self.EnsureSSHKeyExists(compute_holder.client, remote.user, instance, project) if keys_newly_added: poller_tunnel_helper = None if tunnel_helper: poller_tunnel_helper, _ = ssh_utils.CreateIapTunnelHelper( args, instance_ref, instance, ip_type, port=port, interface=interface) poller_tunnel_helper.StartListener( accept_multiple_connections=True) poller = ssh_utils.CreateSSHPoller(remote, identity_file, options, poller_tunnel_helper, port=port) log.status.Print('Waiting for SSH key to propagate.') # TODO(b/35355795): Don't force_connect try: poller.Poll(self.env, force_connect=True) except retry.WaitException: if tunnel_helper: tunnel_helper.StopListener() raise ssh_utils.NetworkError() finally: if poller_tunnel_helper: poller_tunnel_helper.StopListener() if ip_type is ip.IpTypeEnum.INTERNAL and not tunnel_helper: # The IAP Tunnel connection uses instance name and network interface name, # so do not need to additionally verify the instance. Also, the # SSHCommand used within the function does not support IAP Tunnels. self.PreliminarilyVerifyInstance(instance.id, remote, identity_file, options) try: # Errors from the SCP command result in an ssh.CommandError being raised cmd.Run(self.env, force_connect=True) finally: if tunnel_helper: tunnel_helper.StopListener()