Exemple #1
0
    def test_setup_overrides_default_username(self):
        """
        Testing setup_overrides fails if user doesn't change username
        from bootstrap.example.yml default
        """
        real_configdict = ConfigDict("bootstrap")
        real_configdict.load()

        real_configdict.raw["bootstrap"]["overrides"] = {
            "infrastructure_provisioning": {
                "tfvars": {
                    "tags": {
                        "owner": "your.username"
                    }
                }
            }
        }
        test_override_path = whereami.dsi_repo_path("dsi", "tests")
        with self.assertRaises(AssertionError):
            bootstrap.setup_overrides(real_configdict, test_override_path)

        # Removing created file
        try:
            os.remove(os.path.join(test_override_path, "overrides.yml"))
        except OSError:
            pass
    def test_build_hosts_file(self):
        expected = [
            "10.2.1.1\tmd md0 mongod0 mongod0.dsitest.dev",
            "10.2.1.2\tmd1 mongod1 mongod1.dsitest.dev",
            "10.2.1.3\tmd2 mongod2 mongod2.dsitest.dev",
            "10.2.1.4\tmd3 mongod3 mongod3.dsitest.dev",
            "10.2.1.5\tmd4 mongod4 mongod4.dsitest.dev",
            "10.2.1.6\tmd5 mongod5 mongod5.dsitest.dev",
            "10.2.1.7\tmd6 mongod6 mongod6.dsitest.dev",
            "10.2.1.8\tmd7 mongod7 mongod7.dsitest.dev",
            "10.2.1.9\tmd8 mongod8 mongod8.dsitest.dev",
            "10.2.1.100\tms ms0 mongos0 mongos0.dsitest.dev",
            "10.2.1.101\tms1 mongos1 mongos1.dsitest.dev",
            "10.2.1.102\tms2 mongos2 mongos2.dsitest.dev",
            "10.2.1.51\tcs cs0 configsvr0 configsvr0.dsitest.dev",
            "10.2.1.52\tcs1 configsvr1 configsvr1.dsitest.dev",
            "10.2.1.53\tcs2 configsvr2 configsvr2.dsitest.dev",
            "10.2.1.10\twc wc0 workload_client0 workload_client0.dsitest.dev",
        ]

        real_config_dict = ConfigDict(
            "infrastructure_provisioning",
            whereami.dsi_repo_path("docs", "config-specs"))
        real_config_dict.load()
        real_config_dict.save = MagicMock(name="save")

        provisioner = ip.Provisioner(real_config_dict,
                                     provisioning_file=self.provision_log_path)
        hosts_contents = provisioner._build_hosts_file()
        self.assertEqual(expected, hosts_contents)
Exemple #3
0
    def test_setup_overrides_file_exists_empty_config(self):
        """
        Testing setup_overrides, path = True and config vals not given
        """
        real_configdict = ConfigDict("bootstrap")
        real_configdict.load()

        test_override_path = whereami.dsi_repo_path("dsi", "tests")
        test_override_str = yaml.dump({}, default_flow_style=False)

        # Creating 'overrides.yml' in current dir
        with open(os.path.join(test_override_path, "overrides.yml"),
                  "w+") as test_override_file:
            test_override_file.write(test_override_str)

        # Call to setup_overrides updates 'overrides.yml' in current dir
        bootstrap.setup_overrides(real_configdict, test_override_path)

        test_override_dict = {}
        with open(os.path.join(test_override_path, "overrides.yml"),
                  "r") as test_override_file:
            test_override_dict = yaml.load(test_override_file)

        self.assertEqual(test_override_dict, {})

        # Removing created file
        os.remove(os.path.join(test_override_path, "overrides.yml"))
Exemple #4
0
    def test_setup_overrides_no_file_config_vals(self):
        """
        Testing setup_overrides where path = False and config vals given
        """
        real_configdict = ConfigDict("bootstrap")
        real_configdict.load()

        master_overrides = {}
        master_overrides.update({
            "infrastructure_provisioning": {
                "tfvars": {
                    "ssh_key_file": "test_ssh_key_file.pem",
                    "ssh_key_name": "test_ssh_key_name",
                    "tags": {
                        "owner": "testuser",
                        "expire-on-delta": 24
                    },
                }
            }
        })
        real_configdict.raw["bootstrap"] = {}
        real_configdict.raw["bootstrap"]["overrides"] = master_overrides

        test_override_path = whereami.dsi_repo_path("dsi", "tests")
        test_override_dict = {}

        # Call to setup_overrides creates 'overrides.yml' in current dir
        bootstrap.setup_overrides(real_configdict, test_override_path)
        with open(os.path.join(test_override_path, "overrides.yml"),
                  "r") as test_override_file:
            test_override_dict = yaml.load(test_override_file)
        self.assertEqual(test_override_dict, master_overrides)

        # Removing created file
        os.remove(os.path.join(test_override_path, "overrides.yml"))
Exemple #5
0
 def test_variable_reference_contains_invalid_id(self):
     """
     Variable references cannot evaluate to blocks containing duplicate ids.
     """
     with in_dir(FIXTURE_FILES.fixture_file_path("nested-invalid-ids")):
         with self.assertRaises(config.InvalidConfigurationException):
             conf = ConfigDict("mongodb_setup")
             conf.load()
Exemple #6
0
 def test_find_nested_config_dicts(self):
     """
     We check for duplicate ids in lists of lists correctly.
     """
     with in_dir(FIXTURE_FILES.fixture_file_path("invalid-ids-in-lists")):
         with self.assertRaises(config.InvalidConfigurationException):
             conf = ConfigDict("mongodb_setup")
             conf.load()
Exemple #7
0
def main():
    """ Main function """
    args = parse_command_line()
    setup_logging(args.debug, args.log_file)
    config = ConfigDict("infrastructure_provisioning")
    config.load()
    provisioner = Provisioner(config, verbose=args.debug)
    provisioner.provision_resources()
Exemple #8
0
    def test_setup_overrides_file_exists_config_vals(self):
        """
        Testing setup_overrides where path = True and config vals given
        """
        real_configdict = ConfigDict("bootstrap")
        real_configdict.load()

        test_override_path = whereami.dsi_repo_path("dsi", "tests")
        master_overrides = {}
        master_overrides.update({
            "infrastructure_provisioning": {
                "tfvars": {
                    "ssh_key_file": "test_ssh_key_file1.pem",
                    "ssh_key_name": "test_ssh_key_name1",
                    "tags": {
                        "owner": "testuser1",
                        "expire-on-delta": 24
                    },
                }
            }
        })
        real_configdict.raw["bootstrap"]["overrides"] = master_overrides
        test_override_str = yaml.dump(
            {
                "infrastructure_provisioning": {
                    "tfvars": {
                        "ssh_key_file": "test_ssh_key_file2.pem",
                        "ssh_key_name": "test_ssh_key_name2",
                        "tags": {
                            "owner": "testuser2",
                            "expire-on-delta": 48
                        },
                    }
                }
            },
            default_flow_style=False,
        )

        # Creating 'overrides.yml' in current dir
        with open(os.path.join(test_override_path, "overrides.yml"),
                  "w") as test_override_file:
            test_override_file.write(test_override_str)

        # Call to setup_overrides updates 'overrides.yml' in current dir
        bootstrap.setup_overrides(real_configdict, test_override_path)

        test_override_dict = {}
        with open(os.path.join(test_override_path, "overrides.yml"),
                  "r") as test_override_file:
            test_override_dict = yaml.load(test_override_file)

        self.assertEqual(test_override_dict, master_overrides)

        # Removing created file
        os.remove(os.path.join(test_override_path, "overrides.yml"))
Exemple #9
0
 def test_load_new(self):
     """Test loading ConfigDict with old naming convention .yml files"""
     test_conf = ConfigDict(
         "bootstrap",
         whereami.dsi_repo_path("dsi", "tests", "test_config_files",
                                "new_format"))
     test_conf.load()
     self.assertFalse("cluster_type" in test_conf.raw["bootstrap"])
     self.assertTrue(
         "infrastructure_provisioning" in test_conf.raw["bootstrap"])
     self.assertFalse("cluster_type" in test_conf.defaults["bootstrap"])
     self.assertTrue(
         "infrastructure_provisioning" in test_conf.defaults["bootstrap"])
Exemple #10
0
def load_config_dict(module):
    """
    Load ConfigDict for the given module with id checks mocked out.

    :param str module: Name of module for ConfigDict.
    """
    # pylint: disable=import-outside-toplevel
    from dsi.common.config import ConfigDict

    with patch("dsi.common.config.ConfigDict.assert_valid_ids") as mock_assert_valid_ids:
        conf = ConfigDict(module)
        conf.load()
        mock_assert_valid_ids.assert_called_once()
        return conf
Exemple #11
0
    def test_setup_overrides_no_file_empty_config(self):
        """
        Testing setup_overrides, path = False and config vals not given
        """
        real_configdict = ConfigDict("bootstrap")
        real_configdict.load()

        test_override_path = whereami.dsi_repo_path("dsi", "tests")

        # Call to setup_overrides creates 'overrides.yml' in current dir
        bootstrap.setup_overrides(real_configdict, test_override_path)

        self.assertFalse(
            os.path.exists(os.path.join(test_override_path, "overrides.yml")))
Exemple #12
0
def main(argv=sys.argv[1:]):
    """
    Parse args and call workload_setup.yml operations
    """
    parser = argparse.ArgumentParser(description="Workload Setup")

    parser.add_argument("-d", "--debug", action="store_true", help="enable debug output")
    parser.add_argument("--log-file", help="path to log file")

    args = parser.parse_args(argv)
    setup_logging(args.debug, args.log_file)

    config = ConfigDict("workload_setup")
    config.load()

    setup = WorkloadSetupRunner(config)
    setup.setup_workloads()
Exemple #13
0
    def test_validate_delays(self):
        blacklisted_configs = ["mongodb_setup.atlas.yml"]  # Some configs don't have topologies.
        directory = whereami.dsi_repo_path("configurations", "mongodb_setup")
        errors = []

        # There are a few files that aren't configuration files.
        names = [name for name in os.listdir(directory) if name.startswith("mongodb_setup")]

        # We need references to provisioning output
        infrastructure_provisioning = whereami.dsi_repo_path(
            "docs/config-specs/infrastructure_provisioning.out.yml"
        )
        with copied_file(infrastructure_provisioning, "infrastructure_provisioning.out.yml"):

            for conf_name in names:
                if conf_name in blacklisted_configs:
                    continue
                with copied_file(os.path.join(directory, conf_name), "mongodb_setup.yml"):
                    config = ConfigDict("mongodb_setup")
                    config.load()

                    topologies = config["mongodb_setup"]["topology"]
                    network_delays = config["mongodb_setup"]["network_delays"]
                    delay_configs = network_delays["clusters"]

                    try:
                        # The DelayNodes throw exceptions when given bad delays, so we
                        # can validate the configuration by simply constructing a DelayGraph
                        # pylint: disable=unused-variable
                        DelayGraph.client_ip = config["infrastructure_provisioning"]["out"][
                            "workload_client"
                        ][0]["private_ip"]

                        version_flag = str_to_version_flag(
                            network_delays.get("version_flag", "default")
                        )
                        DelayGraph.client_node = DelayNode(version_flag)
                        delays = DelayGraph.from_topologies(topologies, delay_configs, version_flag)
                    # pylint: disable=broad-except
                    except Exception as e:
                        errors.append(e)

        # Reset the delay graph's client variables.
        DelayGraph.client_node = DelayNode()
        DelayGraph.client_ip = "workload_client"
        self.assertEqual(errors, [])
    def test_setup_hostnames(self, mock_exec_command, mock_create_file,
                             mock_ssh):
        _ = mock_ssh
        real_config_dict = ConfigDict(
            "infrastructure_provisioning",
            whereami.dsi_repo_path("docs", "config-specs"))
        real_config_dict.load()
        real_config_dict.save = MagicMock(name="save")

        provisioner = ip.Provisioner(real_config_dict,
                                     provisioning_file=self.provision_log_path)
        provisioner.setup_hostnames()
        out = provisioner.config["infrastructure_provisioning"]["out"]
        self.assertEqual(out["mongod"][0]["private_hostname"],
                         "mongod0.dsitest.dev")
        self.assertEqual(out["configsvr"][2]["private_hostname"],
                         "configsvr2.dsitest.dev")
        self.assertEqual(mock_create_file.call_count, 16)
        self.assertEqual(mock_exec_command.call_count, 16)
