예제 #1
0
    def _stream_docker_logs(self):
        time.sleep(5)
        cmd = '{} docker ps -l -q'.format(self.ssh)
        output = run_cmd(cmd, logger=self.logger, return_output=True)

        cmd = '{} docker logs {} --follow'.format(self.ssh, output)
        run_cmd(cmd, logger=self.logger, level='info')
예제 #2
0
    def _set_remote_dirs(self):
        cmd = '{} mkdir -p {}'.format(self.ssh, self.remote_workdir)
        run_cmd(cmd, logger=self.logger)

        self.remote_image_dir = self.remote_workdir.joinpath(
            self.image_dir.name)
        self.remote_job_dir = self.remote_workdir.joinpath(self.job_dir.name)
예제 #3
0
    def _sync_s3_local(self):
        self.logger.info('Syncing files s3 <> local...')
        self._set_s3_dirs()

        # only sync if job dir is local dir
        if 's3://' not in self.job_dir:
            cmd = 'aws s3 sync --exclude logs --quiet {} {}'.format(self.s3_job_dir, self.job_dir)
            run_cmd(cmd, logger=self.logger, level='info')
예제 #4
0
    def _set_remote_dirs(self):
        cmd = '{} mkdir -p {}'.format(self.ssh, self.remote_workdir)
        run_cmd(cmd, logger=self.logger)

        self.remote_image_dir = os.path.join(self.remote_workdir,
                                             os.path.basename(self.image_dir))
        self.remote_job_dir = os.path.join(self.remote_workdir,
                                           os.path.basename(self.job_dir))
예제 #5
0
    def _launch_train_container(self, **kwargs):
        self.logger.info('Launching training container...')
        # training parameters are passed to container through environment variables
        envs = ['-e {}={}'.format(key, value) for key, value in kwargs.items() if value is not None]

        cmd = (
            '{} docker run -d -v {}:$WORKDIR/image_dir -v {}:$WORKDIR/job_dir {} '
            'idealo/tensorflow-image-atm:1.13.1'
        ).format(self.ssh, self.remote_image_dir, self.remote_job_dir, ' '.join(envs))

        run_cmd(cmd, logger=self.logger)
예제 #6
0
    def _sync_s3_remote(self):
        self.logger.info('Syncing files s3 <> remote...')
        self._set_s3_dirs()
        self._set_remote_dirs()

        cmd = '{} aws s3 sync --exclude logs --quiet {} {}'.format(
            self.ssh, self.s3_image_dir, self.remote_image_dir)
        run_cmd(cmd, logger=self.logger)

        cmd = '{} aws s3 sync --exclude logs --quiet {} {}'.format(
            self.ssh, self.s3_job_dir, self.remote_job_dir)
        run_cmd(cmd, logger=self.logger)
예제 #7
0
    def test_run_cmd_2(self, mocker):
        mp_debug = mocker.patch('logging.Logger.debug')
        mp_info = mocker.patch('logging.Logger.info')

        cmd = 'echo Hello world'
        logger = logging.Logger(__name__)
        level = 'info'
        return_output = False

        run_cmd(cmd, logger, level, return_output)

        mp_debug.assert_not_called()
        mp_info.assert_called_once()
예제 #8
0
    def _sync_local_s3(self):
        self.logger.info('Syncing files local <> s3...')
        self._set_s3_dirs()

        # only sync if image dir is local dir
        if 's3://' not in str(self.image_dir):
            cmd = 'aws s3 sync --quiet --exclude logs {} {}'.format(
                self.image_dir, self.s3_image_dir)
            run_cmd(cmd, logger=self.logger)

        # only sync if job dir is local dir
        if 's3://' not in str(self.job_dir):
            cmd = 'aws s3 sync --quiet --exclude logs {} {}'.format(
                self.job_dir, self.s3_job_dir)
            run_cmd(cmd, logger=self.logger)
예제 #9
0
    def destroy(self):
        """Runs Terraform destroy."""
        self.logger.info('Running terraform destroy...')
        cmd = (
            'cd {} && terraform destroy -auto-approve -var "region={}" -var "instance_type={}" '
            '-var "vpc_id={}" -var "s3_bucket={}" -var "name={}"').format(
                self.tf_dir,
                self.region,
                self.instance_type,
                self.vpc_id,
                self.s3_bucket_wo,
                self.cloud_tag,
            )

        run_cmd(cmd, logger=self.logger)
예제 #10
0
    def test_run_cmd_4(self, mocker):
        mp_debug = mocker.patch('logging.Logger.debug')
        mp_info = mocker.patch('logging.Logger.info')
        mp_error = mocker.patch('logging.Logger.error')

        cmd = 'echo2 Hello world'
        logger = logging.Logger(__name__)
        level = 'debug'
        return_output = False

        with pytest.raises(Exception) as excinfo:
            run_cmd(cmd, logger, level, return_output)
        mp_debug.assert_not_called()
        mp_info.assert_not_called()
        mp_error.assert_called_once()
예제 #11
0
 def _set_ssh(self):
     cmd = 'cd {} && terraform output public_ip'.format(self.tf_dir)
     output = run_cmd(cmd, logger=self.logger, return_output=True)
     self.ssh = 'ssh -o StrictHostKeyChecking=no -i ~/.ssh/id_rsa ec2-user@{}'.format(
         output)
예제 #12
0
 def init(self):
     """Runs Terraform initialization."""
     self.logger.info('Running terraform init...')
     cmd = 'cd {} && terraform init'.format(self.tf_dir)
     run_cmd(cmd, logger=self.logger)