def testStableTensorflowVersionListerSingleElement(self): node_helper = tpu_utils.TPUNode(self.track) self.ExpectTensorflowVersionList('fake-project', 'central2-a', [['fake-tf-name', '1.5']]) stable_tf_version = node_helper.LatestStableTensorflowVersion( 'central2-a') self.assertEqual(stable_tf_version, '1.5')
def testNodeIsRunning(self): node_helper = tpu_utils.TPUNode(self.track) node = self.tpu_messages.Node() node.state = self.tpu_messages.Node.StateValueValuesEnum.READY self.assertTrue(node_helper.IsRunning(node), True) node.state = self.tpu_messages.Node.StateValueValuesEnum.CREATING node.ipAddress = 'fake-address' self.assertTrue(node_helper.IsRunning(node))
def Run(self, args): tpu_utils.DefaultArgs.ValidateZone(args) responses = [] instance_helper = tpu_utils.Instance(self.ReleaseTrack()) try: instance = instance_helper.Get(args.execution_group_name, args.zone) except HttpNotFoundError: # As it stands, we provide vm-only option but no tpu-only option. So if # there is no VM, then we can safely short-circuit and claim the # execution group is not found. responses.append(GetResult('Execution Group Status:', 'Not Found')) return responses responses.append( GetResult( 'Compute Engine Instance IP Address:', instance.networkInterfaces and instance.networkInterfaces[0] and instance.networkInterfaces[0].networkIP)) responses.append( GetResult('Compute Engine Created:', instance.creationTimestamp)) responses.append( GetResult('Compute Engine Machine Type:', instance.machineType)) node_helper = tpu_utils.TPUNode(self.ReleaseTrack()) node = None try: node = node_helper.Get(args.execution_group_name, args.zone) except HttpNotFoundError: responses.append(GetResult('TPU Node status:', 'Not Found')) if node: responses.append( GetResult('TPU Accelerator Type:', node.acceleratorType)) responses.append( GetResult( 'TPU IP Address:', node.networkEndpoints and node.networkEndpoints[0] and node.networkEndpoints[0].ipAddress)) responses.append( GetResult('TPU TF Version:', node.tensorflowVersion)) responses.append( GetResult('TPU Service Account:', node.serviceAccount)) responses.append(GetResult('TPU Created:', node.createTime)) responses.append(GetResult('TPU State:', node.state)) responses.append(GetResult('TPU Health:', node.health)) responses.append( GetResult( 'TPU Preemptible:', node.schedulingConfig and node.schedulingConfig.preemptible)) return responses
def testNodeNameNoMatch(self): node_helper = tpu_utils.TPUNode(self.track) node = self.tpu_messages.Node() name_list = [ '', 'projects/fake-project/locations/fake-node', 'fake-project/locations/fake-location/nodes/fake-node', ] for name in name_list: node.name = name got = node_helper.NodeName(node) self.assertEqual(got, '')
def testNodeIsNotRunning(self): node_helper = tpu_utils.TPUNode(self.track) node = self.tpu_messages.Node() state_enum = self.tpu_messages.Node.StateValueValuesEnum for state in state_enum: if state not in [state_enum.READY, state_enum.CREATING]: node.state = state self.assertFalse(node_helper.IsRunning(node)) node.state = state_enum.CREATING node.ipAddress = '' self.assertFalse(node_helper.IsRunning(node))
def testStableTensorflowVersionListerMultipleElements(self): node_helper = tpu_utils.TPUNode(self.track) self.ExpectTensorflowVersionList( 'fake-project', 'central2-a', [['fake-tf-1.5', '1.5'], ['fake-tf-1.6', '1.6']]) stable_tf_version = node_helper.LatestStableTensorflowVersion( 'central2-a') self.assertEqual(stable_tf_version, '1.6') # Invert order. self.ExpectTensorflowVersionList( 'fake-project', 'central2-a', [['fake-tf-1.6', '1.6'], ['fake-tf-1.5', '1.5']]) stable_tf_version = node_helper.LatestStableTensorflowVersion( 'central2-a') self.assertEqual(stable_tf_version, '1.6') self.ExpectTensorflowVersionList( 'fake-project', 'central2-a', [['fake-tf-1.6', '1.6'], ['fake-tf-1.5', '1.5'], ['fake-tf-1.7', '1.7']]) stable_tf_version = node_helper.LatestStableTensorflowVersion( 'central2-a') self.assertEqual(stable_tf_version, '1.7') self.ExpectTensorflowVersionList('fake-project', 'central2-a', [ ['fake-tf-1.6', '1.6'], ['fake-tf-nightly', 'nightly'], ]) stable_tf_version = node_helper.LatestStableTensorflowVersion( 'central2-a') self.assertEqual(stable_tf_version, '1.6') self.ExpectTensorflowVersionList('fake-project', 'central2-a', [ ['fake-tf-1.7-RC0', '1.7-RC0'], ['fake-tf-1.7', '1.7'], ['fake-tf-nightly', 'nightly'], ]) stable_tf_version = node_helper.LatestStableTensorflowVersion( 'central2-a') self.assertEqual(stable_tf_version, '1.7') self.ExpectTensorflowVersionList('fake-project', 'central2-a', [ ['fake-tf-nightly-20180201', 'nightly-20180201'], ['fake-tf-nightly', 'nightly'], ['fake-tf-1.7-RC0', '1.7-RC0'], ['fake-tf-1.6', '1.6'], ['fake-tf-1.5', '1.5'], ]) stable_tf_version = node_helper.LatestStableTensorflowVersion( 'central2-a') self.assertEqual(stable_tf_version, '1.6')
def testStableTensorflowVersionNotFound(self): node_helper = tpu_utils.TPUNode(self.track) self.ExpectTensorflowVersionList('fake-project', 'central2-a', [ ['fake-tf-nightly-20180201', 'nightly-20180201'], ]) with self.assertRaisesRegex(HttpNotFoundError, 'No stable release found'): node_helper.LatestStableTensorflowVersion('central2-a') self.ExpectTensorflowVersionList('fake-project', 'central2-a', [['fake-tf-nightly', 'nightly']]) with self.assertRaisesRegex(HttpNotFoundError, 'No stable release found'): node_helper.LatestStableTensorflowVersion('central2-a')
def Run(self, args): responses = [] tpu_operation_ref = None instance_operation_ref = None tpu_utils.DefaultArgs.ValidateZone(args) if not args.tpu_only: instance = tpu_utils.Instance(self.ReleaseTrack()) try: instance_operation_ref = instance.Delete(args.execution_group_name, args.zone) except HttpNotFoundError: log.status.Print( 'Instance:{} not found, possibly already deleted.'.format( args.execution_group_name)) tpu = tpu_utils.TPUNode(self.ReleaseTrack()) try: tpu_operation_ref = tpu.Delete(args.execution_group_name, args.zone) except HttpNotFoundError: log.status.Print( 'TPU Node:{} not found, possibly already deleted.'.format( args.execution_group_name)) if instance_operation_ref: try: instance_delete_response = instance.WaitForOperationNoResources( instance_operation_ref, 'Deleting GCE VM') responses.append(instance_delete_response) except HttpNotFoundError: log.status.Print( 'Instance:{} not found, possibly already deleted.'.format( args.execution_group_name)) if tpu_operation_ref: try: responses.append( tpu.WaitForOperationNoResources( tpu_operation_ref, 'Deleting TPU node')) except HttpNotFoundError: log.status.Print( 'TPU Node:{} not found, possibly already deleted.'.format( args.execution_group_name)) return responses
def Run(self, args): responses = [] instances = {} nodes = {} instance_helper = tpu_utils.Instance(self.ReleaseTrack()) for instance in instance_helper.List(args.zone): instances[instance.name] = instance node_helper = tpu_utils.TPUNode(self.ReleaseTrack()) for node in node_helper.List(args.zone): nodes[node_helper.NodeName(node)] = node for name, instance in instances.items(): if name not in nodes.keys(): responses.append(ListResult(name, 'Paused')) elif instance_helper.IsRunning(instance) and node_helper.IsRunning( nodes[name]): responses.append(ListResult(name, 'Running')) else: responses.append(ListResult(name, 'Unknown Status')) return sorted(responses)
def Run(self, args): tpu_utils.DefaultArgs.ValidateZone(args) responses = [] tpu = tpu_utils.TPUNode(self.ReleaseTrack()) tpu_operation_ref = None instance_operation_ref = None if not args.vm_only: try: tpu_operation_ref = tpu.Create(args.execution_group_name, args.accelerator_type, args.tf_version, args.zone, args.preemptible, args.network) except HttpConflictError: log.status.Print('TPU Node with name:{} already exists, ' 'try a different name'.format( args.execution_group_name)) return responses instance = tpu_utils.Instance(self.ReleaseTrack()) try: instance_operation_ref = instance.Start(args.execution_group_name, args.zone) except HttpNotFoundError: log.status.Print('Instance:{} not found, possibly deleted.'.format( args.execution_group_name)) return responses if instance_operation_ref: instance_start_response = instance.WaitForOperation( instance_operation_ref, 'Starting GCE VM') responses.append(instance_start_response) if tpu_operation_ref: responses.append( tpu.WaitForOperation( tpu_operation_ref, 'Creating TPU node:{}'.format(args.execution_group_name))) return responses
def Run(self, args): responses = [] if args.dry_run: self.DryRun(args) return responses tpu = tpu_utils.TPUNode(self.ReleaseTrack()) if not args.tf_version: try: args.tf_version = tpu.LatestStableTensorflowVersion(args.zone) except HttpNotFoundError: log.err.Print( 'Could not find stable Tensorflow version, please ' 'set tensorflow version flag using --tf-version') return responses if not args.vm_only: try: tpu_operation_ref = tpu.Create(args.name, args.accelerator_type, args.tf_version, args.zone, args.preemptible, args.network) except HttpConflictError: log.err.Print('TPU Node with name:{} already exists, ' 'try a different name'.format(args.name)) return responses if not args.tpu_only: instance = tpu_utils.Instance(self.ReleaseTrack()) gce_image = args.gce_image if not gce_image: gce_image = instance.ResolveImageFromTensorflowVersion( args.tf_version, 'ml-images', args.use_dl_images) try: instance_operation_ref = instance.Create( args.name, args.zone, args.machine_type, utils.BytesToGb(args.disk_size), args.preemptible_vm, gce_image, args.network) except HttpConflictError: err_msg = ('VM with name:{} already exists, ' 'try a different name.').format(args.name) if not args.vm_only: err_msg += (' TPU Node:{} creation is underway and will ' 'need to be deleted.'.format(args.name)) log.err.Print(err_msg) return responses if not args.vm_only: responses.append( tpu.WaitForOperation(tpu_operation_ref, 'Creating TPU node:{}'.format(args.name))) if not args.tpu_only: instance_create_response = instance.WaitForOperation( instance_operation_ref, 'Creating GCE VM:{}'.format(args.name)) responses.append(instance_create_response) ssh_helper = tpu_utils.SSH(self.ReleaseTrack()) responses.append( ssh_helper.SSHToInstance(args, instance_create_response)) return responses
def testNodeNameMatch(self): node_helper = tpu_utils.TPUNode(self.track) node = self.tpu_messages.Node() node.name = 'projects/fake-project/locations/fake-location/nodes/fake-node' got = node_helper.NodeName(node) self.assertEqual(got, 'fake-node')