Exemple #15
0
def main(argv):
    """ Main function. Parse command line options, and run tests.

    :returns: int the exit status to return to the caller (0 for OK)
    """
    parser = argparse.ArgumentParser(description="DSI Test runner")

    parser.add_argument("-d",
                        "--debug",
                        action="store_true",
                        help="enable debug output")
    parser.add_argument("--log-file", help="path to log file")
    args = parser.parse_args(argv)
    log.setup_logging(args.debug, args.log_file)

    config = ConfigDict("test_control")
    config.load()

    error = run_tests(config)
    return 1 if error else 0
Exemple #16
0
    def test_setup_overrides_type_error(self):
        """
        Testing setup_overrides doesn't throw TypeError.
        """
        real_configdict = ConfigDict("bootstrap")
        real_configdict.load()

        real_configdict.raw["bootstrap"]["overrides"] = {
            "infrastructure_provisioning": {
                "tfvars": {
                    "tags": None
                }
            }
        }
        test_override_path = whereami.dsi_repo_path("dsi", "tests")
        bootstrap.setup_overrides(real_configdict, test_override_path)

        # Removing created file
        try:
            os.remove(os.path.join(test_override_path, "overrides.yml"))
        except OSError:
            pass
Exemple #17
0
def main(argv):
    """ Main function. Parse command line options, and run analysis.

    Note that the return value here determines whether Evergreen considers the entire task passed
    or failed. Non-zero return value means failure.

    :returns: int the exit status to return to the caller (0 for OK)
    """
    parser = argparse.ArgumentParser(description="Analyze DSI test results.")

    parser.add_argument("-d",
                        "--debug",
                        action="store_true",
                        help="enable debug output")
    parser.add_argument("--log-file", help="path to log file")
    args = parser.parse_args(argv)
    setup_logging(args.debug, args.log_file)

    config = ConfigDict("analysis")
    config.load()

    analyzer = ResultsAnalyzer(config)
    analyzer.analyze_all()
    return 1 if analyzer.failures > 0 else 0
Exemple #18
0
def main():
    """ Handle the main functionality (parse args /setup logging ) and then start the mongodb
    cluster."""
    args = parse_command_line()
    setup_logging(args.debug, args.log_file)

    config = ConfigDict("mongodb_setup")
    config.load()

    # Start MongoDB cluster(s) using config given in mongodb_setup.topology (if any).
    # Note: This also installs mongo client binary onto workload client.
    mongo = MongodbSetup(config=config)

    # Reset delays before starting so delays don't break setup.
    mongo.client.reset_delays()
    for cluster in mongo.clusters:
        cluster.reset_delays()

    start_cluster(mongo, config)

    # Establish delays *after* the cluster is started so delays won't interfere with setup.
    mongo.client.establish_delays()
    for cluster in mongo.clusters:
        cluster.establish_delays()
