Ejemplo n.º 1
0
    def Run(self, args):
        """Securily copy files from/to a running flex instance.

    Args:
      args: argparse.Namespace, the args the command was invoked with.

    Raises:
      InvalidInstanceTypeError: The instance is not supported for SSH.
      MissingVersionError: The version specified does not exist.
      MissingInstanceError: The instance specified does not exist.
      UnattendedPromptError: Not running in a tty.
      OperationCancelledError: User cancelled the operation.
      ssh.CommandError: The SCP command exited with SCP exit code, which
        usually implies that a connection problem occurred.

    Returns:
      int, The exit code of the SCP command.
    """
        api_client = appengine_api_client.GetApiClientForTrack(
            self.ReleaseTrack())
        env = ssh.Environment.Current()
        env.RequireSSH()
        keys = ssh.Keys.FromFilename()
        keys.EnsureKeysExist(overwrite=False)

        # Make sure we have a unique remote
        dst = ssh.FileReference.FromPath(args.destination)
        srcs = [ssh.FileReference.FromPath(source) for source in args.sources]
        ssh.SCPCommand.Verify(srcs, dst, single_remote=True)

        remote = dst.remote or srcs[0].remote
        if not dst.remote:  # Make sure all remotes point to the same ref
            for src in srcs:
                src.remote = remote

        connection_details = ssh_common.PopulatePublicKey(
            api_client, args.service, args.version, remote.host,
            keys.GetPublicKey(), self.ReleaseTrack())

        # Update all remote references
        remote.host = connection_details.remote.host
        remote.user = connection_details.remote.user

        cmd = ssh.SCPCommand(srcs,
                             dst,
                             identity_file=keys.key_file,
                             compress=args.compress,
                             recursive=args.recurse,
                             options=connection_details.options)
        return cmd.Run(env)
Ejemplo n.º 2
0
  def Run(self, args):
    connection_info = util.PrepareEnvironment(args)
    remote = ssh.Remote(host=connection_info.host, user=connection_info.user)
    command = ssh.SCPCommand(
        sources=[ToFileReference(src, remote) for src in args.sources],
        destination=ToFileReference(args.destination, remote),
        recursive=args.recurse,
        compress=False,
        port=str(connection_info.port),
        identity_file=connection_info.key,
        extra_flags=args.scp_flag,
        options={'StrictHostKeyChecking': 'no'},
    )

    if args.dry_run:
      log.Print(' '.join(command.Build(connection_info.ssh_env)))
    else:
      command.Run(connection_info.ssh_env)
Ejemplo n.º 3
0
  def Run(self, args, port=None, recursive=False, extra_flags=None):
    """SCP files between local and remote GCE instance.

    Run this method from subclasses' Run methods.

    Args:
      args: argparse.Namespace, the args the command was invoked with.
      port: str, int or None, Port number to use for SSH connection.
      recursive: bool, Whether to use recursive copying using -R flag.
      extra_flags: [str] or None, extra flags to add to command invocation.

    Raises:
      ssh_utils.NetworkError: Network issue which likely is due to failure
        of SSH key propagation.
      ssh.CommandError: The SSH command exited with SSH exit code, which
        usually implies that a connection problem occurred.
    """
    super(BaseScpCommand, self).Run(args)

    dst = ssh.FileReference.FromPath(args.destination)
    srcs = [ssh.FileReference.FromPath(src) for src in args.sources]

    # Make sure we have a unique remote
    ssh.SCPCommand.Verify(srcs, dst, single_remote=True)

    remote = dst.remote or srcs[0].remote
    if not dst.remote:  # Make sure all remotes point to the same ref
      for src in srcs:
        src.remote = remote

    instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
        [remote.host], compute_scope.ScopeEnum.ZONE, args.zone, self.resources,
        scope_lister=flags.GetDefaultScopeLister(self.compute_client))[0]
    instance = self.GetInstance(instance_ref)

    # Now replace the instance name with the actual IP/hostname
    remote.host = ssh_utils.GetExternalIPAddress(instance)
    if not remote.user:
      remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True)

    identity_file = None
    options = None
    if not args.plain:
      identity_file = self.keys.key_file
      options = self.GetConfig(ssh_utils.HostKeyAlias(instance),
                               args.strict_host_key_checking)

    cmd = ssh.SCPCommand(
        srcs, dst, identity_file=identity_file, options=options,
        recursive=recursive, port=port, extra_flags=extra_flags)

    if args.dry_run:
      log.out.Print(' '.join(cmd.Build(self.env)))
      return

    if args.plain:
      keys_newly_added = False
    else:
      keys_newly_added = self.EnsureSSHKeyExists(
          remote.user, instance, instance_ref.project,
          use_account_service=self._use_account_service)

    if keys_newly_added:
      poller = ssh.SSHPoller(
          remote, identity_file=identity_file, options=options,
          max_wait_ms=ssh_utils.SSH_KEY_PROPAGATION_TIMEOUT_SEC)
      log.status.Print('Waiting for SSH key to propagate.')
      # TODO(b/35355795): Don't force_connect
      try:
        poller.Poll(self.env, force_connect=True)
      except retry.WaitException:
        raise ssh_utils.NetworkError()
    return_code = cmd.Run(self.env, force_connect=True)
    if return_code:
      # Can't raise an exception because we don't want any "ERROR" message
      # printed; the output from `ssh` will be enough.
      sys.exit(return_code)
