def test_delete_orphaned_volumes(self): """ Test that we clean up instance volumes that are orphaned by AWS. """ aws_svc, encryptor_image, guest_image = _build_aws_service() encrypt_ami.SLEEP_ENABLED = False # Simulate a tagged orphaned volume. volume = Volume() volume.id = _new_id() aws_svc.volumes[volume.id] = volume aws_svc.tagged_volumes.append(volume) # Verify that lookup succeeds before encrypt(). self.assertEqual(volume, aws_svc.get_volume(volume.id)) self.assertEqual( [volume], aws_svc.get_volumes( tag_key=encrypt_ami.TAG_ENCRYPTOR_SESSION_ID, tag_value='123') ) encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=DummyEncryptorService, image_id=guest_image.id, encryptor_ami=encryptor_image.id ) # Verify that the volume was deleted. self.assertIsNone(aws_svc.get_volume(volume.id))
def test_terminate_guest(self): """ Test that we terminate the guest instance if an exception is raised while waiting for it to come up. """ self.terminate_instance_called = False self.instance_id = None class TestException(Exception): pass def get_instance_callback(instance): self.instance_id = instance.id raise TestException('Test') def terminate_instance_callback(instance): self.terminate_instance_called = True self.assertEqual(self.instance_id, instance.id) aws_svc, encryptor_image, guest_image = _build_aws_service() aws_svc.get_instance_callback = get_instance_callback aws_svc.terminate_instance_callback = terminate_instance_callback try: encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=DummyEncryptorService, image_id=guest_image.id, encryptor_ami=encryptor_image.id ) except TestException: pass self.assertTrue(self.terminate_instance_called)
def test_instance_type(self): """ Test that we launch the guest as m3.medium and the encryptor as c3.xlarge. """ self.call_count = 0 def run_instance_callback(instance_type, ebs_optimized, security_group_ids, subnet_id): self.call_count += 1 if self.call_count == 1: self.assertEqual('m3.medium', instance_type) self.assertFalse(ebs_optimized) elif self.call_count == 2: self.assertEqual('c3.xlarge', instance_type) self.assertTrue(ebs_optimized) else: self.fail('Unexpected number of calls to run_instance()') aws_svc, encryptor_image, guest_image = _build_aws_service() aws_svc.run_instance_callback = run_instance_callback encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=DummyEncryptorService, image_id=guest_image.id, encryptor_ami=encryptor_image.id )
def test_subnet_without_security_groups(self): """ Test that we create the temporary security group in the subnet that the user specified. """ self.security_group_was_created = False def create_security_group_callback(vpc_id): self.security_group_was_created = True self.assertEqual('vpc-1', vpc_id) aws_svc, encryptor_image, guest_image = _build_aws_service() aws_svc.create_security_group_callback = \ create_security_group_callback subnet = Subnet() subnet.id = 'subnet-1' subnet.vpc_id = 'vpc-1' aws_svc.subnets = {subnet.id: subnet} encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=DummyEncryptorService, image_id=guest_image.id, encryptor_ami=encryptor_image.id, subnet_id='subnet-1' ) self.assertTrue(self.security_group_was_created)
def test_subnet_with_security_groups(self): """ Test that the subnet and security groups are passed to the calls to AWSService.run_instance(). """ self.call_count = 0 def run_instance_callback(instance_type, ebs_optimized, security_group_ids, subnet_id): self.call_count += 1 self.assertEqual('subnet-1', subnet_id) if self.call_count == 1: # Snapshotter. self.assertIsNone(security_group_ids) elif self.call_count == 2: # Encryptor. self.assertEqual(['sg-1', 'sg-2'], security_group_ids) else: self.fail('Unexpected number of calls to run_instance()') aws_svc, encryptor_image, guest_image = _build_aws_service() aws_svc.run_instance_callback = run_instance_callback encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=DummyEncryptorService, image_id=guest_image.id, encryptor_ami=encryptor_image.id, subnet_id='subnet-1', security_group_ids=['sg-1', 'sg-2'] )
def command_encrypt_ami(values, log): session_id = util.make_nonce() encryptor_ami = ( values.encryptor_ami or encrypt_ami.get_encryptor_ami(values.region) ) default_tags = encrypt_ami.get_default_tags(session_id, encryptor_ami) aws_svc = aws_service.AWSService( session_id, default_tags=default_tags) _connect_and_validate(aws_svc, values, encryptor_ami) log.info('Starting encryptor session %s', aws_svc.session_id) encrypted_image_id = encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=encryptor_service.EncryptorService, image_id=values.ami, encryptor_ami=encryptor_ami, encrypted_ami_name=values.encrypted_ami_name, subnet_id=values.subnet_id, security_group_ids=values.security_group_ids, brkt_env=values.brkt_env ) # Print the AMI ID to stdout, in case the caller wants to process # the output. Log messages go to stderr. print(encrypted_image_id) return 0
def test_subnet_and_security_groups(self): """ Test that the subnet and security group ids are passed through to run_instance(). """ aws_svc, encryptor_image, guest_image = _build_aws_service() encrypt_ami.SLEEP_ENABLED = False encrypted_ami_id = encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=DummyEncryptorService, image_id=guest_image.id, brkt_env=None, encryptor_ami=encryptor_image.id ) self.call_count = 0 def run_instance_callback(instance_type, ebs_optimized, security_group_ids, subnet_id): self.call_count += 1 self.assertEqual('subnet-1', subnet_id) self.assertEqual(['sg-1', 'sg-2'], security_group_ids) aws_svc.run_instance_callback = run_instance_callback ami_id = update_ami( aws_svc, encrypted_ami_id, encryptor_image.id, 'Test updated AMI', subnet_id='subnet-1', security_group_ids=['sg-1', 'sg-2'], enc_svc_class=DummyEncryptorService ) self.assertEqual(2, self.call_count) self.assertIsNotNone(ami_id)
def test_encryption_error_console_output_not_available(self): """ Test that we handle the case when encryption fails and console output is not available. """ aws_svc, encryptor_image, guest_image = _build_aws_service() encrypt_ami.SLEEP_ENABLED = False aws_svc.console_output_text = None try: encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=FailedEncryptionService, image_id=guest_image.id, encryptor_ami=encryptor_image.id ) self.fail('Encryption should have failed') except encrypt_ami.EncryptionError as e: self.assertIsNone(e.console_output_file)
def test_encryption_error_console_output_available(self): """ Test that when an encryption failure occurs, we write the console log to a temp file. """ aws_svc, encryptor_image, guest_image = _build_aws_service() encrypt_ami.SLEEP_ENABLED = False try: encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=FailedEncryptionService, image_id=guest_image.id, encryptor_ami=encryptor_image.id ) self.fail('Encryption should have failed') except encrypt_ami.EncryptionError as e: with open(e.console_output_file.name) as f: content = f.read() self.assertEquals(CONSOLE_OUTPUT_TEXT, content) os.remove(e.console_output_file.name)
def test_smoke(self): """ Run the entire process and test that nothing obvious is broken. """ aws_svc, encryptor_image, guest_image = _build_aws_service() encrypt_ami.SLEEP_ENABLED = False encrypted_ami_id = encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=DummyEncryptorService, image_id=guest_image.id, encryptor_ami=encryptor_image.id ) self.assertIsNotNone(encrypted_ami_id)
def test_encrypted_ami_name(self): """ Test that the name is set on the encrypted AMI when specified. """ aws_svc, encryptor_image, guest_image = _build_aws_service() encrypt_ami.SLEEP_ENABLED = False name = 'Am I an AMI?' image_id = encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=DummyEncryptorService, image_id=guest_image.id, encryptor_ami=encryptor_image.id, encrypted_ami_name=name ) ami = aws_svc.get_image(image_id) self.assertEqual(name, ami.name)
def main(): parser = argparse.ArgumentParser() parser.add_argument( '-v', '--verbose', dest='verbose', action='store_true', help='Print status information to the console' ) parser.add_argument( '--version', action='version', version='brkt-cli version %s' % VERSION ) subparsers = parser.add_subparsers() encrypt_ami_parser = subparsers.add_parser('encrypt-ami') encrypt_ami_args.setup_encrypt_ami_args(encrypt_ami_parser) argv = sys.argv[1:] values = parser.parse_args(argv) region = values.region # Initialize logging. Log messages are written to stderr and are # prefixed with a compact timestamp, so that the user knows how long # each operation took. if values.verbose: log_level = logging.DEBUG else: # Boto logs auth errors and 401s at ERROR level by default. boto.log.setLevel(logging.FATAL) log_level = logging.INFO # Set the log level of our modules explicitly. We can't set the # default log level to INFO because we would see INFO messages from # boto and other 3rd party libraries in the command output. logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%H:%M:%S') global log log = logging.getLogger(__name__) log.setLevel(log_level) aws_service.log.setLevel(log_level) encryptor_service.log.setLevel(log_level) if values.encrypted_ami_name: try: aws_service.validate_image_name(values.encrypted_ami_name) except aws_service.ImageNameError as e: print(e.message, file=sys.stderr) return 1 # Validate the region. regions = [str(r.name) for r in boto.vpc.regions()] if region not in regions: print( 'Invalid region %s. Must be one of %s.' % (region, str(regions)), file=sys.stderr ) return 1 encryptor_ami = values.encryptor_ami if not encryptor_ami: try: encryptor_ami = encrypt_ami.get_encryptor_ami(region) except: log.exception('Failed to get encryptor AMI.') return 1 session_id = util.make_nonce() default_tags = encrypt_ami.get_default_tags(session_id, encryptor_ami) try: # Connect to AWS. aws_svc = aws_service.AWSService( session_id, encryptor_ami, default_tags=default_tags) aws_svc.connect(region, key_name=values.key_name) except NoAuthHandlerFound: msg = ( 'Unable to connect to AWS. Are your AWS_ACCESS_KEY_ID and ' 'AWS_SECRET_ACCESS_KEY environment variables set?' ) if values.verbose: log.exception(msg) else: log.error(msg) return 1 try: if values.key_name: # Validate the key pair name. aws_svc.get_key_pair(values.key_name) if not values.no_validate_ami: error = aws_svc.validate_guest_ami(values.ami) if error: print(error, file=sys.stderr) return 1 error = aws_svc.validate_encryptor_ami(encryptor_ami) if error: print(error, file=sys.stderr) return 1 log.info('Starting encryptor session %s', aws_svc.session_id) encrypted_image_id = encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=encryptor_service.EncryptorService, image_id=values.ami, encryptor_ami=encryptor_ami, encrypted_ami_name=values.encrypted_ami_name ) # Print the AMI ID to stdout, in case the caller wants to process # the output. Log messages go to stderr. print(encrypted_image_id) return 0 except EC2ResponseError as e: if e.error_code == 'AuthFailure': msg = 'Check your AWS login credentials and permissions' if values.verbose: log.exception(msg) else: log.error(msg + ': ' + e.error_message) elif e.error_code == 'InvalidKeyPair.NotFound': if values.verbose: log.exception(e.error_message) else: log.error(e.error_message) elif e.error_code == 'UnauthorizedOperation': if values.verbose: log.exception(e.error_message) else: log.error(e.error_message) log.error( 'Unauthorized operation. Check the IAM policy for your ' 'AWS account.' ) else: raise except util.BracketError as e: if values.verbose: log.exception(e.message) else: log.error(e.message) except KeyboardInterrupt: if values.verbose: log.exception('Interrupted by user') else: log.error('Interrupted by user') return 1
def command_encrypt_ami(values, log): region = values.region # Initialize logging. Log messages are written to stderr and are # prefixed with a compact timestamp, so that the user knows how long # each operation took. if values.verbose: log_level = logging.DEBUG else: # Boto logs auth errors and 401s at ERROR level by default. boto.log.setLevel(logging.FATAL) log_level = logging.INFO if values.encrypted_ami_name: try: aws_service.validate_image_name(values.encrypted_ami_name) except aws_service.ImageNameError as e: print(e.message, file=sys.stderr) return 1 # Validate the region. regions = [str(r.name) for r in boto.vpc.regions()] if region not in regions: print( 'Invalid region %s. Must be one of %s.' % (region, str(regions)), file=sys.stderr ) return 1 encryptor_ami = values.encryptor_ami if not encryptor_ami: try: encryptor_ami = encrypt_ami.get_encryptor_ami(region) except: log.exception('Failed to get encryptor AMI.') return 1 session_id = util.make_nonce() default_tags = encrypt_ami.get_default_tags(session_id, encryptor_ami) try: # Connect to AWS. aws_svc = aws_service.AWSService( session_id, default_tags=default_tags) aws_svc.connect(region, key_name=values.key_name) except NoAuthHandlerFound: msg = ( 'Unable to connect to AWS. Are your AWS_ACCESS_KEY_ID and ' 'AWS_SECRET_ACCESS_KEY environment variables set?' ) if values.verbose: log.exception(msg) else: log.error(msg) return 1 try: if values.key_name: # Validate the key pair name. aws_svc.get_key_pair(values.key_name) if not values.no_validate_ami: error = aws_svc.validate_guest_ami(values.ami) if error: print(error, file=sys.stderr) return 1 error = aws_svc.validate_encryptor_ami(encryptor_ami) if error: print(error, file=sys.stderr) return 1 log.info('Starting encryptor session %s', aws_svc.session_id) encrypted_image_id = encrypt_ami.encrypt( aws_svc=aws_svc, enc_svc_cls=encryptor_service.EncryptorService, image_id=values.ami, encryptor_ami=encryptor_ami, encrypted_ami_name=values.encrypted_ami_name ) # Print the AMI ID to stdout, in case the caller wants to process # the output. Log messages go to stderr. print(encrypted_image_id) return 0 except EC2ResponseError as e: if e.error_code == 'AuthFailure': msg = 'Check your AWS login credentials and permissions' if values.verbose: log.exception(msg) else: log.error(msg + ': ' + e.error_message) elif e.error_code == 'InvalidKeyPair.NotFound': if values.verbose: log.exception(e.error_message) else: log.error(e.error_message) elif e.error_code == 'UnauthorizedOperation': if values.verbose: log.exception(e.error_message) else: log.error(e.error_message) log.error( 'Unauthorized operation. Check the IAM policy for your ' 'AWS account.' ) else: raise except util.BracketError as e: if values.verbose: log.exception(e.message) else: log.error(e.message) except KeyboardInterrupt: if values.verbose: log.exception('Interrupted by user') else: log.error('Interrupted by user') return 1