Exemple #19
0
class ConfigDictTestCase(unittest.TestCase):
    """Unit tests for ConfigDict library."""
    def setUp(self):
        """Init a ConfigDict object and load the configuration files from docs/config-specs/"""
        self.conf = ConfigDict("mongodb_setup",
                               whereami.dsi_repo_path("docs", "config-specs"))
        self.conf.load()
        self.assertEqual(self.conf.module, "mongodb_setup")

    def test_load_new(self):
        """Test loading ConfigDict with old naming convention .yml files"""
        test_conf = ConfigDict(
            "bootstrap",
            whereami.dsi_repo_path("dsi", "tests", "test_config_files",
                                   "new_format"))
        test_conf.load()
        self.assertFalse("cluster_type" in test_conf.raw["bootstrap"])
        self.assertTrue(
            "infrastructure_provisioning" in test_conf.raw["bootstrap"])
        self.assertFalse("cluster_type" in test_conf.defaults["bootstrap"])
        self.assertTrue(
            "infrastructure_provisioning" in test_conf.defaults["bootstrap"])

    def test_none_valued_keys(self):
        config_dict = self.conf
        self.assertEqual(config_dict["runtime"]["overridden_none"],
                         "hey there")
        self.assertEqual(config_dict["runtime"]["override_with_none"], None)
        self.assertEqual(config_dict["runtime"]["overridden_dict"], None)
        self.assertEqual(config_dict["runtime"]["overridden_list"], None)
        with self.assertRaises(KeyError):
            config_dict["runtime"]["nonexistant"]  # pylint: disable=pointless-statement

    def test_traverse_entire_dict(self):
        """Traverse entire dict (also tests that the structure of docs/config-specs/ files are ok)"""
        # We actually could compare the result to a constant megadict here, but maintaining that
        # would quickly become tedious. In practice, there's huge value just knowing we can traverse
        # the entire structure without errors.
        str(self.conf)

    @unittest.skip("dict(instance_of_ConfigDict) does not work")
    def test_cast_as_dict(self):
        """It is possible to cast a ConfigDict to a dict"""
        # TODO: this doesn't actually work. Seems like a limitation of python when sub-classing
        # native type like dict: http://stackoverflow.com/questions/18317905/overloaded-iter-is-bypassed-when-deriving-from-dict
        complete_dict = dict(self.conf)
        sub_dict = dict(
            self.conf["workload_setup"]["tasks"][0]["on_workload_client"])
        self.assertEqual(
            complete_dict["workload_setup"]["tasks"][0]["on_workload_client"]
            ["retrieve_files"][0],
            {
                "source": "http://url1",
                "target": "file"
            },
        )
        self.assertEqual(
            sub_dict["retrieve_files"][0],
            {
                "source": "remote_file_path",
                "target": "local_file_path"
            },
        )

    def test_convert_to_dict(self):
        """It is possible to convert a ConfigDict to a dict with self.as_dict()"""
        complete_dict = self.conf.as_dict()
        sub_dict = self.conf["workload_setup"]["ycsb"][0][
            "on_workload_client"].as_dict()
        self.assertEqual(
            complete_dict["workload_setup"]["ycsb"][0]["on_workload_client"]
            ["retrieve_files"][0],
            {
                "source": "remote_file_path",
                "target": "local_file_path"
            },
        )
        self.assertEqual(
            sub_dict["retrieve_files"][0],
            {
                "source": "remote_file_path",
                "target": "local_file_path"
            },
        )

    def test_basic_checks(self):
        """Basic checks"""
        self.assert_equal_dicts(
            self.conf["workload_setup"]["ycsb"][0]["on_workload_client"]
            ["retrieve_files"][0],
            {
                "source": "remote_file_path",
                "target": "local_file_path"
            },
        )
        expected_result = [{
            "source": "remote_file_path",
            "target": "local_file_path"
        }]
        actual_result = self.conf["workload_setup"]["ycsb"][0][
            "on_workload_client"]["retrieve_files"]
        self.assertEqual(len(actual_result), len(expected_result))
        for actual, expected in zip(actual_result, expected_result):
            self.assert_equal_dicts(actual, expected)
        self.assert_equal_dicts(
            self.conf["infrastructure_provisioning"]["out"]["mongos"][2],
            {
                "public_ip": "53.1.1.102",
                "private_ip": "10.2.1.102"
            },
        )
        self.assertEqual(
            self.conf["infrastructure_provisioning"]["out"]["workload_client"]
            [0]["public_ip"],
            "53.1.1.101",
        )
        self.assertEqual(
            type(self.conf["infrastructure_provisioning"]["out"]
                 ["workload_client"][0]["public_ip"]),
            type(""),
        )

    def test_overrides(self):
        """Test value from overrides.yml"""
        self.assertEqual(
            self.conf["infrastructure_provisioning"]["tfvars"]
            ["configsvr_instance_type"],
            "t1.micro",
        )
        self.assertEqual(
            self.conf["infrastructure_provisioning"]["tfvars"].as_dict(),
            {
                "cluster_name": "shard",
                "mongos_instance_type": "c3.8xlarge",
                "availability_zone": "us-west-2a",
                "workload_instance_count": 1,
                "region": "us-west-2",
                "image": "amazon2",
                "mongod_instance_count": 9,
                "configsvr_instance_count": 3,
                "mongos_instance_count": 3,
                "ssh_key_file": "~/.ssh/linustorvalds.pem",
                "ssh_user": "******",
                "mongod_instance_type": "c3.8xlarge",
                "ssh_key_name": "linus.torvalds",
                "workload_instance_type": "c3.8xlarge",
                "tags": {
                    "Project": "sys-perf",
                    "owner": "*****@*****.**",
                    "Variant": "Linux 3-shard cluster",
                    "expire-on-delta": 2,
                },
                "configsvr_instance_type": "t1.micro",
                "expire-on-delta": 24,
            },
        )

    def test_defaults(self):
        """Test value from defaults.yml"""
        self.assertEqual(
            self.conf["mongodb_setup"]["mongod_config_file"]["net"]["port"],
            27017)
        self.assertEqual(
            self.conf["mongodb_setup"]["mongod_config_file"]
            ["processManagement"]["fork"], True)

    def test_copy(self):
        """Copy value into new python variable"""
        out = self.conf["infrastructure_provisioning"]["out"]
        self.conf.raw["infrastructure_provisioning"]["out"]["workload_client"][
            0]["private_ip"] = "foo"
        out.raw["workload_client"][0]["public_ip"] = "bar"
        self.assertTrue(isinstance(out, ConfigDict))
        self.assert_equal_lists(
            self.conf.raw["infrastructure_provisioning"]["out"]
            ["workload_client"],
            [{
                "public_ip": "bar",
                "private_ip": "foo"
            }],
        )
        self.assert_equal_lists(
            self.conf.root["infrastructure_provisioning"]["out"]
            ["workload_client"],
            [{
                "public_ip": "bar",
                "private_ip": "foo"
            }],
        )
        self.assert_equal_lists(out.raw["workload_client"],
                                [{
                                    "public_ip": "bar",
                                    "private_ip": "foo"
                                }])
        self.assert_equal_lists(
            out.root["infrastructure_provisioning"]["out"]["workload_client"],
            [{
                "public_ip": "bar",
                "private_ip": "foo"
            }],
        )
        self.assert_equal_dicts(out.overrides, {})
        self.assertEqual(out["workload_client"][0]["public_ip"], "bar")

    def test_items(self):
        actual = {k for k, v in self.conf["bootstrap"].items()}
        expect = {
            "production",
            "analysis",
            "workload_setup",
            "terraform",
            "infrastructure_provisioning",
            "overrides",
            "storageEngine",
            "test_control",
            "platform",
            "mongodb_setup",
        }

        self.assertEqual(actual, expect)

    def test_variable_references(self):
        """Test ${variable.references}"""
        self.assertEqual(
            self.conf["mongodb_setup"]["topology"][0]["mongos"][0]
            ["private_ip"], "10.2.1.100")
        self.assertEqual(
            self.conf["mongodb_setup"]["meta"]["hosts"],
            "10.2.1.100:27017,10.2.1.101:27017,10.2.1.102:27017",
        )

        # reference to reference
        self.assertEqual(self.conf["mongodb_setup"]["meta"]["hostname"],
                         "10.2.1.100")

        # recursive reference ${a.${foo}.c} where "foo: b"
        value = self.conf["test_control"]["run"][0]["workload_config"][
            "tests"]["default"][2]["insert_vector"]["thread_levels"]
        expected = [1, 8, 16]
        self.assertEqual(value, expected)

    def test_variable_reference_in_list(self):
        """Test ${variable.references} in a list"""
        self.assertEqual(
            self.conf["mongodb_setup"]["validate"]["primaries"][0],
            "10.2.1.1:27017")

    def test_variable_reference_value_error(self):
        """Test ${variable.references} that point to nonexisting value PERF-1705"""
        # PERF-1705 happened when infrastructure_provisioning.out doesn't exist and variable
        # references point to it.
        del self.conf.raw["infrastructure_provisioning"]["out"]

        # ConfigDict is late binding
        # assert_valid_ids() (used in load()) should not raise for such variable references.
        self.conf.assert_valid_ids()

        # Otoh actively accessing a field with such a variable reference must raise
        with self.assertRaises(ValueError):
            _ = self.conf["mongodb_setup"]["meta"]["mongodb_url"]

        # As must other methods where user causes entire ConfigDict to be traversed
        with self.assertRaises(ValueError):
            _ = self.conf.as_dict()
        with self.assertRaises(ValueError):
            _ = str(self.conf)

    def test_per_node_mongod_config(self):
        """Test magic per_node_mongod_config() (merging the common mongod_config_file with per node config_file)"""
        mycluster = self.conf["mongodb_setup"]["topology"][0]
        mongod = mycluster["shard"][2]["mongod"][0]
        self.assert_equal_dicts(
            mycluster["shard"][0]["mongod"][0]["config_file"],
            {
                "replication": {
                    "replSetName": "override-rs"
                },
                "systemLog": {
                    "path": "data/logs/mongod.log",
                    "destination": "file"
                },
                "setParameter": {
                    "enableTestCommands": True,
                    "foo": True
                },
                "net": {
                    "port": 27017,
                    "bindIp": "0.0.0.0"
                },
                "processManagement": {
                    "fork": True
                },
                "storage": {
                    "engine": "wiredTiger",
                    "dbPath": "data/dbs"
                },
            },
        )
        self.assert_equal_dicts(
            mycluster["shard"][2]["mongod"][0]["config_file"],
            {
                "replication": {
                    "replSetName": "override-rs"
                },
                "systemLog": {
                    "path": "data/logs/mongod.log",
                    "destination": "file"
                },
                "setParameter": {
                    "enableTestCommands": True,
                    "foo": True
                },
                "net": {
                    "port": 27017,
                    "bindIp": "0.0.0.0"
                },
                "processManagement": {
                    "fork": True
                },
                "storage": {
                    "engine": "inMemory",
                    "dbPath": "data/dbs"
                },
            },
        )
        self.assert_equal_dicts(
            mycluster["shard"][2]["mongod"][0]["config_file"].overrides,
            {"storage": {
                "engine": "inMemory"
            }},
        )
        self.assertEqual(
            mycluster["shard"][2]["mongod"][0]["config_file"]["storage"]
            ["engine"], "inMemory")
        self.assertEqual(
            mycluster["shard"][2]["mongod"][0]["config_file"]["net"]["port"],
            27017)
        self.assertEqual(
            mycluster["shard"][2]["mongod"][0]["config_file"]["net"]["bindIp"],
            "0.0.0.0")
        self.assertEqual(
            mycluster["shard"][2]["mongod"][0]["config_file"]
            ["processManagement"]["fork"], True)
        self.assertEqual(
            mongod.raw,
            {
                "public_ip":
                "${infrastructure_provisioning.out.mongod.6.public_ip}",
                "mongodb_binary_archive":
                "${mongodb_setup.mongodb_binary_archive}",
                "config_file": {
                    "storage": {
                        "engine": "inMemory"
                    }
                },
                "private_ip":
                "${infrastructure_provisioning.out.mongod.6.private_ip}",
            },
        )
        # Standalone node
        self.assert_equal_dicts(
            self.conf["mongodb_setup"]["topology"][2]["config_file"],
            {
                "replication": {
                    "replSetName": "override-rs"
                },
                "systemLog": {
                    "path": "data/logs/mongod.log",
                    "destination": "file"
                },
                "setParameter": {
                    "enableTestCommands": True,
                    "foo": True
                },
                "net": {
                    "port": 27017,
                    "bindIp": "0.0.0.0"
                },
                "processManagement": {
                    "fork": True
                },
                "storage": {
                    "engine": "wiredTiger",
                    "dbPath": "data/dbs"
                },
            },
        )
        # self.keys() should return a 'config_file' key
        self.assertTrue(
            "config_file" in mycluster["shard"][0]["mongod"][0].keys())
        self.assertTrue(
            "config_file" in mycluster["shard"][2]["mongod"][0].keys())
        self.assertTrue(
            "config_file" in self.conf["mongodb_setup"]["topology"][2].keys())
        self.assertFalse(
            "config_file" in self.conf["mongodb_setup"]["topology"][0].keys())

    def test_replset_rs_conf(self):
        """Test magic rs_conf for a replset"""
        mycluster = self.conf["mongodb_setup"]["topology"][0]
        rs_conf = mycluster["shard"][2]["rs_conf"]
        self.assertEqual(rs_conf["protocolVersion"], 1)
        myreplset = self.conf["mongodb_setup"]["topology"][1]
        rs_conf = myreplset["rs_conf"]
        self.assertEqual(rs_conf["settings"]["chainingAllowed"], False)
        self.assertEqual(rs_conf["protocolVersion"], 1)

        # conf.keys() should return a 'config_file' key for replsets, not otherwise
        self.assertTrue("rs_conf" in mycluster["shard"][0].keys())
        self.assertTrue("rs_conf" in mycluster["shard"][2].keys())
        self.assertTrue("rs_conf" in myreplset.keys())
        self.assertFalse("rs_conf" in mycluster.keys())
        self.assertFalse(
            "rs_conf" in self.conf["mongodb_setup"]["topology"][2].keys())
        self.assertFalse(
            "rs_conf" in self.conf["infrastructure_provisioning"].keys())

    def test_set_some_values(self):
        """Set some values and write out file"""
        self.conf["mongodb_setup"]["out"] = {"foo": "bar"}
        # Read the value multiple times, because once upon a time that didn't work (believe it or not)
        self.assert_equal_dicts(self.conf["mongodb_setup"]["out"],
                                {"foo": "bar"})
        self.assert_equal_dicts(self.conf["mongodb_setup"]["out"],
                                {"foo": "bar"})
        self.assert_equal_dicts(self.conf["mongodb_setup"]["out"],
                                {"foo": "bar"})
        self.conf["mongodb_setup"]["out"]["zoo"] = "zar"
        self.assert_equal_dicts(self.conf["mongodb_setup"]["out"], {
            "foo": "bar",
            "zoo": "zar"
        })
        with self.assertRaises(KeyError):
            self.conf["foo"] = "bar"
        # Write the out file only if it doesn't already exist, and delete it when done
        file_name = os.path.join(
            whereami.dsi_repo_path("docs", "config-specs"),
            "mongodb_setup.out.yml")
        if os.path.exists(file_name):
            self.fail(
                "Cannot test writing docs/config-specs/mongodb_setup.out.yml file, file already exists."
            )
        else:
            self.conf.save()
            file_handle = open(file_name)
            saved_out_file = yaml.safe_load(file_handle)
            file_handle.close()
            self.assert_equal_dicts({"out": self.conf["mongodb_setup"]["out"]},
                                    saved_out_file)
            os.remove(file_name)

    def test_iterators(self):
        """Test that iterators .keys() and .values() work"""
        mycluster = self.conf["mongodb_setup"]["topology"][0]
        self.assertEqual(
            set(self.conf.keys()),
            {
                "test_control",
                "workload_setup",
                "runtime_secret",
                "bootstrap",
                "mongodb_setup",
                "analysis",
                "infrastructure_provisioning",
                "runtime",
            },
        )

        tfvars_dict = dict(
            zip(
                self.conf["infrastructure_provisioning"]["tfvars"].keys(),
                self.conf["infrastructure_provisioning"]["tfvars"].values(),
            ))

        self.assert_equal_dicts(
            tfvars_dict,
            {
                "cluster_name": "shard",
                "mongos_instance_type": "c3.8xlarge",
                "availability_zone": "us-west-2a",
                "workload_instance_count": 1,
                "image": "amazon2",
                "region": "us-west-2",
                "mongod_instance_count": 9,
                "configsvr_instance_count": 3,
                "mongos_instance_count": 3,
                "ssh_key_file": "~/.ssh/linustorvalds.pem",
                "ssh_user": "******",
                "mongod_instance_type": "c3.8xlarge",
                "ssh_key_name": "linus.torvalds",
                "workload_instance_type": "c3.8xlarge",
                "tags": {
                    "Project": "sys-perf",
                    "owner": "*****@*****.**",
                    "Variant": "Linux 3-shard cluster",
                    "expire-on-delta": 2,
                },
                "configsvr_instance_type": "t1.micro",
                "expire-on-delta": 24,
            },
        )

        # Order doesn't matter, but this can't be a set because dicts aren't hashable.
        expect = [
            "53.1.1.7",
            "https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-amazon-3.4.6.tgz",
            "10.2.1.7",
        ]
        actual = mycluster["shard"][2]["mongod"][0].values()
        # There is a 4th item in the list that's not in expect because we can't easily assert on its value.
        self.assertEqual(4, len(actual))
        for val in expect:
            self.assertTrue(val in actual, f"Val {val} not present")

    def test_lookup_path(self):
        """check that the lookup_path works as expected."""

        conf = self.conf["infrastructure_provisioning"]["out"]

        self.assertIsInstance(conf.lookup_path("mongod"), list)
        self.assertIsInstance(conf.lookup_path("mongod.0"), ConfigDict)
        self.assertIsInstance(conf.lookup_path("mongod.0.public_ip"), str)

        # hard coded but quick and easy
        mongod = ["53.1.1.{}".format(i) for i in range(1, 10)]
        mongos = ["53.1.1.{}".format(i) for i in range(100, 102)]
        configsvr = ["53.1.1.{}".format(i) for i in range(51, 54)]
        workload_client = ["53.1.1.101"]

        self.assertEqual(conf.lookup_path("mongod.0.public_ip"), mongod[0])

        self.assertEqual(conf.lookup_path("mongod.1.public_ip"), mongod[1])
        self.assertEqual(conf.lookup_path("mongod.4.public_ip"), mongod[4])

        self.assertEqual(conf.lookup_path("mongos.0.public_ip"), mongos[0])
        self.assertEqual(conf.lookup_path("configsvr.0.public_ip"),
                         configsvr[0])
        self.assertEqual(conf.lookup_path("workload_client.0.public_ip"),
                         workload_client[0])

        # document that this is the current behavior
        self.assertEqual(conf.lookup_path("mongod.-1.public_ip"), mongod[-1])

    def test_lookup_path_ex(self):
        """check that lookup_path throws exceptions for the correct portion of the pathspec."""

        conf = self.conf["infrastructure_provisioning"]["out"]
        self.assertRaisesRegex(
            KeyError,
            "ConfigDict: Key not found: MONGOD in path \\['MONGOD'\\]",
            conf.lookup_path,
            "MONGOD",
        )
        self.assertRaisesRegex(
            KeyError,
            "ConfigDict: Key not found: MONGOD in path \\['MONGOD', 50\\]",
            conf.lookup_path,
            "MONGOD.50",
        )
        self.assertRaisesRegex(
            KeyError,
            "ConfigDict: list index out of range: mongod.50 in path \\['mongod', 50\\]",
            conf.lookup_path,
            "mongod.50",
        )
        self.assertRaisesRegex(
            KeyError,
            "ConfigDict: Key not found: mongod.50e-1 in path \\['mongod', '50e-1'\\]",
            conf.lookup_path,
            "mongod.50e-1",
        )
        self.assertRaisesRegex(
            KeyError,
            "ConfigDict: list index out of range: mongod.50 in path \\['mongod', 50, 'public_ip'\\]",
            conf.lookup_path,
            "mongod.50.public_ip",
        )
        self.assertRaisesRegex(
            KeyError,
            "ConfigDict: Key not found: mongod.0.0 in path \\['mongod', 0, 0\\]",
            conf.lookup_path,
            "mongod.0.0",
        )
        self.assertRaisesRegex(
            KeyError,
            "ConfigDict: list index out of range: mongod.50 in path \\['mongod', 50, 'public_ip', 0\\]",
            conf.lookup_path,
            "mongod.50.public_ip.0",
        )

    # Helpers
    def assert_equal_dicts(self, dict1, dict2):
        """Compare 2 dicts element by element for equal values."""
        dict1keys = list(dict1.keys())
        dict2keys = list(dict2.keys())
        self.assertEqual(len(dict1keys), len(dict2keys))
        self.assertEqual(set(dict1keys), set(dict2keys))
        for dict1key in dict1keys:
            # Pop the corresponding key from dict2, note that they won't be in the same order.
            dict2key = dict2keys.pop(dict2keys.index(dict1key))
            self.assertEqual(
                dict1key,
                dict2key,
                "assert_equal_dicts failed: mismatch in keys: " +
                str(dict1key) + "!=" + str(dict2key),
            )
            if isinstance(dict1[dict1key], dict):
                self.assert_equal_dicts(dict1[dict1key], dict2[dict2key])
            elif isinstance(dict1[dict1key], list):
                self.assert_equal_lists(dict1[dict1key], dict2[dict2key])
            else:
                self.assertEqual(
                    dict1[dict1key],
                    dict2[dict2key],
                    "assert_equal_dicts failed: mismatch in values.",
                )
        self.assertEqual(len(dict2keys), 0)

    def assert_equal_lists(self, list1, list2):
        """Compare 2 lists element by element for equal values."""
        self.assertEqual(len(list1), len(list2))
        for list1value in list1:
            list2value = list2.pop(0)
            if isinstance(list1value, dict):
                self.assert_equal_dicts(list1value, list2value)
            elif isinstance(list1value, list):
                self.assert_equal_lists(list1value, list2value)
            else:
                self.assertEqual(list1value, list2value,
                                 "{} != {}".format(list1, list2))
        self.assertEqual(len(list2), 0)