Ejemplo n.º 4
0
    def RunScp(self,
               compute_holder,
               args,
               port=None,
               recursive=False,
               compress=False,
               extra_flags=None,
               release_track=None,
               ip_type=ip.IpTypeEnum.EXTERNAL):
        """SCP files between local and remote GCE instance.

    Run this method from subclasses' Run methods.

    Args:
      compute_holder: The ComputeApiHolder.
      args: argparse.Namespace, the args the command was invoked with.
      port: str or None, Port number to use for SSH connection.
      recursive: bool, Whether to use recursive copying using -R flag.
      compress: bool, Whether to use compression.
      extra_flags: [str] or None, extra flags to add to command invocation.
      release_track: obj, The current release track.
      ip_type: IpTypeEnum, Specify using internal ip or external ip address.

    Raises:
      ssh_utils.NetworkError: Network issue which likely is due to failure
        of SSH key propagation.
      ssh.CommandError: The SSH command exited with SSH exit code, which
        usually implies that a connection problem occurred.
    """
        if release_track is None:
            release_track = base.ReleaseTrack.GA
        super(BaseScpHelper, self).Run(args)

        dst = ssh.FileReference.FromPath(args.destination)
        srcs = [ssh.FileReference.FromPath(src) for src in args.sources]

        # Make sure we have a unique remote
        ssh.SCPCommand.Verify(srcs, dst, single_remote=True)

        remote = dst.remote or srcs[0].remote
        if not dst.remote:  # Make sure all remotes point to the same ref
            for src in srcs:
                src.remote = remote

        instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
            [remote.host],
            compute_scope.ScopeEnum.ZONE,
            args.zone,
            compute_holder.resources,
            scope_lister=instance_flags.GetInstanceZoneScopeLister(
                compute_holder.client))[0]
        instance = self.GetInstance(compute_holder.client, instance_ref)
        project = self.GetProject(compute_holder.client, instance_ref.project)

        if not remote.user:
            remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True)
        if args.plain:
            use_oslogin = False
        else:
            public_key = self.keys.GetPublicKey().ToEntry(include_comment=True)
            remote.user, use_oslogin = ssh.CheckForOsloginAndGetUser(
                instance, project, remote.user, public_key, release_track)

        identity_file = None
        options = None
        if not args.plain:
            identity_file = self.keys.key_file
            options = self.GetConfig(ssh_utils.HostKeyAlias(instance),
                                     args.strict_host_key_checking)

        iap_tunnel_args = iap_tunnel.SshTunnelArgs.FromArgs(
            args, release_track, instance_ref,
            ssh_utils.GetInternalInterface(instance),
            ssh_utils.GetExternalInterface(instance, no_raise=True))

        if iap_tunnel_args:
            remote.host = ssh_utils.HostKeyAlias(instance)
        elif ip_type is ip.IpTypeEnum.INTERNAL:
            remote.host = ssh_utils.GetInternalIPAddress(instance)
        else:
            remote.host = ssh_utils.GetExternalIPAddress(instance)

        cmd = ssh.SCPCommand(srcs,
                             dst,
                             identity_file=identity_file,
                             options=options,
                             recursive=recursive,
                             compress=compress,
                             port=port,
                             extra_flags=extra_flags,
                             iap_tunnel_args=iap_tunnel_args)

        if args.dry_run:
            log.out.Print(' '.join(cmd.Build(self.env)))
            return

        if args.plain or use_oslogin:
            keys_newly_added = False
        else:
            keys_newly_added = self.EnsureSSHKeyExists(compute_holder.client,
                                                       remote.user, instance,
                                                       project)

        if keys_newly_added:
            poller = ssh_utils.CreateSSHPoller(remote,
                                               identity_file,
                                               options,
                                               iap_tunnel_args,
                                               port=port)

            log.status.Print('Waiting for SSH key to propagate.')
            # TODO(b/35355795): Don't force_connect
            try:
                poller.Poll(self.env, force_connect=True)
            except retry.WaitException:
                raise ssh_utils.NetworkError()

        if ip_type is ip.IpTypeEnum.INTERNAL:
            # This will never happen when IAP Tunnel is enabled, because ip_type is
            # always EXTERNAL when IAP Tunnel is enabled, even if the instance has no
            # external IP. IAP Tunnel doesn't need verification because it uses
            # unambiguous identifiers for the instance.
            self.PreliminarilyVerifyInstance(instance.id, remote,
                                             identity_file, options)

        # Errors from the SCP command result in an ssh.CommandError being raised
        cmd.Run(self.env, force_connect=True)
