def testParallelExecute_partialFailed(self):
        """Test ParallelExecute on multiple hosts parallel with some host failed."""
        hosts = [
            host_util.Host(host_config, context=self.mock_context)
            for host_config in [self.host_config1, self.host_config2]
        ]
        for host in hosts:
            host._control_server_client = mock.MagicMock()
        excecution_exception = Exception('some error message.')
        self.mock_func_exceptions['host2'] = excecution_exception
        args_dict = self.default_args.copy()
        args_dict.update(
            parallel=True,
            service_account_json_key_path='path/to/key',
        )
        args = mock.MagicMock(**args_dict)

        host_util._ParallelExecute(
            host_util._WrapFuncForSetHost(self._MockFunc), args, hosts)

        # We don't know the order of the call since it's parallel.
        self.assertSameElements(['host1', 'host2'],
                                self.mock_func_calls.keys())
        self.assertEqual(host_util.HostExecutionState.COMPLETED,
                         hosts[0].execution_state)
        self.assertEqual(host_util.HostExecutionState.ERROR,
                         hosts[1].execution_state)
        (hosts[1].control_server_client.SubmitHostUpdateStateChangedEvent.
         assert_called_with(hosts[1].config.hostname,
                            host_util.HostUpdateState.ERRORED,
                            display_message=str(excecution_exception),
                            target_image='image1'))
    def testParallelExecute(self):
        """Test ParallelExecute on multiple hosts parallel."""
        hosts = [
            host_util.Host(host_config, context=self.mock_context)
            for host_config in
            [self.host_config1, self.host_config2, self.host_config3]
        ]
        for host in hosts:
            host._control_server_client = mock.MagicMock()
        args_dict = self.default_args.copy()
        args_dict.update(
            parallel=2,
            service_account_json_key_path='path/to/key',
        )
        args = mock.MagicMock(**args_dict)

        host_util._ParallelExecute(
            host_util._WrapFuncForSetHost(self._MockFunc), args, hosts)

        # We don't know the order of the call since it's parallel.
        self.assertSameElements(['host1', 'host2', 'host3'],
                                self.mock_func_calls.keys())
        self.assertEqual(host_util.HostExecutionState.COMPLETED,
                         hosts[0].execution_state)
        self.assertEqual(host_util.HostExecutionState.COMPLETED,
                         hosts[1].execution_state)
        self.assertEqual(host_util.HostExecutionState.COMPLETED,
                         hosts[2].execution_state)
    def testSequentialExecute_exitOnError(self):
        """Test _SequentialExecute multiple hosts sequentially and failed."""
        hosts = [
            host_util.Host(host_config, context=self.mock_context)
            for host_config in [self.host_config1, self.host_config2]
        ]
        for host in hosts:
            host._control_server_client = mock.MagicMock()
        self.mock_func_exceptions['host1'] = Exception()
        args_dict = self.default_args.copy()
        args_dict.update(
            parallel=False,
            exit_on_error=True,
            service_account_json_key_path='path/to/key',
        )
        args = mock.MagicMock(**args_dict)

        with self.assertRaises(Exception):
            host_util._SequentialExecute(host_util._WrapFuncForSetHost(
                self._MockFunc),
                                         args,
                                         hosts,
                                         exit_on_error=True)

        (hosts[0].control_server_client.SubmitHostUpdateStateChangedEvent.
         assert_called_with(hosts[0].config.hostname,
                            host_util.HostUpdateState.ERRORED,
                            target_image='image1'))

        self.assertSameElements(['host1'], self.mock_func_calls.keys())
        self.assertEqual(host_util.HostExecutionState.ERROR,
                         hosts[0].execution_state)
        self.assertEqual(host_util.HostExecutionState.UNKNOWN,
                         hosts[1].execution_state)
    def testWrapFuncForSetHost_skip(self):
        host = host_util.Host(self.host_config1, self.ssh_config1)
        host.context = self.mock_context
        host.StartExecutionTimer = mock.MagicMock()
        host.StopExecutionTimer = mock.MagicMock()
        host.execution_state = host_util.HostExecutionState.COMPLETED
        args = mock.MagicMock(**self.default_args)
        f = host_util._WrapFuncForSetHost(self._MockFunc)

        f(args, host)

        self.assertEmpty(self.mock_func_calls)
        host.StartExecutionTimer.assert_not_called()
        host.StopExecutionTimer.assert_not_called()
    def testWrapFuncForSetHost(self):
        host = host_util.Host(self.host_config1, self.ssh_config1)
        host.context = self.mock_context
        host.StartExecutionTimer = mock.MagicMock()
        host.StopExecutionTimer = mock.MagicMock()
        args = mock.MagicMock(**self.default_args)
        f = host_util._WrapFuncForSetHost(self._MockFunc)

        f(args, host)

        self.assertSameElements(['host1'], self.mock_func_calls.keys())
        self.assertEqual(host_util.HostExecutionState.COMPLETED,
                         host.execution_state)
        host.StartExecutionTimer.assert_called_once()
        host.StopExecutionTimer.assert_called_once()
    def testWrapFuncForSetHost_error(self):
        host = host_util.Host(self.host_config1, self.ssh_config1)
        host.context = self.mock_context
        host.StartExecutionTimer = mock.MagicMock()
        host.StopExecutionTimer = mock.MagicMock()
        args = mock.MagicMock(**self.default_args)
        f = host_util._WrapFuncForSetHost(self._MockFunc)
        e = Exception('Fail to run command.')
        self.mock_func_exceptions['host1'] = e
        with self.assertRaises(Exception):
            f(args, host)

        self.assertSameElements(['host1'], self.mock_func_calls.keys())
        self.assertEqual(host_util.HostExecutionState.ERROR,
                         host.execution_state)
        self.assertEqual(e, host.error)
        host.StartExecutionTimer.assert_called_once()
        host.StopExecutionTimer.assert_called_once()