class TestTerraformConfiguration(unittest.TestCase):
    """To test terraform configuration class."""
    def setUp(self):
        """ Load self.config (ConfigDict) and set some other common values """
        self.config = ConfigDict(
            "infrastructure_provisioning",
            whereami.dsi_repo_path("docs", "config-specs"))
        self.config.load()

        cookiejar = requests.cookies.RequestsCookieJar()
        request = requests.Request("GET", "http://ip.42.pl/raw")
        request.prepare()
        self.response_state = {
            "cookies": cookiejar,
            "_content": b"ip.42.hostname",
            "encoding": "UTF-8",
            "url": "http://ip.42.pl/raw",
            "status_code": 200,
            "request": request,
            "elapsed": datetime.timedelta(0, 0, 615501),
            "headers": {
                "Content-Length": "14",
                "X-Powered-By": "PHP/5.6.27",
                "Keep-Alive": "timeout=5, max=100",
                "Server":
                "Apache/2.4.23 (FreeBSD) OpenSSL/1.0.1l-freebsd PHP/5.6.27",
                "Connection": "Keep-Alive",
                "Date": "Tue, 25 Jul 2017 14:20:06 GMT",
                "Content-Type": "text/html; charset=UTF-8",
            },
            "reason": "OK",
            "history": [],
        }

    @patch("socket.gethostname")
    @patch("requests.get")
    def test_generate_runner_timeout_hostname(self, mock_requests_get,
                                              mock_gethostname):
        """ Test generate runner and error cases. Fall back to gethostname """
        mock_requests_get.side_effect = requests.exceptions.Timeout()
        mock_requests_get.return_value = "MockedNotRaise"
        mock_gethostname.return_value = b"HostName"
        with LogCapture(level=logging.INFO) as log_output:
            self.assertEqual(terraform_config.generate_runner_hostname(),
                             "HostName")
            log_output.check(
                (
                    "dsi.common.terraform_config",
                    "INFO",
                    "Terraform_config.py _do_generate_runner could not access AWS"
                    "meta-data. Falling back to other methods",
                ),
                ("dsi.common.terraform_config", "INFO", "Timeout()"),
                (
                    "dsi.common.terraform_config",
                    "INFO",
                    "Terraform_config.py _do_generate_runner could not access"
                    " ip.42.pl to get public IP. Falling back to gethostname",
                ),
                ("dsi.common.terraform_config", "INFO", "Timeout()"),
            )

    @patch("socket.gethostname")
    @patch("requests.get")
    def test_generate_runner_awsmeta(self, mock_requests_get,
                                     mock_gethostname):
        """ Test generate runner, successfully getting data from aws """
        request = requests.Request(
            "GET", "http://169.254.169.254/latest/meta-data/public-hostname")
        request.prepare()
        response = requests.models.Response()
        self.response_state["request"] = request
        self.response_state["_content"] = b"awsdata"
        response.__setstate__(self.response_state)
        mock_requests_get.return_value = response
        mock_gethostname.return_value = b"HostName"
        with LogCapture(level=logging.INFO) as log_output:
            self.assertEqual(terraform_config.generate_runner_hostname(),
                             "awsdata")
            log_output.check()

    @patch("socket.gethostname")
    @patch("requests.get")
    def test_generate_runner_timeout_ip42(self, mock_requests_get,
                                          mock_gethostname):
        """ Test generate runner and error cases. Fall back to ip.42 call """
        mock_gethostname.return_value = b"HostName"
        response = requests.models.Response()
        response.__setstate__(self.response_state)
        mock_requests_get.side_effect = [
            requests.exceptions.Timeout(), response
        ]
        with LogCapture(level=logging.INFO) as log_output:
            self.assertEqual(terraform_config.generate_runner_hostname(),
                             "ip.42.hostname")
            log_output.check(
                (
                    "dsi.common.terraform_config",
                    "INFO",
                    "Terraform_config.py _do_generate_runner could not access AWS"
                    "meta-data. Falling back to other methods",
                ),
                ("dsi.common.terraform_config", "INFO", "Timeout()"),
            )

    @patch("socket.gethostname")
    @patch("requests.get")
    def test_generate_runner_timeout_ip42_404(self, mock_requests_get,
                                              mock_gethostname):
        """ Test generate runner and error cases. Timeout on aws, and 404 on ip42 """
        mock_gethostname.return_value = b"HostName"
        response = requests.models.Response()
        self.response_state["status_code"] = 404
        response.__setstate__(self.response_state)
        mock_requests_get.side_effect = [
            requests.exceptions.Timeout(), response
        ]
        with LogCapture(level=logging.INFO) as log_output:
            self.assertEqual(terraform_config.generate_runner_hostname(),
                             "HostName")
            log_output.check(
                (
                    "dsi.common.terraform_config",
                    "INFO",
                    "Terraform_config.py _do_generate_runner could not access AWS"
                    "meta-data. Falling back to other methods",
                ),
                ("dsi.common.terraform_config", "INFO", "Timeout()"),
                (
                    "dsi.common.terraform_config",
                    "INFO",
                    "Terraform_config.py _do_generate_runner could not access ip.42.pl to"
                    " get public IP. Falling back to gethostname",
                ),
                (
                    "dsi.common.terraform_config",
                    "INFO",
                    "HTTPError('404 Client Error: OK for url: http://ip.42.pl/raw')",
                ),
            )

    @patch("socket.gethostname")
    @patch("requests.get")
    def test_generate_runner_404_and_timeout(self, mock_requests_get,
                                             mock_gethostname):
        """ Test generate runner and error cases. 404 on aws and timeout on ip42.
        Fall back to gethostname """
        request = requests.Request(
            "GET", "http://169.254.169.254/latest/meta-data/public-hostname")
        request.prepare()
        self.response_state["request"] = request
        self.response_state["status_code"] = 404
        self.response_state[
            "url"] = "http://169.254.169.254/latest/meta-data/public-hostname"
        response = requests.models.Response()
        response.__setstate__(self.response_state)
        mock_requests_get.side_effect = [
            response, requests.exceptions.Timeout()
        ]
        mock_gethostname.return_value = b"HostName"
        with LogCapture(level=logging.INFO) as log_output:
            self.assertEqual(terraform_config.generate_runner_hostname(),
                             "HostName")
            log_output.check(
                (
                    "dsi.common.terraform_config",
                    "INFO",
                    "Terraform_config.py _do_generate_runner could not access AWSmeta-data."
                    " Falling back to other methods",
                ),
                (
                    "dsi.common.terraform_config",
                    "INFO",
                    "HTTPError('404 Client Error: OK for url: "
                    "http://169.254.169.254/latest/meta-data/public-hostname')",
                ),
                (
                    "dsi.common.terraform_config",
                    "INFO",
                    "Terraform_config.py _do_generate_runner could not access ip.42.pl to get"
                    " public IP. Falling back to gethostname",
                ),
                ("dsi.common.terraform_config", "INFO", "Timeout()"),
            )

    @patch("socket.gethostname")
    @patch("requests.get")
    def test_retrieve_runner_instance_id_awsmeta(self, mock_requests_get,
                                                 mock_gethostname):
        """ Test retrieve runner instance id, successfully getting data from aws """
        request = requests.Request(
            "GET", "http://169.254.169.254/latest/meta-data/instance-id")
        request.prepare()
        response = requests.models.Response()
        self.response_state["request"] = request
        self.response_state["_content"] = b"awsdata"
        response.__setstate__(self.response_state)
        mock_requests_get.return_value = response
        mock_gethostname.return_value = b"HostName"
        with LogCapture(level=logging.INFO) as log_output:
            self.assertEqual(terraform_config.generate_runner_hostname(),
                             "awsdata")
            log_output.check()

    @patch("socket.gethostname")
    @patch("requests.get")
    def test_retrieve_runner_instance_id_timeout(self, mock_requests_get,
                                                 mock_gethostname):
        """ Test retrieve runner instance id error case."""
        mock_gethostname.return_value = b"HostName"
        response = requests.models.Response()
        response.__setstate__(self.response_state)
        mock_requests_get.side_effect = [
            requests.exceptions.Timeout(), response
        ]
        with LogCapture(level=logging.INFO) as log_output:
            self.assertEqual(
                terraform_config.retrieve_runner_instance_id(),
                "deploying host is not an EC2 instance",
            )
            log_output.check(
                (
                    "dsi.common.terraform_config",
                    "INFO",
                    "Terraform_config.py retrieve_runner_instance_id could not access AWS"
                    "instance id.",
                ),
                ("dsi.common.terraform_config", "INFO", "Timeout()"),
            )

    @patch("dsi.common.terraform_config.generate_expire_on_tag")
    @patch("dsi.common.terraform_config.uuid4")
    @patch("dsi.common.terraform_config.generate_runner_hostname")
    @patch("dsi.common.terraform_config.retrieve_runner_instance_id")
    # pylint: disable=invalid-name
    def test_default(
        self,
        mock_retrieve_runner_instance_id,
        mock_generate_runner_hostname,
        mock_uuid4,
        mock_generate_expire_on_tag,
    ):
        """Test default terraform configuration."""
        # pylint: disable=line-too-long
        mock_uuid4.return_value = "mock-uuid-1234"
        mock_retrieve_runner_instance_id.return_value = "i-0c2aad81dfac5ca6e"
        mock_generate_runner_hostname.return_value = "111.111.111.111"
        mock_generate_expire_on_tag.return_value = "2018-10-13 14:19:51"

        expected_string = '{"Project":"sys-perf","Variant":"Linux 3-shard cluster","availability_zone":"us-west-2a","cluster_name":"shard","configsvr_instance_count":3,"configsvr_instance_type":"t1.micro","expire_on":"2018-10-13 14:19:51","image":"amazon2","mongod_instance_count":9,"mongod_instance_type":"c3.8xlarge","mongod_placement_group":"shard-mock-uuid-1234","mongos_instance_count":3,"mongos_instance_type":"c3.8xlarge","mongos_placement_group":"shard-mock-uuid-1234","owner":"*****@*****.**","placement_group":"shard-mock-uuid-1234","region":"us-west-2","runner_hostname":"111.111.111.111","ssh_key_file":"~/.ssh/linustorvalds.pem","ssh_key_name":"linus.torvalds","ssh_user":"******","status":"running","task_id":"123...","workload_instance_count":1,"workload_instance_type":"c3.8xlarge","workload_placement_group":"shard-mock-uuid-1234"}'
        tf_config = terraform_config.TerraformConfiguration(self.config)
        json_string = tf_config.to_json(compact=True)

        self.assertEqual(json_string, expected_string)

    def test_generate_expire_on_tag(self):
        """Test expire-on tag generator."""
        def fake_datetime_utcnow():
            return datetime.datetime(2018, 10, 13, 14, 19, 51)

        tag = terraform_config.generate_expire_on_tag(
            _datetime_utcnow=fake_datetime_utcnow)
        self.assertEqual(tag, "2018-10-13 16:19:51")

        tag = terraform_config.generate_expire_on_tag(
            1, _datetime_utcnow=fake_datetime_utcnow)
        self.assertEqual(tag, "2018-10-13 15:19:51")

        tag = terraform_config.generate_expire_on_tag(
            100, _datetime_utcnow=fake_datetime_utcnow)
        self.assertEqual(tag, "2018-10-17 18:19:51")

    def test_is_placement_group_needed(self):
        """Test is_placement_group_needed()"""
        tfvars = {
            "mongod_instance_type": "c3.8xlarge",
            "mongod_instance_count": 3,
            "mongos_instance_type": "c3.8xlarge",
            "mongos_instance_count": 0,
            "configsvr_instance_type": "t1.micro",
            "configsvr_instance_count": 3,
        }

        self.assertEqual(
            True, terraform_config.is_placement_group_needed("mongod", tfvars))
        self.assertEqual(
            False,
            terraform_config.is_placement_group_needed("mongos", tfvars))

        self.assertEqual(
            False,
            terraform_config.is_placement_group_needed("configsvr", tfvars))
