def test_cli_connections_import_should_return_error_if_file_does_not_exist(
         self, mock_exists):
     mock_exists.return_value = False
     filepath = '/does/not/exist.json'
     with pytest.raises(SystemExit, match=r"Missing connections file."):
         connection_command.connections_import(
             self.parser.parse_args(["connections", "import", filepath]))
    def test_cli_connections_import_should_load_connections(
            self, mock_exists, mock_load_connections_dict):
        mock_exists.return_value = True

        # Sample connections to import
        expected_connections = {
            "new0": {
                "conn_type": "postgres",
                "description": "new0 description",
                "host": "host",
                "is_encrypted": False,
                "is_extra_encrypted": False,
                "login": "******",
                "port": 5432,
                "schema": "airflow",
            },
            "new1": {
                "conn_type": "mysql",
                "description": "new1 description",
                "host": "host",
                "is_encrypted": False,
                "is_extra_encrypted": False,
                "login": "******",
                "port": 3306,
                "schema": "airflow",
            },
        }

        # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
        mock_load_connections_dict.return_value = expected_connections

        connection_command.connections_import(
            self.parser.parse_args(["connections", "import", 'sample.json']))

        # Verify that the imported connections match the expected, sample connections
        with create_session() as session:
            current_conns = session.query(Connection).all()

            comparable_attrs = [
                "conn_type",
                "description",
                "host",
                "is_encrypted",
                "is_extra_encrypted",
                "login",
                "port",
                "schema",
            ]

            current_conns_as_dicts = {
                current_conn.conn_id: {
                    attr: getattr(current_conn, attr)
                    for attr in comparable_attrs
                }
                for current_conn in current_conns
            }
            assert expected_connections == current_conns_as_dicts
 def test_cli_connections_import_should_return_error_if_file_format_is_invalid(
         self, filepath, mock_exists):
     mock_exists.return_value = True
     with pytest.raises(
             AirflowException,
             match=
             r"Unsupported file format. The file must have the extension .env or .json or .yaml",
     ):
         connection_command.connections_import(
             self.parser.parse_args(["connections", "import", filepath]))
    def test_connections_import(self, file_content, expected_connection_uris):
        """Test connections_import command"""

        with mock_local_file(json.dumps(file_content), 'a.json'):
            connection_command.connections_import(self.parser.parse_args(['connections', 'import', "a.json"]))
            with create_session() as session:
                for conn_id in expected_connection_uris:
                    current_conn = session.query(Connection).filter(Connection.conn_id == conn_id).first()
                    self.assertEqual(expected_connection_uris[conn_id],
                                     {attr: getattr(current_conn, attr)
                                      for attr in expected_connection_uris[conn_id]})
    def test_connections_import_disposition_overwrite(self, file_content, expected_connection_uris):
        """Test connections_import command with --conflict-disposition overwrite"""
        with mock_local_file(json.dumps(file_content[0]), 'a.json'):
            connection_command.connections_import(self.parser.parse_args([
                'connections', 'import', 'a.json', '--conflict-disposition', 'overwrite']))

        with mock_local_file(json.dumps(file_content[1]), 'a.json'):
            connection_command.connections_import(self.parser.parse_args([
                'connections', 'import', 'a.json', '--conflict-disposition', 'overwrite']))

        conn_id = 'CONN_ID3'
        with create_session() as session:
            current_conn = session.query(Connection).filter(Connection.conn_id == conn_id).first()
            self.assertEqual(expected_connection_uris, current_conn.get_uri())
    def test_cli_connections_import_should_not_overwrite_existing_connections(
            self, mock_exists, mock_load_connections_dict, session=None):
        mock_exists.return_value = True

        # Add a pre-existing connection "new1"
        merge_conn(
            Connection(
                conn_id="new1",
                conn_type="mysql",
                description="mysql description",
                host="mysql",
                login="******",
                password="",
                schema="airflow",
            ),
            session=session,
        )

        # Sample connections to import, including a collision with "new1"
        expected_connections = {
            "new0": {
                "conn_type": "postgres",
                "description": "new0 description",
                "host": "host",
                "is_encrypted": False,
                "is_extra_encrypted": False,
                "login": "******",
                "port": 5432,
                "schema": "airflow",
            },
            "new1": {
                "conn_type": "mysql",
                "description": "new1 description",
                "host": "host",
                "is_encrypted": False,
                "is_extra_encrypted": False,
                "login": "******",
                "port": 3306,
                "schema": "airflow",
            },
        }

        # We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
        mock_load_connections_dict.return_value = expected_connections

        with redirect_stdout(io.StringIO()) as stdout:
            connection_command.connections_import(
                self.parser.parse_args(
                    ["connections", "import", 'sample.json']))

            assert 'Could not import connection new1: connection already exists.' in stdout.getvalue(
            )

        # Verify that the imported connections match the expected, sample connections
        current_conns = session.query(Connection).all()

        comparable_attrs = [
            "conn_type",
            "description",
            "host",
            "is_encrypted",
            "is_extra_encrypted",
            "login",
            "port",
            "schema",
        ]

        current_conns_as_dicts = {
            current_conn.conn_id:
            {attr: getattr(current_conn, attr)
             for attr in comparable_attrs}
            for current_conn in current_conns
        }
        assert current_conns_as_dicts['new0'] == expected_connections['new0']

        # The existing connection's description should not have changed
        assert current_conns_as_dicts['new1'][
            'description'] == 'new1 description'