Ejemplo n.º 5
0
    def Run(self, args):
        dst = ssh.FileReference.FromPath(args.destination)
        srcs = [ssh.FileReference.FromPath(src) for src in args.sources]
        ssh.SCPCommand.Verify(srcs, dst, single_remote=True)
        if dst.remote:
            tpu_name = dst.remote.host
        else:
            tpu_name = srcs[0].remote.host

        # If zone is not set, retrieve the one from the config.
        if args.zone is None:
            args.zone = properties.VALUES.compute.zone.Get(required=True)

        # Retrieve the node.
        tpu = tpu_utils.TPUNode(self.ReleaseTrack())
        node = tpu.Get(tpu_name, args.zone)
        if not tpu_utils.IsTPUVMNode(node):
            raise exceptions.BadArgumentException(
                'TPU',
                'this command is only available for Cloud TPU VM nodes. To access '
                'this node, please see '
                'https://cloud.google.com/tpu/docs/creating-deleting-tpus.')

        worker_ips = tpu_ssh_utils.ParseWorkerFlag(args.worker,
                                                   node.networkEndpoints,
                                                   args.internal_ip)

        if len(worker_ips) > 1 and srcs[0].remote:
            raise exceptions.InvalidArgumentException(
                '--worker',
                'cannot target multiple workers while copying files to '
                'client.')

        tpu_ssh_utils.ValidateTPUState(node.state,
                                       tpu.messages.Node.StateValueValuesEnum)

        # Retrieve GuestAttributes.
        single_pod_worker = len(
            node.networkEndpoints) > 1 and len(worker_ips) == 1
        if single_pod_worker:
            # Retrieve only that worker's GuestAttributes.
            worker_id = list(worker_ips)[0]
            guest_attributes_response = tpu.GetGuestAttributes(
                tpu_name, args.zone, six.text_type((worker_id)))
            host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes(
                guest_attributes_response.guestAttributes,
                len(node.networkEndpoints), worker_id)
        else:
            # Retrieve the GuestAttributes for all workers in that TPU.
            guest_attributes_response = tpu.GetGuestAttributes(
                tpu_name, args.zone)
            host_key_suffixes = tpu_ssh_utils.GetHostKeySuffixes(
                guest_attributes_response.guestAttributes)

        # Generate the public key.
        ssh_helper = ssh_utils.BaseSSHCLIHelper()
        ssh_helper.Run(args)
        public_key = ssh_helper.keys.GetPublicKey().ToEntry()

        remote = dst.remote or srcs[0].remote
        if not dst.remote:  # Make sure all remotes point to the same ref.
            for src in srcs:
                src.remote = remote

        if remote.user:
            username_requested = True
        else:
            username_requested = False
            remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True)

        project = tpu_utils.GetProject(self.ReleaseTrack(), ssh_helper)

        if not args.plain:
            # If there is an '@' symbol in the user_host arg, the user is requesting
            # to connect as a specific user. This may get overridden by OS Login.
            _, expiration_micros = ssh_utils.GetSSHKeyExpirationFromArgs(args)
            oslogin_state = ssh.GetOsloginState(
                None,
                project,
                remote.user,
                public_key,
                expiration_micros,
                self.ReleaseTrack(),
                username_requested=username_requested,
                instance_enable_oslogin=tpu_ssh_utils.TpuHasOsLoginEnabled(
                    node))
            remote.user = oslogin_state.user

        # Format the key correctly.
        public_key = '{1}:{0} {1}'.format(public_key, remote.user)
        if not args.plain and not args.dry_run:
            tpu_ssh_utils.AddSSHKeyIfNeeded(project, tpu, node, tpu_name,
                                            args.zone, public_key)

        identity_file = None
        if not args.plain:
            identity_file = ssh_helper.keys.key_file
            # If the user's key is not in the SSH agent, the command will stall. We
            # want to verify it is added before proceeding, and raise an error if it
            # is not.
            if not args.dry_run and len(worker_ips) > 1:
                tpu_ssh_utils.VerifyKeyInAgent(identity_file)

        extra_flags = []

        if args.scp_flag:
            extra_flags.extend(args.scp_flag)

        instance_names = {}
        if (args.IsKnownAndSpecified('tunnel_through_iap')
                and args.tunnel_through_iap):
            # Retrieve the instance names from the GuestAttributes.
            for worker in worker_ips:
                # The GuestAttributes will only have one entry if we're targeting a
                # single worker.
                index = 0 if single_pod_worker else worker
                instance_name = tpu_ssh_utils.GetFromGuestAttributes(
                    guest_attributes_response.guestAttributes, index,
                    'hostname')
                if instance_name is None:
                    log.status.Print('Failed to connect to TPU.')
                    log.status.Print(tpu_ssh_utils.IAP_TROUBLESHOOTING_HELP)
                    raise tpu_exceptions.IapTunnelingUnavailable()
                instance_names[worker] = instance_name

        ssh_threads = []
        exit_statuses = [None] * len(worker_ips)
        for worker, ips in worker_ips.items():
            options = None
            if not args.plain:
                options = ssh_helper.GetConfig(
                    tpu_ssh_utils.GetInstanceID(node.id, worker,
                                                host_key_suffixes),
                    args.strict_host_key_checking, None)

            iap_tunnel_args = None
            if (args.IsKnownAndSpecified('tunnel_through_iap')
                    and args.tunnel_through_iap):
                # Retrieve the instance name from the GuestAttributes.
                instance_name = instance_names[worker]
                iap_tunnel_args = tpu_ssh_utils.CreateSshTunnelArgs(
                    args, self.ReleaseTrack(), project, args.zone,
                    instance_name)

            remote.host = ips.ip_address
            cmd = ssh.SCPCommand(srcs,
                                 dst,
                                 identity_file=identity_file,
                                 options=options,
                                 recursive=args.recurse,
                                 compress=args.compress,
                                 extra_flags=extra_flags,
                                 iap_tunnel_args=iap_tunnel_args)

            if args.dry_run:
                log.out.Print(' '.join(cmd.Build(ssh_helper.env)))
                continue

            if len(worker_ips) > 1:
                # Run the command on multiple workers concurrently.
                ssh_threads.append(
                    threading.Thread(
                        target=tpu_ssh_utils.AttemptRunWithRetries,
                        args=('SCP', worker, exit_statuses, cmd,
                              ssh_helper.env, None, True, SCPRunCmd)))
                ssh_threads[-1].start()
            else:
                # Run on a single worker.
                tpu_ssh_utils.AttemptRunWithRetries('SCP', worker,
                                                    exit_statuses, cmd,
                                                    ssh_helper.env, None,
                                                    False, SCPRunCmd)

        if len(worker_ips) > 1:
            # Wait for all the threads to complete.
            for i in range(len(ssh_threads)):
                ssh_threads[i].join()

            # Exit with a nonzero status, if any.
            # This ensures that if any command failed on a worker, we don't end up
            # returning 0 for a value.
            for status in exit_statuses:
                if status:
                    sys.exit(status)