Exemple #21
0
class HostUtilsTestCase(unittest.TestCase):
    """ Unit Tests for Host Utils library """

    def _delete_fixtures(self):
        """ delete FIXTURE_FILES path and set filename attribute """
        local_host_path = os.path.join(FIXTURE_FILES.fixture_file_path(), "fixtures")
        self.filename = os.path.join(local_host_path, "file")
        shutil.rmtree(os.path.dirname(self.filename), ignore_errors=True)

    def setUp(self):
        """ Init a ConfigDict object and load the configuration files from docs/config-specs/ """
        self.config = ConfigDict("mongodb_setup", whereami.dsi_repo_path("docs", "config-specs"))
        self.config.load()
        self.parent_dir = os.path.join(os.path.expanduser("~"), "checkout_repos_test")

        self._delete_fixtures()

    def tearDown(self):
        self._delete_fixtures()

    def test_never_timeout(self):
        """ test never_timeout"""
        self.assertFalse(host_utils.never_timeout())
        self.assertFalse(host_utils.never_timeout())

    def test_check_timed_out(self):
        """ test check_timed_out"""
        start = datetime.now()
        self.assertFalse(host_utils.check_timed_out(start, 50))
        time.sleep(51 / 1000.0)
        self.assertTrue(host_utils.check_timed_out(start, 50))

    def test_create_timer(self):
        """ test create_timer """
        start = datetime.now()
        self.assertEqual(host_utils.create_timer(start, None), host_utils.never_timeout)
        with patch("dsi.common.host_utils.partial") as mock_partial:
            self.assertTrue(host_utils.create_timer(start, 50))
            mock_partial.assert_called_once_with(host_utils.check_timed_out, start, 50)

    def test_extract_hosts(self):
        """ Test extract hosts using config info """

        default_host_info = host_utils.HostInfo(
            public_ip=None,
            # These are the user and key files used by this test.
            ssh_user="******",
            ssh_key_file=os.path.join(os.path.expanduser("~"), ".ssh", "linustorvalds.pem"),
            category=None,
            offset=-1,
        )

        def customize_host_info(new_ip, new_category, offset):
            new_host_info = copy.copy(default_host_info)
            new_host_info.public_ip = new_ip
            new_host_info.category = new_category
            new_host_info.offset = offset
            return new_host_info

        mongods = [customize_host_info("53.1.1.{}".format(i + 1), "mongod", i) for i in range(0, 9)]
        configsvrs = [
            customize_host_info("53.1.1.{}".format(i + 51), "configsvr", i) for i in range(0, 3)
        ]
        mongos = [
            customize_host_info("53.1.1.{}".format(i + 100), "mongos", i) for i in range(0, 3)
        ]
        workload_clients = [customize_host_info("53.1.1.101", "workload_client", 0)]
        localhost = [host_utils.HostInfo(public_ip="localhost", category="localhost", offset=0)]

        self.assertEqual(host_utils.extract_hosts("localhost", self.config), localhost)
        self.assertEqual(host_utils.extract_hosts("workload_client", self.config), workload_clients)
        self.assertEqual(host_utils.extract_hosts("mongod", self.config), mongods)
        self.assertEqual(host_utils.extract_hosts("mongos", self.config), mongos)
        self.assertEqual(host_utils.extract_hosts("configsvr", self.config), configsvrs)
        self.assertEqual(
            host_utils.extract_hosts("all_servers", self.config), mongods + mongos + configsvrs
        )
        self.assertEqual(
            host_utils.extract_hosts("all_hosts", self.config),
            mongods + mongos + configsvrs + workload_clients,
        )

    def test_stream_lines_timeout_on_first_line(self):
        """ Test stream_lines """
        source = [1]
        destination = MagicMock(name="destination")
        destination.write.side_effect = socket.timeout("args")
        any_lines = host_utils.stream_lines(source, destination)
        self.assertEqual(False, any_lines)
        destination.write.assert_has_calls([call(1)])

    def test_stream_lines_timeout_on_third_line(self):
        destination = MagicMock(name="destination")
        destination.write.side_effect = ["first", "second", socket.timeout("args"), "third"]

        source = [1, 2, 3]
        any_lines = host_utils.stream_lines(source, destination)
        self.assertEqual(True, any_lines)

        destination.write.assert_has_calls([call(1), call(2), call(3)])
Exemple #22
0
class LocalHostTestCase(unittest.TestCase):
    """ Unit Test for LocalHost library """
    def _delete_fixtures(self):
        """ delete fixture path and set filename attribute """
        local_host_path = FIXTURE_FILES.fixture_file_path("fixtures")
        self.filename = os.path.join(local_host_path, "file")
        shutil.rmtree(os.path.dirname(self.filename), ignore_errors=True)

    def setUp(self):
        """ Init a ConfigDict object and load the configuration files from docs/config-specs/ """
        self.config = ConfigDict(
            "mongodb_setup", whereami.dsi_repo_path("docs", "config-specs"))
        self.config.load()
        self.parent_dir = os.path.join(os.path.expanduser("~"),
                                       "checkout_repos_test")

        self._delete_fixtures()

    def tearDown(self):
        """ Restore working directory """
        self._delete_fixtures()

    def test_local_host_exec_command(self):
        """ Test LocalHost.exec_command """

        local = local_host.LocalHost()
        utils.mkdir_p(os.path.dirname(self.filename))

        self.assertEqual(local.exec_command("exit 0"), 0)

        # test that the correct warning is issued
        mock_logger = MagicMock(name="LOG")
        local_host.LOG.warning = mock_logger
        self.assertEqual(local.exec_command("exit 1"), 1)
        mock_logger.assert_called_once_with(
            ANY_IN_STRING("Failed with exit status"), ANY, ANY, ANY)

        local.exec_command("touch {}".format(self.filename))
        self.assertTrue(os.path.isfile(self.filename))

        local.exec_command("touch {}".format(self.filename))
        self.assertTrue(os.path.isfile(self.filename))

        with open(self.filename, "w", encoding="utf-8") as the_file:
            the_file.write("Hello\n")
            the_file.write("World\n")
        out = StringIO()
        err = StringIO()
        local.exec_command("cat {}".format(self.filename), out, err)
        self.assertEqual(out.getvalue(), "Hello\nWorld\n")

        out = StringIO()
        err = StringIO()
        self.assertEqual(
            local.exec_command("cat {}; exit 1".format(self.filename), out,
                               err), 1)
        self.assertEqual(out.getvalue(), "Hello\nWorld\n")
        self.assertEqual(err.getvalue(), "")

        out = StringIO()
        err = StringIO()
        local.exec_command("cat {} >&2; exit 1".format(self.filename), out,
                           err)
        self.assertEqual(out.getvalue(), "")
        self.assertEqual(err.getvalue(), "Hello\nWorld\n")

        out = StringIO()
        err = StringIO()
        command = """cat {filename} && cat -n {filename} >&2; \
        exit 1""".format(filename=self.filename)
        local.exec_command(command, out, err)
        self.assertEqual(out.getvalue(), "Hello\nWorld\n")
        self.assertEqual(err.getvalue(), "     1\tHello\n     2\tWorld\n")

        out = StringIO()
        err = StringIO()
        command = "seq 10 -1 1 | xargs  -I % sh -c '{ echo %; sleep .1; }'; \
        echo 'blast off!'"

        local.exec_command(command, out, err)
        self.assertEqual(out.getvalue(),
                         "10\n9\n8\n7\n6\n5\n4\n3\n2\n1\nblast off!\n")
        self.assertEqual(err.getvalue(), "")

        # test timeout and that the correct warning is issued
        out = StringIO()
        err = StringIO()
        command = "sleep 1"

        mock_logger = MagicMock(name="LOG")
        local_host.LOG.warning = mock_logger
        self.assertEqual(
            local.exec_command(command, out, err, max_time_ms=500), 1)
        mock_logger.assert_called_once_with(ANY_IN_STRING("Timeout after"),
                                            ANY, ANY, ANY, ANY)

    def test_local_host_tee(self):
        """ Test run command map retrieve_files """

        local = local_host.LocalHost()
        utils.mkdir_p(os.path.dirname(self.filename))

        expected = "10\n9\n8\n7\n6\n5\n4\n3\n2\n1\nblast off!\n"
        with open(self.filename, "w", encoding="utf-8") as the_file:
            out = StringIO()
            tee = TeeStream(the_file, out)
            err = StringIO()
            command = "seq 10 -1 1 | xargs  -I % sh -c '{ echo %; sleep .1; }'; \
        echo 'blast off!'"

            local.exec_command(command, tee, err)
            self.assertEqual(out.getvalue(), expected)
            self.assertEqual(err.getvalue(), "")

        with open(self.filename) as the_file:
            self.assertEqual(expected, "".join(the_file.readlines()))
