def testGetMaxWorker_ParallelAllHosts(self):
     args = mock.MagicMock(parallel=True)
     hosts = [
         host_util.Host(host_util.lab_config.CreateHostConfig())
         for _ in range(3)
     ]
     self.assertEqual(3, host_util._GetMaxWorker(args, hosts))
 def testExecutionTimeElapsed(self, start_time, end_time, current_time,
                              expected_expression, mock_time):
     mock_time.time.return_value = current_time
     host = host_util.Host(host_util.lab_config.CreateHostConfig())
     host._execution_start_time = start_time
     host._execution_end_time = end_time
     self.assertEqual(expected_expression, host.execution_time_elapsed)
    def testParallelExecute_partialFailed(self):
        """Test ParallelExecute on multiple hosts parallel with some host failed."""
        hosts = [
            host_util.Host(host_config, context=self.mock_context)
            for host_config in [self.host_config1, self.host_config2]
        ]
        for host in hosts:
            host._control_server_client = mock.MagicMock()
        excecution_exception = Exception('some error message.')
        self.mock_func_exceptions['host2'] = excecution_exception
        args_dict = self.default_args.copy()
        args_dict.update(
            parallel=True,
            service_account_json_key_path='path/to/key',
        )
        args = mock.MagicMock(**args_dict)

        host_util._ParallelExecute(
            host_util._WrapFuncForSetHost(self._MockFunc), args, hosts)

        # We don't know the order of the call since it's parallel.
        self.assertSameElements(['host1', 'host2'],
                                self.mock_func_calls.keys())
        self.assertEqual(host_util.HostExecutionState.COMPLETED,
                         hosts[0].execution_state)
        self.assertEqual(host_util.HostExecutionState.ERROR,
                         hosts[1].execution_state)
        (hosts[1].control_server_client.SubmitHostUpdateStateChangedEvent.
         assert_called_with(hosts[1].config.hostname,
                            host_util.HostUpdateState.ERRORED,
                            display_message=str(excecution_exception),
                            target_image='image1'))
    def testParallelExecute(self):
        """Test ParallelExecute on multiple hosts parallel."""
        hosts = [
            host_util.Host(host_config, context=self.mock_context)
            for host_config in
            [self.host_config1, self.host_config2, self.host_config3]
        ]
        for host in hosts:
            host._control_server_client = mock.MagicMock()
        args_dict = self.default_args.copy()
        args_dict.update(
            parallel=2,
            service_account_json_key_path='path/to/key',
        )
        args = mock.MagicMock(**args_dict)

        host_util._ParallelExecute(
            host_util._WrapFuncForSetHost(self._MockFunc), args, hosts)

        # We don't know the order of the call since it's parallel.
        self.assertSameElements(['host1', 'host2', 'host3'],
                                self.mock_func_calls.keys())
        self.assertEqual(host_util.HostExecutionState.COMPLETED,
                         hosts[0].execution_state)
        self.assertEqual(host_util.HostExecutionState.COMPLETED,
                         hosts[1].execution_state)
        self.assertEqual(host_util.HostExecutionState.COMPLETED,
                         hosts[2].execution_state)
    def testSequentialExecute_exitOnError(self):
        """Test _SequentialExecute multiple hosts sequentially and failed."""
        hosts = [
            host_util.Host(host_config, context=self.mock_context)
            for host_config in [self.host_config1, self.host_config2]
        ]
        for host in hosts:
            host._control_server_client = mock.MagicMock()
        self.mock_func_exceptions['host1'] = Exception()
        args_dict = self.default_args.copy()
        args_dict.update(
            parallel=False,
            exit_on_error=True,
            service_account_json_key_path='path/to/key',
        )
        args = mock.MagicMock(**args_dict)

        with self.assertRaises(Exception):
            host_util._SequentialExecute(host_util._WrapFuncForSetHost(
                self._MockFunc),
                                         args,
                                         hosts,
                                         exit_on_error=True)

        (hosts[0].control_server_client.SubmitHostUpdateStateChangedEvent.
         assert_called_with(hosts[0].config.hostname,
                            host_util.HostUpdateState.ERRORED,
                            target_image='image1'))

        self.assertSameElements(['host1'], self.mock_func_calls.keys())
        self.assertEqual(host_util.HostExecutionState.ERROR,
                         hosts[0].execution_state)
        self.assertEqual(host_util.HostExecutionState.UNKNOWN,
                         hosts[1].execution_state)
 def testGetMaxWorker(self, parallel, expected_max_worker):
     args = mock.MagicMock(parallel=parallel)
     hosts = [
         host_util.Host(host_util.lab_config.CreateHostConfig())
         for _ in range(3)
     ]
     self.assertEqual(expected_max_worker,
                      host_util._GetMaxWorker(args, hosts))
 def testHostContext(self):
     """Test Host.context."""
     host = host_util.Host(self.host_config1, self.ssh_config1)
     self.assertIsNotNone(host.context)
     self.mock_create_context.assert_called_once_with(
         self.host_config1.hostname,
         self.host_config1.host_login_name,
         ssh_config=self.ssh_config1,
         sudo_ssh_config=None)
    def testWrapFuncForSetHost_skip(self):
        host = host_util.Host(self.host_config1, self.ssh_config1)
        host.context = self.mock_context
        host.StartExecutionTimer = mock.MagicMock()
        host.StopExecutionTimer = mock.MagicMock()
        host.execution_state = host_util.HostExecutionState.COMPLETED
        args = mock.MagicMock(**self.default_args)
        f = host_util._WrapFuncForSetHost(self._MockFunc)

        f(args, host)

        self.assertEmpty(self.mock_func_calls)
        host.StartExecutionTimer.assert_not_called()
        host.StopExecutionTimer.assert_not_called()
    def testHostContext_unknownException(self):
        """Test Host.context with exception."""
        self.mock_create_context.side_effect = Exception('Connection timeout.')

        context = None
        with self.assertRaises(Exception):
            context = host_util.Host(self.host_config1,
                                     self.ssh_config1).context
        self.assertIsNone(context)
        self.mock_create_context.assert_called_once_with(
            'host1',
            'user1',
            ssh_config=self.ssh_config1,
            sudo_ssh_config=None)
    def testWrapFuncForSetHost(self):
        host = host_util.Host(self.host_config1, self.ssh_config1)
        host.context = self.mock_context
        host.StartExecutionTimer = mock.MagicMock()
        host.StopExecutionTimer = mock.MagicMock()
        args = mock.MagicMock(**self.default_args)
        f = host_util._WrapFuncForSetHost(self._MockFunc)

        f(args, host)

        self.assertSameElements(['host1'], self.mock_func_calls.keys())
        self.assertEqual(host_util.HostExecutionState.COMPLETED,
                         host.execution_state)
        host.StartExecutionTimer.assert_called_once()
        host.StopExecutionTimer.assert_called_once()
    def testWrapFuncForSetHost_error(self):
        host = host_util.Host(self.host_config1, self.ssh_config1)
        host.context = self.mock_context
        host.StartExecutionTimer = mock.MagicMock()
        host.StopExecutionTimer = mock.MagicMock()
        args = mock.MagicMock(**self.default_args)
        f = host_util._WrapFuncForSetHost(self._MockFunc)
        e = Exception('Fail to run command.')
        self.mock_func_exceptions['host1'] = e
        with self.assertRaises(Exception):
            f(args, host)

        self.assertSameElements(['host1'], self.mock_func_calls.keys())
        self.assertEqual(host_util.HostExecutionState.ERROR,
                         host.execution_state)
        self.assertEqual(e, host.error)
        host.StartExecutionTimer.assert_called_once()
        host.StopExecutionTimer.assert_called_once()
    def testHostContext_withSSHInfo(self):
        ssh_config = ssh_util.SshConfig(user='******',
                                        hostname='host1',
                                        password='******',
                                        ssh_key='/ssh_key')
        sudo_ssh_config = ssh_util.SshConfig(user='******',
                                             hostname='host1',
                                             password='******',
                                             ssh_key='/ssh_key')
        host = host_util.Host(self.host_config1,
                              ssh_config=ssh_config,
                              sudo_ssh_config=sudo_ssh_config)

        self.assertIsNotNone(host.context)
        self.mock_create_context.assert_called_once_with(
            'host1',
            'user1',
            ssh_config=ssh_config,
            sudo_ssh_config=sudo_ssh_config)
 def setUp(self):
     super(ExecutionStatePrinterTest, self).setUp()
     self.logger_patcher = mock.patch('__main__.host_util.logger')
     self.mock_logger = self.logger_patcher.start()
     self.now_patcher = mock.patch('__main__.host_util._GetCurrentTime')
     self.mock_now = self.now_patcher.start()
     self.mock_now.return_value = 1
     self.hosts = []
     for i in range(5):
         self.hosts.append(
             host_util.Host(
                 host_util.lab_config.CreateHostConfig(
                     hostname='host' + str(i), host_login_name='auser')))
     self.mock_now.return_value = 11
     for i in range(5):
         self.hosts[i].StartExecutionTimer()
         self.hosts[i].execution_state = 'Step1'
     self.mock_now.return_value = 21
     self.hosts[0].execution_state = 'Step2'
     self.hosts[2].execution_state = 'Step2'
     self.mock_now.return_value = 31
     self.hosts[0].execution_state = host_util.HostExecutionState.COMPLETED
     self.hosts[4].execution_state = host_util.HostExecutionState.ERROR
     self.printer = host_util.ExecutionStatePrinter(self.hosts)