Ejemplo n.º 6
0
    def RunScp(self,
               compute_holder,
               args,
               port=None,
               recursive=False,
               compress=False,
               extra_flags=None,
               release_track=None,
               ip_type=ip.IpTypeEnum.EXTERNAL):
        """SCP files between local and remote GCE instance.

    Run this method from subclasses' Run methods.

    Args:
      compute_holder: The ComputeApiHolder.
      args: argparse.Namespace, the args the command was invoked with.
      port: str or None, Port number to use for SSH connection.
      recursive: bool, Whether to use recursive copying using -R flag.
      compress: bool, Whether to use compression.
      extra_flags: [str] or None, extra flags to add to command invocation.
      release_track: obj, The current release track.
      ip_type: IpTypeEnum, Specify using internal ip or external ip address.

    Raises:
      ssh_utils.NetworkError: Network issue which likely is due to failure
        of SSH key propagation.
      ssh.CommandError: The SSH command exited with SSH exit code, which
        usually implies that a connection problem occurred.
    """
        if release_track is None:
            release_track = base.ReleaseTrack.GA
        super(BaseScpHelper, self).Run(args)

        dst = ssh.FileReference.FromPath(args.destination)
        srcs = [ssh.FileReference.FromPath(src) for src in args.sources]

        # Make sure we have a unique remote
        ssh.SCPCommand.Verify(srcs, dst, single_remote=True)

        remote = dst.remote or srcs[0].remote
        if not dst.remote:  # Make sure all remotes point to the same ref
            for src in srcs:
                src.remote = remote

        instance_ref = instance_flags.SSH_INSTANCE_RESOLVER.ResolveResources(
            [remote.host],
            compute_scope.ScopeEnum.ZONE,
            args.zone,
            compute_holder.resources,
            scope_lister=instance_flags.GetInstanceZoneScopeLister(
                compute_holder.client))[0]
        instance = self.GetInstance(compute_holder.client, instance_ref)
        project = self.GetProject(compute_holder.client, instance_ref.project)

        # Now replace the instance name with the actual IP/hostname
        if ip_type is ip.IpTypeEnum.INTERNAL:
            remote.host = ssh_utils.GetInternalIPAddress(instance)
        else:
            remote.host = ssh_utils.GetExternalIPAddress(instance)
        if not remote.user:
            remote.user = ssh.GetDefaultSshUsername(warn_on_account_user=True)
        if args.plain:
            use_oslogin = False
        else:
            public_key = self.keys.GetPublicKey().ToEntry(include_comment=True)
            remote.user, use_oslogin = ssh.CheckForOsloginAndGetUser(
                instance, project, remote.user, public_key, release_track)

        identity_file = None
        options = None
        if not args.plain:
            identity_file = self.keys.key_file
            options = self.GetConfig(ssh_utils.HostKeyAlias(instance),
                                     args.strict_host_key_checking)

        tunnel_helper = None
        cmd_port = port
        interface = None
        if hasattr(args, 'tunnel_through_iap') and args.tunnel_through_iap:
            tunnel_helper, interface = ssh_utils.CreateIapTunnelHelper(
                args, instance_ref, instance, ip_type, port=port)
            tunnel_helper.StartListener()
            cmd_port = str(tunnel_helper.GetLocalPort())
            if dst.remote:
                dst.remote.host = 'localhost'
            else:
                for src in srcs:
                    src.remote.host = 'localhost'

        cmd = ssh.SCPCommand(srcs,
                             dst,
                             identity_file=identity_file,
                             options=options,
                             recursive=recursive,
                             compress=compress,
                             port=cmd_port,
                             extra_flags=extra_flags)

        if args.dry_run:
            log.out.Print(' '.join(cmd.Build(self.env)))
            if tunnel_helper:
                tunnel_helper.StopListener()
            return

        if args.plain or use_oslogin:
            keys_newly_added = False
        else:
            keys_newly_added = self.EnsureSSHKeyExists(compute_holder.client,
                                                       remote.user, instance,
                                                       project)

        if keys_newly_added:
            poller_tunnel_helper = None
            if tunnel_helper:
                poller_tunnel_helper, _ = ssh_utils.CreateIapTunnelHelper(
                    args,
                    instance_ref,
                    instance,
                    ip_type,
                    port=port,
                    interface=interface)
                poller_tunnel_helper.StartListener(
                    accept_multiple_connections=True)
            poller = ssh_utils.CreateSSHPoller(remote,
                                               identity_file,
                                               options,
                                               poller_tunnel_helper,
                                               port=port)

            log.status.Print('Waiting for SSH key to propagate.')
            # TODO(b/35355795): Don't force_connect
            try:
                poller.Poll(self.env, force_connect=True)
            except retry.WaitException:
                if tunnel_helper:
                    tunnel_helper.StopListener()
                raise ssh_utils.NetworkError()
            finally:
                if poller_tunnel_helper:
                    poller_tunnel_helper.StopListener()

        if ip_type is ip.IpTypeEnum.INTERNAL and not tunnel_helper:
            # The IAP Tunnel connection uses instance name and network interface name,
            # so do not need to additionally verify the instance. Also, the
            # SSHCommand used within the function does not support IAP Tunnels.
            self.PreliminarilyVerifyInstance(instance.id, remote,
                                             identity_file, options)

        try:
            # Errors from the SCP command result in an ssh.CommandError being raised
            cmd.Run(self.env, force_connect=True)
        finally:
            if tunnel_helper:
                tunnel_helper.StopListener()