Exemple #23
0
class TestTerraformOutputParser(unittest.TestCase):
    """To test terraform configuration"""

    def setUp(self):
        """Setup so config dict works properly"""
        self.config = ConfigDict(
            "infrastructure_provisioning", whereami.dsi_repo_path("docs", "config-specs")
        )
        self.config.load()

    def test_single_cluster_value(self):
        """Test parsing single cluster value is correct."""
        output = tf_output.TerraformOutputParser(
            config=self.config,
            input_file=FIXTURE_FILES.fixture_file_path("terraform_single_cluster_output.txt"),
        )

        print(output._ips)

        self.assertEqual(["10.2.0.10"], output._ips["private_ip_mc"])
        self.assertEqual(["52.32.13.97"], output._ips["public_ip_mc"])
        self.assertEqual(["52.26.153.91"], output._ips["public_member_ip"])
        self.assertEqual(["10.2.0.100"], output._ips["private_member_ip"])

    def test_replica_ebs_cluster_value(self):
        """Test parsing replica_ebs cluster."""
        output = tf_output.TerraformOutputParser(
            config=self.config,
            input_file=FIXTURE_FILES.fixture_file_path("terraform_replica_with_ebs_output.txt"),
        )

        print(output._ips)

        self.assertEqual("52.33.30.1", output._ips["public_ip_mc"][0])
        self.assertEqual("10.2.0.10", output._ips["private_ip_mc"][0])
        self.assertEqual("52.41.40.0", output._ips["public_member_ip"][0])
        self.assertEqual("52.37.52.162", output._ips["public_member_ip"][1])
        self.assertEqual("52.25.102.16", output._ips["public_member_ip"][2])
        self.assertEqual("52.25.102.17", output._ips["public_member_ip"][3])
        self.assertEqual("10.2.0.100", output._ips["private_member_ip"][0])

    def test_shard_cluster_value(self):
        """Test parsing shard cluster value is correct."""
        output = tf_output.TerraformOutputParser(
            config=self.config,
            input_file=FIXTURE_FILES.fixture_file_path("terraform_shard_cluster_output.txt"),
        )

        print(output._ips)

        # Test ip address is correct for different members
        self.assertEqual("10.2.0.10", output._ips["private_ip_mc"][0])
        self.assertEqual("52.11.198.150", output._ips["public_ip_mc"][0])
        self.assertEqual("52.26.155.122", output._ips["public_member_ip"][0])
        self.assertEqual("52.38.108.78", output._ips["public_member_ip"][4])
        self.assertEqual("10.2.0.100", output._ips["private_member_ip"][0])
        self.assertEqual("10.2.0.106", output._ips["private_member_ip"][6])

        self.assertEqual("52.38.116.84", output._ips["public_config_ip"][0])
        self.assertEqual("52.27.136.80", output._ips["public_config_ip"][1])
        self.assertEqual("10.2.0.81", output._ips["private_config_ip"][0])
        self.assertEqual("10.2.0.83", output._ips["private_config_ip"][2])

        # Test total monogod count
        self.assertEqual(9, len(output._ips["public_member_ip"]))
        self.assertEqual(9, len(output._ips["private_member_ip"]))

        # Test config_server count
        self.assertEqual(3, len(output._ips["public_config_ip"]))
        self.assertEqual(3, len(output._ips["private_config_ip"]))

    def test_single_cluster_yml(self):
        """Test parsing single cluster YML file is correct."""
        output = tf_output.TerraformOutputParser(
            config=self.config,
            input_file=FIXTURE_FILES.fixture_file_path("terraform_single_cluster_output.txt"),
        )
        output._generate_output()
        reference = {}
        with open(FIXTURE_FILES.fixture_file_path("terraform_single.out.yml")) as fread:
            reference = yaml.safe_load(fread)

        print(reference["out"])
        print(output.config_obj["infrastructure_provisioning"]["out"])
        self.assertEqual(
            output.config_obj["infrastructure_provisioning"]["out"].as_dict(), reference["out"]
        )

    def test_shard_cluster_yml(self):
        """Test parsing single cluster YML file is correct."""
        output = tf_output.TerraformOutputParser(
            config=self.config,
            input_file=FIXTURE_FILES.fixture_file_path("terraform_shard_cluster_output.txt"),
        )

        output._generate_output()
        with open(FIXTURE_FILES.fixture_file_path("terraform_shard.out.yml")) as fread:
            reference = yaml.safe_load(fread)

        print(reference["out"])
        print(output.config_obj["infrastructure_provisioning"]["out"])
        self.assertEqual(
            output.config_obj["infrastructure_provisioning"]["out"].as_dict(), reference["out"]
        )
Exemple #24
0
def load_bootstrap(config, directory):
    """
    Move specified bootstrap.yml file to correct location for read_runtime_values
    """
    # Create directory if it doesn't exist
    if not os.path.exists(directory):
        os.makedirs(directory)

    if "bootstrap_file" in config:
        bootstrap_path = os.path.abspath(
            os.path.expanduser(config["bootstrap_file"]))
        if os.path.isfile(bootstrap_path):
            if not bootstrap_path == os.path.abspath(
                    os.path.join(directory, "bootstrap.yml")):
                if os.path.isfile(
                        os.path.abspath(
                            os.path.join(directory, "bootstrap.yml"))):
                    LOGGER.critical(
                        "Attempting to overwrite existing bootstrap.yml file. Aborting.",
                        directory=directory,
                    )
                    assert False
                shutil.copyfile(bootstrap_path,
                                os.path.join(directory, "bootstrap.yml"))
        else:
            LOGGER.critical("Location specified for bootstrap.yml is invalid.")
            assert False
    else:
        bootstrap_path = os.path.abspath(
            os.path.expanduser(os.path.join(os.getcwd(), "bootstrap.yml")))
        if os.path.isfile(bootstrap_path):
            if not bootstrap_path == os.path.abspath(
                    os.path.join(directory, "bootstrap.yml")):
                if os.path.isfile(
                        os.path.abspath(
                            os.path.join(directory, "bootstrap.yml"))):
                    LOGGER.critical(
                        "Attempting to overwrite existing bootstrap.yml file in %s. "
                        "Aborting.",
                        directory,
                    )
                    assert False
                shutil.copyfile(bootstrap_path,
                                os.path.join(directory, "bootstrap.yml"))

    expansions.write_if_necessary(directory)

    current_path = os.getcwd()
    os.chdir(directory)
    config_dict = ConfigDict("bootstrap")
    config_dict.load()
    for key in config_dict["bootstrap"].keys():
        config[key] = config_dict["bootstrap"][key]

    # terraform required_version must be specified, we fail hard if user has tried to unset
    config["terraform_version_check"] = config_dict[
        "infrastructure_provisioning"]["terraform"]["required_version"]
    config["terraform_linux_download"] = config_dict[
        "infrastructure_provisioning"]["terraform"]["linux_download"]
    config["terraform_mac_download"] = config_dict[
        "infrastructure_provisioning"]["terraform"]["mac_download"]

    os.chdir(current_path)

    return config_dict
class TestConfigTestControl(unittest.TestCase):
    """ Test config_test_control.py"""
    def setUp(self):
        """
        Setup basic environment
        """
        # Mocking `ConfigDict.assert_valid_ids` because it enforces structural constraints on yaml
        # files that aren't necessary here.
        with patch("dsi.common.config.ConfigDict.assert_valid_ids"
                   ) as mock_assert_valid_ids:
            self.config = ConfigDict(
                "test_control",
                FIXTURE_FILES.fixture_file_path("config_test_control"))
            self.config.load()
            mock_assert_valid_ids.assert_called_once()

    def tearDown(self):
        file_name = FIXTURE_FILES.fixture_file_path("workloads.yml")
        if os.path.exists(file_name):
            os.remove(file_name)

    def test_benchrun_workload_config(self):
        """
        Test that generate_config_files works with a benchrun workload
        """
        test = self.config["test_control"]["run"][0]
        mock_host = Mock(spec=RemoteHost)
        test_control.generate_config_file(test,
                                          FIXTURE_FILES.fixture_file_path(),
                                          mock_host)
        self.assertEqual(
            FIXTURE_FILES.load_yaml_file("config_test_control",
                                         "workloads.yml"),
            FIXTURE_FILES.load_yaml_file("config_test_control",
                                         "workloads.benchrun.yml.ok"),
            "workloads.yml doesn't match expected for test_control.yml",
        )
        mock_host.upload_file.assert_called_once_with(
            FIXTURE_FILES.fixture_file_path(test["config_filename"]),
            test["config_filename"])

    def test_ycsb_workload_config(self):
        """
        Test that generate_config_files works with a ycsb run
        """
        test = self.config["test_control"]["run"][1]
        mock_host = Mock(spec=RemoteHost)
        test_control.generate_config_file(
            test, FIXTURE_FILES.fixture_file_path("config_test_control"),
            mock_host)
        self.assertEqual(
            FIXTURE_FILES.load_yaml_file("config_test_control",
                                         "workloadEvergreen"),
            FIXTURE_FILES.load_yaml_file("config_test_control",
                                         "workloadEvergreen.ok"),
            "workloadEvergreen doesn't match expected for test_control.yml",
        )
        mock_host.upload_file.assert_called_once_with(
            FIXTURE_FILES.fixture_file_path("config_test_control",
                                            test["config_filename"]),
            test["config_filename"],
        )

    @patch("dsi.test_control.open")
    def test_generate_config_no_config(self, mock_open):
        """
        Test that generate_config_file doesn't create a workload file and logs the correct message
        if there is no config file
        """
        test = self.config["test_control"]["run"][2]
        mock_host = Mock(spec=RemoteHost)
        with LogCapture(level=logging.WARNING) as warning:
            test_control.generate_config_file(
                test, FIXTURE_FILES.repo_root_file_path(), mock_host)
        warning.check(("dsi.test_control", "WARNING",
                       "No workload config in test control"))
        mock_open.assert_not_called()
        mock_host.upload_file.assert_not_called()
Exemple #26
0
class CommandRunnerTestCase(unittest.TestCase):
    """ Unit Tests for Host Utils library """
    def _delete_fixtures(self):
        """ delete fixture path and set filename attribute """
        local_host_path = os.path.join(FIXTURE_FILES.fixture_file_path(),
                                       "fixtures")
        self.filename = os.path.join(local_host_path, "file")
        shutil.rmtree(os.path.dirname(self.filename), ignore_errors=True)

    def setUp(self):
        """ Init a ConfigDict object and load the configuration files from docs/config-specs/ """
        self.old_dir = os.getcwd()  # Save the old path to restore
        self.config = ConfigDict(
            "mongodb_setup", whereami.dsi_repo_path("docs", "config-specs"))
        self.config.load()
        self.parent_dir = os.path.join(os.path.expanduser("~"),
                                       "checkout_repos_test")

        self._delete_fixtures()
        self.reports_container = os.path.join(
            FIXTURE_FILES.fixture_file_path(), "container")
        self.reports_path = os.path.join(self.reports_container,
                                         "reports_tests")

        mkdir_p(self.reports_path)

    def tearDown(self):
        """ Restore working directory """
        shutil.rmtree(self.reports_container)

        os.chdir(self.old_dir)

        self._delete_fixtures()

    @patch("dsi.common.command_runner._run_host_command_map")
    def test_make_host_runner_str(self, mock_run_host_command_map):
        """ Test run RemoteHost.make_host_runner with str"""
        with patch("dsi.common.host_factory.make_host") as mock_make_host:
            mock_target_host = Mock()
            mock_make_host.return_value = mock_target_host

            dummy_host_info = HostInfo("host_info")

            command_runner.make_host_runner(dummy_host_info, "command",
                                            "test_id")
            mock_make_host.assert_called_once_with(dummy_host_info, None, None)
            mock_target_host.run.assert_called_once_with("command")
            mock_target_host.close.assert_called_once()

    @patch("dsi.common.command_runner._run_host_command_map")
    def test_make_host_runner_map(self, mock_run_host_command_map):
        """ Test run Remotecommand_runner.make_host_runner with map"""

        with patch("dsi.common.host_factory.make_host") as mock_make_host:
            command = {}
            mock_target_host = Mock()
            mock_make_host.return_value = mock_target_host

            dummy_host_info = HostInfo("host_name")

            command_runner.make_host_runner(dummy_host_info, command,
                                            "test_id")

            mock_make_host.assert_called_once_with(dummy_host_info, None, None)

            mock_run_host_command_map.assert_called_once_with(
                mock_target_host, command, "test_id")
            mock_target_host.close.assert_called_once()

    def test_run_host_commands(self):
        """Test 2-commands common.command_runner.run_host_commands invocation"""
        with patch("dsi.common.host_factory.RemoteSSHHost") as mongod:
            commands = [
                {
                    "on_workload_client": {
                        "upload_files": [{
                            "source": "src1",
                            "target": "dest1"
                        }]
                    }
                },
                {
                    "on_workload_client": {
                        "upload_files": [{
                            "source": "src2",
                            "target": "dest2"
                        }]
                    }
                },
            ]
            command_runner.run_host_commands(commands, self.config, "test_id")
            self.assertEqual(mongod.call_count, 2)

    def test_run_host_command_map(self):
        """ Test run command map not known """

        with self.assertRaises(UserWarning):
            with patch("dsi.common.remote_host.RemoteHost") as mongod:
                command = {"garbage": {"remote_path": "mongos.log"}}
                command_runner._run_host_command_map(mongod, command,
                                                     "test_id")

    def __run_host_command_map_ex(self,
                                  command,
                                  run_return_value=False,
                                  exec_return_value=None):
        with patch("dsi.common.remote_host.RemoteHost") as mongod:
            if run_return_value is not None:
                mongod.run.return_value = run_return_value
            else:
                mongod.exec_mongo_command.return_value = exec_return_value
            command_runner._run_host_command_map(mongod, command, "test_id")

    def test__exec_ex(self):
        """ Test run command map excpetion """

        # test upload_files
        with self.assertRaisesRegex(host_utils.HostException,
                                    r"^\(1, .*cowsay moo"):
            command = {"exec": "cowsay moo"}
            self.__run_host_command_map_ex(command)

    def test__exec_mongo_shell_ex(self):
        """ Test run command map excpetion """

        with self.assertRaisesRegex(host_utils.HostException,
                                    r"^\(1, .*this is a script"):
            command = {
                "exec_mongo_shell": {
                    "script": "this is a script",
                    "connection_string": "connection string",
                }
            }
            self.__run_host_command_map_ex(command,
                                           run_return_value=None,
                                           exec_return_value=1)

    def test_upload_repo_files(self):
        """ Test run command map upload_repo_files """
        root = whereami.dsi_repo_path() + os.sep

        # test upload_repo_files
        with patch("dsi.common.remote_host.RemoteHost") as mongod:
            command = {
                "upload_repo_files": [{
                    "target": "remote_path",
                    "source": "mongos.log"
                }]
            }
            command_runner._run_host_command_map(mongod, command, "test_id")
            mongod.upload_file.assert_called_once_with(root + "mongos.log",
                                                       "remote_path")

        with patch("dsi.common.remote_host.RemoteHost") as mongod:
            command = {
                "upload_repo_files": [
                    {
                        "target": "remote_path",
                        "source": "mongos.log"
                    },
                    {
                        "target": "to",
                        "source": "from"
                    },
                ]
            }
            command_runner._run_host_command_map(mongod, command, "test_id")
            calls = [
                mock.call(root + "mongos.log", "remote_path"),
                mock.call(root + "from", "to")
            ]
            mongod.upload_file.assert_has_calls(calls, any_order=True)

    def test_upload_files(self):
        """ Test run command map upload_files """

        # test upload_files
        with patch("dsi.common.remote_ssh_host.RemoteSSHHost") as mongod:
            command = {
                "upload_files": [{
                    "target": "remote_path",
                    "source": "mongos.log"
                }]
            }
            command_runner._run_host_command_map(mongod, command, "test_id")
            mongod.upload_file.assert_called_once_with("mongos.log",
                                                       "remote_path")

        with patch("dsi.common.remote_ssh_host.RemoteSSHHost") as mongod:
            command = {
                "upload_files": [
                    {
                        "source": "mongos.log",
                        "target": "remote_path"
                    },
                    {
                        "source": "to",
                        "target": "from"
                    },
                ]
            }
            command_runner._run_host_command_map(mongod, command, "test_id")
            calls = [
                mock.call("mongos.log", "remote_path"),
                mock.call("to", "from")
            ]
            mongod.upload_file.assert_has_calls(calls, any_order=True)

    def test_retrieve_files(self):
        """ Test run command map retrieve_files """

        # retrieve_files tests
        with patch("dsi.common.remote_host.RemoteHost") as mongod:
            mock_retrieve_file = Mock()
            mongod.retrieve_path = mock_retrieve_file

            command = {
                "retrieve_files": [{
                    "source": "remote_path",
                    "target": "mongos.log"
                }]
            }
            mongod.alias = "host"
            command_runner._run_host_command_map(mongod, command, "test_id")
            mock_retrieve_file.assert_any_call(
                "remote_path", "reports/test_id/host/mongos.log")

        # retrieve_files tests
        with patch("dsi.common.remote_host.RemoteHost") as mongod:
            mock_retrieve_file = Mock()
            mongod.retrieve_path = mock_retrieve_file

            command = {
                "retrieve_files": [{
                    "source": "remote_path",
                    "target": "mongos.log"
                }]
            }
            mongod.alias = "host"
            command_runner._run_host_command_map(mongod, command, "test_id")
            mock_retrieve_file.assert_any_call(
                "remote_path", "reports/test_id/host/mongos.log")

        with patch("dsi.common.remote_host.RemoteHost") as mongod:
            mock_retrieve_file = Mock()
            mongod.retrieve_path = mock_retrieve_file

            command = {
                "retrieve_files": [{
                    "source": "remote_path",
                    "target": "local_path"
                }]
            }
            mongod.alias = "host"
            command_runner._run_host_command_map(mongod, command, "test_id")
            mock_retrieve_file.assert_any_call(
                "remote_path", "reports/test_id/host/local_path")

        with patch("dsi.common.remote_host.RemoteHost") as mongod:
            mock_retrieve_file = Mock()
            mongod.retrieve_path = mock_retrieve_file

            mongod.alias = "mongod.0"
            command_runner._run_host_command_map(mongod, command, "test_id")
            mock_retrieve_file.assert_any_call(
                "remote_path", "reports/test_id/mongod.0/local_path")

        with patch("dsi.common.remote_host.RemoteHost") as mongod:
            mock_retrieve_file = Mock()
            mongod.retrieve_path = mock_retrieve_file

            command = {
                "retrieve_files": [{
                    "source": "remote_path",
                    "target": "./local_path"
                }]
            }
            mongod.alias = "mongos.0"
            command_runner._run_host_command_map(mongod, command, "test_id")
            mock_retrieve_file.assert_any_call(
                "remote_path", "reports/test_id/mongos.0/local_path")

    def test_exec(self):
        """ Test run command map exec """

        # test exec
        with patch("dsi.common.remote_host.RemoteHost") as mongod:
            command = {"exec": "this is a command"}
            mongod.run.return_value = True
            command_runner._run_host_command_map(mongod, command, "test_id")
            mongod.run.assert_called_once_with("this is a command")

    def test_exec_mongo_shell(self):
        """ Test run command map exec mongo shell """

        # test exec_mongo_shell
        with patch("dsi.common.remote_host.RemoteHost") as mongod:
            command = {
                "exec_mongo_shell": {
                    "script": "this is a script",
                    "connection_string": "connection string",
                }
            }
            mongod.exec_mongo_command.return_value = 0
            command_runner._run_host_command_map(mongod, command, "test_id")
            mongod.exec_mongo_command.assert_called_once_with(
                "this is a script", connection_string="connection string")

        with patch("dsi.common.remote_host.RemoteHost") as mongod:
            command = {"exec_mongo_shell": {"script": "this is a script"}}
            mongod.exec_mongo_command.return_value = 0
            command_runner._run_host_command_map(mongod, command, "test_id")
            mongod.exec_mongo_command.assert_called_once_with(
                "this is a script", connection_string="")

    def test_run_upon_error(self):
        """ test run_upon_error """
        @mock.patch("dsi.common.command_runner.prepare_reports_dir")
        @mock.patch("dsi.common.command_runner.run_pre_post_commands")
        def _test_run_upon_error(behave, mock_run_pre_post_commands,
                                 mock_prepare_reports_dir):

            config = {"mongodb_setup": "setup"}
            if behave:
                run_upon_error("mongodb_setup", [], config)
                expected = EXCEPTION_BEHAVIOR.EXIT
            else:
                run_upon_error("mongodb_setup", [], config,
                               EXCEPTION_BEHAVIOR.CONTINUE)
                expected = EXCEPTION_BEHAVIOR.CONTINUE

            mock_prepare_reports_dir.assert_called_once()
            mock_run_pre_post_commands.assert_called_with(
                "upon_error", [], config, expected, "upon_error/mongodb_setup")

        _test_run_upon_error(True)
        _test_run_upon_error(False)

    @patch("dsi.common.command_runner.run_host_command")
    def test_run_pre_post(self, mock_run_host_command):
        """Test test_control.run_pre_post_commands()"""
        command_dicts = [
            self.config["test_control"], self.config["mongodb_setup"]
        ]
        run_pre_post_commands("post_test", command_dicts, self.config,
                              EXCEPTION_BEHAVIOR.EXIT)

        # expected_args = ['on_workload_client', 'on_all_servers', 'on_mongod', 'on_configsvr']
        expected_args = [
            "on_mongod", "on_all_hosts", "on_all_servers", "on_mongod",
            "on_configsvr"
        ]
        observed_args = []
        for args in mock_run_host_command.call_args_list:
            observed_args.append(args[0][0])
        self.assertEqual(observed_args, expected_args)

    def test_prepare_reports_dir(self):
        """Test test_control.run_test where the exec command returns non-zero"""

        previous_directory = os.getcwd()
        reports_dir = os.path.join(self.reports_path, "reports")
        reports_tarball = os.path.join(self.reports_container, "reports.tgz")

        def _test_prepare_reports_dir():
            try:
                os.chdir(self.reports_path)
                prepare_reports_dir(reports_dir=reports_dir)
            finally:
                os.chdir(previous_directory)

            self.assertFalse(os.path.exists(reports_tarball))
            self.assertTrue(os.path.exists(reports_dir))
            self.assertTrue(os.path.islink(reports_dir))

        _test_prepare_reports_dir()

        touch(reports_tarball)
        _test_prepare_reports_dir()

        os.remove(reports_dir)
        mkdir_p(reports_dir)
        self.assertRaises(OSError, _test_prepare_reports_dir)
Exemple #27
0
class HostTestCase(unittest.TestCase):
    """ Unit Test for Host library """
    def _delete_fixtures(self):
        """ delete fixture path and set filename attribute """
        local_host_path = os.path.join(FIXTURE_FILES.fixture_file_path(),
                                       "fixtures")
        self.filename = os.path.join(local_host_path, "file")
        shutil.rmtree(os.path.dirname(self.filename), ignore_errors=True)

    def setUp(self):
        """ Init a ConfigDict object and load the configuration files from docs/config-specs/ """
        self.old_dir = os.getcwd()  # Save the old path to restore
        self.config = ConfigDict(
            "mongodb_setup", whereami.dsi_repo_path("docs", "config-specs"))
        self.config.load()
        self.parent_dir = os.path.join(os.path.expanduser("~"),
                                       "checkout_repos_test")

        self._delete_fixtures()

    def tearDown(self):
        """ Restore working directory """
        os.chdir(self.old_dir)

        self._delete_fixtures()

    def test_kill_remote_procs(self):
        """ Test kill_remote_procs """

        local = LocalHost()
        local.run = MagicMock(name="run")
        local.run.return_value = False
        self.assertTrue(local.kill_remote_procs("mongo"))

        calls = [
            call(["pkill", "-9", "mongo"], quiet=True),
            call(["pgrep", "mongo"], quiet=True)
        ]

        local.run.assert_has_calls(calls)

        with patch(
                "dsi.common.host_utils.create_timer") as mock_create_watchdog:

            local.run = MagicMock(name="run")
            local.run.return_value = False
            local.kill_remote_procs("mongo", max_time_ms=None)
            mock_create_watchdog.assert_called_once_with(ANY, None)

        with patch(
                "dsi.common.host_utils.create_timer") as mock_create_watchdog:

            local.run = MagicMock(name="run")
            local.run.return_value = False
            local.kill_remote_procs("mongo", max_time_ms=0, delay_ms=99)
            mock_create_watchdog.assert_called_once_with(ANY, 99)

        with patch(
                "dsi.common.host_utils.create_timer") as mock_create_watchdog:
            local = LocalHost()
            local.run = MagicMock(name="run")
            local.run.return_value = True

            mock_is_timed_out = MagicMock(name="is_timed_out")
            mock_create_watchdog.return_value = mock_is_timed_out
            mock_is_timed_out.side_effect = [False, True]
            self.assertFalse(local.kill_remote_procs("mongo", delay_ms=1))

        local = LocalHost()
        local.run = MagicMock(name="run")
        local.run.side_effect = [False, True, False, False]
        self.assertTrue(
            local.kill_remote_procs("mongo", signal_number=15, delay_ms=1))

        calls = [
            call(["pkill", "-15", "mongo"], quiet=True),
            call(["pgrep", "mongo"], quiet=True),
            call(["pkill", "-15", "mongo"], quiet=True),
            call(["pgrep", "mongo"], quiet=True),
        ]

        local.run.assert_has_calls(calls)
        # mock_sleep.assert_not_called()

    def test_kill_mongo_procs(self):
        """ Test kill_mongo_procs """
        local = LocalHost()
        local.kill_remote_procs = MagicMock(name="kill_remote_procs")
        local.kill_remote_procs.return_value = True
        self.assertTrue(local.kill_mongo_procs())
        local.kill_remote_procs.assert_called_once_with("mongo",
                                                        9,
                                                        max_time_ms=30000)

    @patch("paramiko.SSHClient")
    def test_alias(self, mock_ssh):
        """ Test alias """

        remote = RemoteHost("host", "user", "pem_file")
        self.assertEqual(remote.alias, "host")

        remote.alias = ""
        self.assertEqual(remote.alias, "host")

        remote.alias = None
        self.assertEqual(remote.alias, "host")

        remote.alias = "alias"
        self.assertEqual(remote.alias, "alias")

    @patch("paramiko.SSHClient")
    def test_run(self, mock_ssh):
        """Test Host.run on RemoteHost"""
        subject = RemoteHost("test_host", "test_user", "test_pem_file")

        # test string command
        subject.exec_command = MagicMock(name="exec_command")
        subject.exec_command.return_value = 0
        self.assertTrue(subject.run("cowsay Hello World", quiet=True))
        subject.exec_command.assert_called_once_with("cowsay Hello World",
                                                     quiet=True)

        # test string command
        subject.exec_command = MagicMock(name="exec_command")
        subject.exec_command.return_value = 0
        self.assertTrue(subject.run("cowsay Hello World"))
        subject.exec_command.assert_called_once_with("cowsay Hello World",
                                                     quiet=False)

        # Test fail
        subject.exec_command = MagicMock(name="exec_command")
        subject.exec_command.return_value = 1
        self.assertFalse(subject.run("cowsay Hello World"))
        subject.exec_command.assert_called_once_with("cowsay Hello World",
                                                     quiet=False)

        # test list command success
        subject.exec_command = MagicMock(name="exec_command")
        subject.exec_command.return_value = 0
        self.assertTrue(
            subject.run([["cowsay", "Hello", "World"], ["cowsay", "moo"]]))
        subject.exec_command.assert_any_call(["cowsay", "Hello", "World"],
                                             quiet=False)
        subject.exec_command.assert_any_call(["cowsay", "moo"], quiet=False)

        # test list command failure
        subject.exec_command = MagicMock(name="exec_command")
        subject.exec_command.side_effect = [0, 1, 0]
        self.assertFalse(
            subject.run([["cowsay", "Hello", "World"], ["cowsay", "moo"],
                         ["cowsay", "boo"]]))
        calls = [
            mock.call(["cowsay", "Hello", "World"], quiet=False),
            mock.call(["cowsay", "moo"], quiet=False),
        ]
        subject.exec_command.assert_has_calls(calls)

        # test list command failure
        subject.exec_command = MagicMock(name="exec_command")
        subject.exec_command.return_value = 0
        self.assertTrue(subject.run(["cowsay Hello World", "cowsay moo"]))
        subject.exec_command.assert_called_once_with(
            ["cowsay Hello World", "cowsay moo"], quiet=False)

    @nottest
    def helper_test_checkout_repos(self,
                                   source,
                                   target,
                                   commands,
                                   branch=None,
                                   verbose=True):
        """ test_checkout_repos common test code """
        local = LocalHost()

        # Test with non-existing target
        self.assertFalse(os.path.exists(target))
        with patch("dsi.common.host.mkdir_p") as mock_mkdir_p, patch(
                "dsi.common.local_host.LocalHost.exec_command"
        ) as mock_exec_command:
            local.checkout_repos(source,
                                 target,
                                 verbose=verbose,
                                 branch=branch)
            mock_mkdir_p.assert_called_with(self.parent_dir)
            if len(commands) == 1:
                mock_exec_command.assert_called_once()
                mock_exec_command.assert_called_with(commands[0])
            else:
                for command in commands:
                    mock_exec_command.assert_any_call(command)

    def test_checkout_repos(self):
        """
        Test Host.checkout_repos command
        """
        # Only testing on LocalHost since `checkout_repos` is implemented in the base class and not
        # overidden
        local = LocalHost()

        # Test with existing target that is not a git repository
        source = "[email protected]:mongodb/mongo.git"
        target = os.path.expanduser("~")
        command = ["cd", target, "&&", "git", "status"]
        with patch("dsi.common.host.mkdir_p") as mock_mkdir_p, patch(
                "dsi.common.local_host.LocalHost.exec_command"
        ) as mock_exec_command:
            self.assertRaises(UserWarning, local.checkout_repos, source,
                              target)
            mock_mkdir_p.assert_not_called()
            mock_exec_command.assert_called_once()
            mock_exec_command.assert_called_with(command)

    def test_checkout_repos_non_existing_target(self):

        # # Test with non-existing target
        source = "https://github.com/mongodb/stitch-js-sdk.git"
        target = os.path.join(self.parent_dir, "bin.stitch-js-sdk")
        commands = [["git", "clone", "", source, target]]
        self.helper_test_checkout_repos(source, target, commands, verbose=True)

        commands = [["git", "clone", "--quiet", source, target]]
        self.helper_test_checkout_repos(source, target, commands, verbose=None)

    def test_checkout_repos_branch(self):

        # Test with specified branch
        source = "https://github.com/mongodb/stitch-js-sdk.git"
        target = os.path.join(self.parent_dir, "bin.stitch-js-sdk")
        branch = "2.x.x"
        commands = [
            ["git", "clone", "--quiet", source, target],
            ["cd", target, "&&", "git", "checkout", "--quiet", branch],
        ]
        self.helper_test_checkout_repos(source,
                                        target,
                                        commands,
                                        branch=branch,
                                        verbose=None)

    def test_checkout_repos_existing_target(self):

        # Test with existing target that is a git repository
        local = LocalHost()

        source = "https://github.com/mongodb/stitch-js-sdk.git"
        target = os.path.join(self.parent_dir, "stitch-js-sdk")
        command = ["cd", target, "&&", "git", "status"]
        with patch("dsi.common.host.os.path.isdir") as mock_isdir, patch(
                "dsi.common.host.mkdir_p") as mock_mkdir_p, patch(
                    "dsi.common.local_host.LocalHost.exec_command"
                ) as mock_exec_command:
            mock_isdir.return_value = True
            mock_exec_command.return_value = 0
            local.checkout_repos(source, target)
            mock_mkdir_p.assert_not_called()
            mock_exec_command.assert_called_once()
            mock_exec_command.assert_called_with(command)