class SalesforceFetcher(object):
    """
    Class that encapsulates all the fetching logic for SalesForce.
    """
    def __init__(self, config_path):
        """
        Bootstrap a fetcher class
        :param config_path: Path to the configuration file to use for this instance
        """
        # Get settings
        with open(config_path, 'r') as f:
            self.settings = yaml.safe_load(f)

        # Configure the logger
        log_level = (logging.WARN, logging.DEBUG)[self.settings['debug']]
        LOG_FORMAT = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        logger = logging.getLogger("salesforce-fetcher")
        logger.setLevel(log_level)

        ch = logging.StreamHandler()
        ch.setFormatter(LOG_FORMAT)
        logger.addHandler(ch)

        logger.debug("Logging is set to DEBUG level")
        # let's not output the password
        #logger.debug("Settings: %s" % self.settings)

        self.logger = logger
        self.salesforce = Salesforce(**self.settings['salesforce']['auth'])
        self.salesforce_bulk = SalesforceBulk(**self.settings['salesforce']
                                              ['auth'],
                                              API_version='46.0')

        # Make sure output dir is created
        output_directory = self.settings['output_dir']
        if not os.path.exists(output_directory):
            os.makedirs(output_directory)

    def fetch_all(self, fetch_only, airflow_date, fetch_method, days_lookback):
        """
        Fetch any reports or queries, writing them out as files in the output_dir
        """
        queries = self.load_queries()
        for name, query in queries.items():
            if fetch_only and name != fetch_only:
                self.logger.debug(
                    "'--fetch-only %s' specified. Skipping fetch of %s" %
                    (fetch_only, name))
                continue
            #if name == 'contacts' or name == 'opportunity':
            if fetch_method and fetch_method == 'bulk':
                self.fetch_soql_query_bulk(name, query, airflow_date)
            else:
                self.fetch_soql_query(name, query, airflow_date)

        reports = self.settings['salesforce']['reports']
        for name, report_url in reports.items():
            if fetch_only and name != fetch_only:
                self.logger.debug(
                    "'--fetch-only %s' specified. Skipping fetch of %s" %
                    (fetch_only, name))
                continue
            self.fetch_report(name, report_url, airflow_date)

        if fetch_only:
            if fetch_only == 'contact_deletes':
                self.fetch_contact_deletes(days=days_lookback,
                                           airflow_date=airflow_date)
        else:
            self.fetch_contact_deletes(days=days_lookback,
                                       airflow_date=airflow_date)

        self.logger.info("Job Completed")

    def fetch_contact_deletes(self, days=29, airflow_date=None):
        """
        Fetches all deletes from Contact for X days
        :param days: Fetch deletes from this number of days to present
        :return:
        """
        path = self.create_output_path('contact_deletes',
                                       airflow_date=airflow_date)
        end = datetime.datetime.now(
            pytz.UTC)  # we need to use UTC as salesforce API requires this!
        records = self.salesforce.Contact.deleted(
            end - datetime.timedelta(days=days), end)
        data_list = records['deletedRecords']
        if len(data_list) > 0:
            fieldnames = list(data_list[0].keys())
            with open(path, 'w') as f:
                writer = DictWriter(f,
                                    fieldnames=fieldnames,
                                    quoting=QUOTE_ALL)
                writer.writeheader()
                for delta_record in data_list:
                    writer.writerow(delta_record)

    def fetch_report(self, name, report_url, airflow_date=None):
        """
        Fetches a single prebuilt Salesforce report via an HTTP request
        :param name: Name of the report to fetch
        :param report_url: Base URL for the report
        :return:
        """

        self.logger.info("Fetching report - %s" % name)
        sf_host = self.settings['salesforce']['host']
        url = "%s%s?view=d&snip&export=1&enc=UTF-8&xf=csv" % (sf_host,
                                                              report_url)

        resp = requests.get(url,
                            headers=self.salesforce.headers,
                            cookies={'sid': self.salesforce.session_id},
                            stream=True)

        path = self.create_output_path(name, airflow_date=airflow_date)
        with open(path, 'w+') as f:
            # Write the full contents
            f.write(resp.text.replace("\"", ""))

            # Remove the Salesforce footer (last 7 lines)
            f.seek(0, os.SEEK_END)
            pos = f.tell() - 1

            count = 0
            while pos > 0 and count < 7:
                pos -= 1
                f.seek(pos, os.SEEK_SET)
                if f.read(1) == "\n":
                    count += 1

            # So long as we're not at the start of the file, delete all the characters ahead of this position
            if pos > 0:
                # preserve the last newline then truncate the file
                pos += 1
                f.seek(pos, os.SEEK_SET)
                f.truncate()

    def fetch_soql_query_bulk(self, name, query, airflow_date=None):
        self.logger.info("BULK Executing %s" % name)
        self.logger.info("BULK Query is: %s" % query)
        if name == 'contacts' or name == 'contact_updates':
            table_name = 'Contact'
        elif name == 'opportunity' or name == 'opportunity_updates':
            table_name = 'Opportunity'
        job = self.salesforce_bulk.create_query_job(table_name,
                                                    contentType='CSV',
                                                    pk_chunking=True,
                                                    concurrency='Parallel')
        self.logger.info("job: %s" % job)
        batch = self.salesforce_bulk.query(job, query)
        #        job = '7504O00000LUxuCQAT'
        #        batch = '7514O00000TvapeQAB'
        self.logger.info("Bulk batch created: %s" % batch)

        while True:
            batch_state = self.salesforce_bulk.batch_state(
                batch, job_id=job, reload=True).lower()
            if batch_state == 'notprocessed':
                self.logger.info("master batch is done")
                break
            elif batch_state == 'aborted' or batch_state == 'failed':
                self.logger.error("master batch failed")
                self.logger.error(
                    self.salesforce_bulk.batch_status(batch_id=batch,
                                                      job_id=job,
                                                      reload=True))
                raise Exception("master batch failed")
            self.logger.info("waiting for batch to be done. status=%s" %
                             batch_state)
            time.sleep(10)

        count = 0
        downloaded = {}

        pool = mp.Pool(5)

        while True:
            stats = {}
            batch_count = 0
            all_batches = self.salesforce_bulk.get_batch_list(job)
            for batch_info in all_batches:
                batch_count += 1

                batch_state = batch_info['state'].lower()
                if batch_state in stats:
                    stats[batch_state] += 1
                else:
                    stats[batch_state] = 1

                if batch_info['id'] == batch:
                    #self.logger.debug("skipping the master batch id")
                    continue
                elif batch_info['id'] in downloaded:
                    #self.logger.debug("batch %s already downloaded" % batch_info['id'])
                    continue

                if batch_state == 'completed':
                    self.logger.debug(
                        "batch %s (%s of %s)" %
                        (batch_info['id'], batch_count, len(all_batches)))

                    for result_id in self.salesforce_bulk.get_query_batch_result_ids(
                            batch_info['id'], job_id=job):
                        self.logger.debug("result_id: %s" % result_id)
                        path = self.create_output_path(
                            name, result_id, airflow_date=airflow_date)
                        pool.apply_async(
                            get_and_write_bulk_results,
                            args=(batch_info['id'], result_id, job,
                                  self.salesforce_bulk.endpoint,
                                  self.salesforce_bulk.headers(), path))

                    downloaded[batch_info['id']] = 1

                elif batch_state == 'failed':
                    downloaded[batch_info['id']] = 1
                    self.logger.error("batch %s failed!" % batch_info['id'])
                    self.logger.error(
                        self.salesforce_bulk.batch_status(
                            batch_id=batch_info['id'], job_id=job,
                            reload=True))

            if 'completed' in stats and stats['completed'] + 1 == batch_count:
                self.logger.info("all batches retrieved")
                break
            elif 'failed' in stats and stats['failed'] + 1 == batch_count:
                self.logger.error("NO batches retrieved")
                self.logger.error(
                    self.salesforce_bulk.batch_status(batch_id=batch,
                                                      job_id=job,
                                                      reload=True))
                raise Exception("NO batches retrieved")
            elif 'failed' in stats and stats['failed'] + stats[
                    'completed'] == batch_count:
                self.logger.warning("all batches WITH SOME FAILURES")
                break
            else:
                self.logger.info(stats)
                time.sleep(5)

        try:
            self.salesforce_bulk.close_job(job)
        except:
            pass
        pool.close()
        pool.join()

    def fetch_soql_query(self, name, query, airflow_date=None):
        self.logger.info("Executing %s" % name)
        self.logger.info("Query is: %s" % query)
        path = self.create_output_path(name, airflow_date=airflow_date)
        result = self.salesforce.query(query)
        self.logger.info("First result set received")
        batch = 0
        count = 0
        if result['records']:
            fieldnames = list(result['records'][0].keys())
            fieldnames.pop(0)  # get rid of attributes
            with open(path, 'w') as f:
                writer = DictWriter(f,
                                    fieldnames=fieldnames,
                                    quoting=QUOTE_ALL)
                writer.writeheader()

                while True:
                    batch += 1
                    for row in result['records']:
                        # each row has a strange attributes key we don't want
                        row.pop('attributes', None)
                        out_dict = {}
                        for key, value in row.items():
                            if type(value) is collections.OrderedDict:
                                out_dict[key] = json.dumps(value)
                            else:
                                out_dict[key] = value
                        writer.writerow(out_dict)
                        count += 1
                        if count % 100000 == 0:
                            self.logger.debug("%s rows fetched" % count)

                    # fetch next batch if we're not done else break out of loop
                    if not result['done']:
                        result = self.salesforce.query_more(
                            result['nextRecordsUrl'], True)
                    else:
                        break

        else:
            self.logger.warn("No results returned for %s" % name)

    def create_output_path(self, name, filename='output', airflow_date=None):
        output_dir = self.settings['output_dir']
        if airflow_date:
            date = airflow_date
        else:
            date = time.strftime("%Y-%m-%d")
        child_dir = os.path.join(output_dir, name, date)
        if not os.path.exists(child_dir):
            os.makedirs(child_dir)

        filename = filename + ".csv"
        file_path = os.path.join(child_dir, filename)
        self.logger.info("Writing output to %s" % file_path)
        return file_path

    def create_custom_query(self,
                            table_name='Contact',
                            dir='/usr/local/salesforce_fetcher/queries',
                            updates_only=False):
        """
        The intention is to have Travis upload the "contact_fields.yaml" file
        to a bucket where it can be pulled down dynamically by this script
        and others (instead of having to rebuild the image on each change)
        """

        fields_file_name = table_name.lower() + '_fields.yaml'
        fields_file = os.path.join(dir, fields_file_name)
        if not os.path.exists(fields_file):
            return
        with open(fields_file, 'r') as stream:
            columns = yaml.safe_load(stream)

        query = "SELECT "
        for field in columns['fields']:
            query += next(iter(field)) + ', '

        query = query[:-2] + " FROM " + table_name
        if updates_only:
            query += " WHERE LastModifiedDate >= LAST_N_DAYS:3"

        return query

    def load_queries(self):
        """
        load queries from an external directory
        :return: a dict containing all the SOQL queries to be executed
        """
        queries = {}

        query_dir = self.settings['salesforce']['query_dir']
        for file in os.listdir(query_dir):
            if file.endswith(".soql"):
                name, ext = os.path.splitext(file)
                query_file = os.path.join(query_dir, file)
                with open(query_file, 'r') as f:
                    queries[name] = f.read().strip().replace('\n', ' ')

        # explicitly add the non-file queries
        queries['contacts'] = self.create_custom_query(table_name='Contact',
                                                       dir=query_dir)
        queries['contact_updates'] = self.create_custom_query(
            table_name='Contact', dir=query_dir, updates_only=True)
        queries['opportunity'] = self.create_custom_query(
            table_name='Opportunity', dir=query_dir)
        queries['opportunity_updates'] = self.create_custom_query(
            table_name='Opportunity', dir=query_dir, updates_only=True)

        return queries
class SFInteraction(object):
    """Simple class that interacts with Salesforce"""
    def __init__(self,
                 username,
                 password,
                 token,
                 uat,
                 project_id='datacoco3.db',
                 session_id=None,
                 instance=None,
                 version=None):
        """Instantiate a Salesforce interaction manager.

        UAT mode is explicitly set to a boolean value in case a string is provided.

        If Salesforce session credentials do not exist, attempt to retrieve.

        :param username: Username
        :param password: Password
        :param token: API token
        :param uat: Whether or not in UAT mode
        :param project_id: to identify project source api calls
        :param session_id:  Access token for existing session
        :param instance: Domain of Salesforce instance
        """
        if not username or not password or not token or uat is None:
            raise RuntimeError("%s request all __init__ arguments" % __name__)

        self.username = username
        self.password = password
        self.token = token
        self.session_id = session_id
        self.instance = instance
        self.project_id = project_id
        self.version = version

        self.conn = None
        self.bulk = None
        self.job_id = None
        self.batch_max_attempts = None
        self.batch_timeout = None
        self.batch_sleep_interval = None
        self.temp_file = "sf_temp_results.txt"

        self.redis_conn = None
        self.session_credentials_key = "sf_session_credentials"

        # Handle string uat which should be boolean
        if isinstance(uat, str) and uat.lower() in ("true", "t"):
            self.uat = True
        else:
            self.uat = False

        # Retrieve session_id and/or instance if they do not exist
        if not self.session_id or not self.instance:
            self._get_session_credentials()

    def connect(self):
        """Connect to the Salesforce API client.

        Only executes if there is not an existing open Salesforce connection.

        If there are a session_id and an instance, attempt to connect to
        the existing session. The existing session connection is verified with a Salesforce API describe call.

        If that fails, create a new connection.

        There are 3 retry attempts
        """
        if self.session_id and self.instance:
            retry_count = 1
            while True:
                if retry_count > 3:
                    LOG.l(
                        "Could not connect to Salesforce in the specified number of retries."
                    )
                    LOG.l("Starting a new connection...")
                    break
                else:
                    LOG.l(
                        f"Connecting to Salesforce: attempt {retry_count} of 3..."
                    )

                try:
                    self.conn = Salesforce(session_id=self.session_id,
                                           instance=self.instance,
                                           client_id=self.project_id,
                                           version=self.version)
                    self.conn.describe()  # Connection health check
                    return  #Sucess, leave this function
                except SalesforceError as sfe:
                    LOG.l(
                        f"Encountered error connecting to Salesforce:\n{sfe}")
                    retry_count += 1
                    sleep(5)
                    continue

        #If reconnecting didn't work or session_id is not set, then start a new connection
        try:
            self._create_new_connection()
        except:
            raise Exception("Could not initiate connection to Salesforce!")

    def fetch_soql(
        self,
        db_table,
        soql,
        batch=True,
        batch_timeout=600,
        batch_sleep_int=10,
        batch_max_attempts=1,
    ):
        """Fetch results from Salesforce soql queries.

        Batch Salesforce queries results saved to a file and retrieved because they are in CSV format
        and to avoid bulk queries timeouts.

        :param db_table: Database table name
        :param soql: Soql queries
        :param batch: Whether to use Salesforce Batch or Simple API
        :param batch_sleep_int: Salesforce Bulk query sleep interval
        :param batch_timeout: Batch job timeout in seconds
        :param batch_max_attempts: Maximum number of batch query creation attempts
        :return: If success, List of result dictionaries; Else empty list
        """
        try:
            if batch:
                # Set batch operation attributes
                self.batch_timeout = batch_timeout
                self.batch_sleep_interval = batch_sleep_int
                self.batch_max_attempts = batch_max_attempts

                results = self.get_query_records_dict(db_table, soql)

                # save to and read from file to avoid connection timeout
                self._save_results_to_file(results)
                records = self._get_results_from_file()
            else:
                result = self.conn.query(soql)

                # if there isn't a result return an empty list
                if result["records"]:
                    salesforce_records = json.loads(
                        json.dumps(result["records"][0]))
                    parsed_records = parse_sf_records(salesforce_records)
                    records = [parsed_records
                               ]  # put result in a list object for consistency
                else:
                    records = []
        except BulkApiError as e:
            self.bulk.abort_job(self.job_id)
            # TODO Handle failed bulk API transaction better
            raise e

        return records

    def get(self, object_name: str, object_id: str):
        """
        To get a dictionary with all the information regarding that record
        """
        return self.conn.__getattr__(object_name).get(object_id)

    def get_by_custom_id(self, object_name: str, field: str, id: str):
        """
        To get a dictionary with all the information regarding that record
        using a **custom** field that was defined as External ID:
        """
        return self.conn.__getattr__(object_name).get_by_custom_id(field, id)

    def upsert(self, object_name: str, field: str, id: str, data: dict):
        """
        To insert or update (upsert) a record using an external ID
        """
        return self.conn.__getattr__(object_name).upsert(f'{field}/{id}', data)

    def get_query_records_dict(self, db_table, soql_query):
        """Execute bulk Salesforce soql queries and return results as generator of dictionaries.

        :param db_table: Database table name
        :param soql_query: Soql queries
        :return: If success, List of result record dictionaries; Else empty list
        """
        self.bulk = SalesforceBulk(sessionId=self.session_id,
                                   host=self.instance)
        job = self.bulk.create_query_job(db_table, contentType="JSON")
        batch = self.bulk.query(job, soql_query)
        self.bulk.close_job(job)
        while not self.bulk.is_batch_done(batch):
            print("Waiting for batch query to complete")
            sleep(10)

        dict_records = []
        rec_count = 0
        print("Iterating through batch result set")
        for result in self.bulk.get_all_results_for_query_batch(batch):
            result = json.load(IteratorBytesIO(result))
            for row in result:
                rec_count += 1
                dict_records.append(row)
            print("Current fetched record count: ", rec_count)

        return dict_records

    def batch_query_records_dict(self,
                                 db_table,
                                 soql_query,
                                 concurrency='Serial'):
        """Execute bulk Salesforce soql queries and return results as generator of dictionaries.

        works only for PK CHUNKING enabled SF tables.

        Allows millions of record read.

        :param db_table: Database table name
        :param soql_query: Soql queries
        :return: If success, List of result record dictionaries; Else empty list
        """
        self.bulk = SalesforceBulk(sessionId=self.session_id,
                                   host=self.instance)
        job = self.bulk.create_query_job(db_table,
                                         contentType="JSON",
                                         pk_chunking=True,
                                         concurrency=concurrency)
        try:
            batch = self.bulk.query(job, soql_query)
            batch_list = self.bulk.get_batch_list(job)
            print('first batch', batch_list[0])
            batch_id = batch_list[0]['id']
            job_id = batch_list[0]['jobId']
            state = batch_list[0]['state']
            while state == 'Queued' or state == 'InProgress':
                print(
                    "Waiting for batch state Queued or InProgress to change " +
                    state)
                sleep(10)
                state = self.bulk.batch_state(batch_id, job_id)

            batch_list = self.bulk.get_batch_list(job)
            print(f'number of batches: {len(batch_list)}')
            for item in batch_list:
                print('item', item)
                batch_id = item['id']
                job_id = item['jobId']
                state = item['state']

                if state == 'NotProcessed':
                    continue

                while not self.bulk.is_batch_done(batch_id, job_id):
                    print(
                        f"Waiting for batch query to complete batch_id:{batch_id}, job_id: {job_id}, state: {state}"
                    )
                    sleep(10)
                    state = self.bulk.batch_state(batch_id, job_id)

                total_retry_count = len(batch_list)
                retry = len(batch_list)
                lastIndex = 0
                while retry > 0:
                    print(f'retry {retry} times left')
                    try:
                        for result in list(
                                self.bulk.get_all_results_for_query_batch(
                                    batch_id, job_id))[lastIndex:]:
                            result = json.load(IteratorBytesIO(result))
                            lastIndex += 1
                            yield result
                        break
                    except requests.exceptions.ChunkedEncodingError as e:
                        print('Chunking failed')
                        retry -= 1
                        self.connect()
                        self.bulk = SalesforceBulk(sessionId=self.session_id,
                                                   host=self.instance)
                        pass
                    except Exception as e:
                        print('There was an error')
                        traceback.print_exc()
                        retry -= 1
                        self.connect()
                        self.bulk = SalesforceBulk(sessionId=self.session_id,
                                                   host=self.instance)
                        pass
                if retry <= 0:
                    raise Exception(
                        f'Retried {total_retry_count} times and it still failed'
                    )
        except BulkApiError as e:
            self.bulk.abort_job(self.job_id)
            raise e

    def upload_records_to_s3(self, records, s3_bucket, s3_key, aws_access_key,
                             aws_secret_key):
        """Upload records to s3.

        :param records: Records filename
        """
        self._save_results_to_file(records)
        datetime_today = datetime.today().strftime("%Y-%m-%d-%X")

        s3_dest_key = s3_key + datetime_today

        s3_interaction = S3Interaction(aws_access_key, aws_secret_key)
        s3_interaction.put_file_to_s3(s3_bucket, s3_dest_key, self.temp_file)

        return s3_dest_key

    def get_description(self, object_name):
        """Retrieves object description

        :param object_name: Salesforce object/table name
        """
        retry = True
        while retry:
            try:
                return self.conn.__getattr__(object_name).describe()
            except SalesforceError as sfe:
                retry = self._sf_except_reconnect(sfe)

    def _sf_except_reconnect(self, e):
        """ Used in try/catch blocks to reinit the connection

        returns true if the code should be retried, false if no connection could be made
        """
        LOG.l(f"Encountered error:\n{e}")
        try:
            self.connect()
            return True
        except Exception:
            return False

    def _create_new_connection(self):
        """Create a new Salesforce API client connection.

        After the connection is created, the Salesforce session credentials are stored externally.
        """
        self.conn = Salesforce(username=self.username,
                               password=self.password,
                               security_token=self.token,
                               sandbox=self.uat,
                               client_id=self.project_id)
        self.session_id = str(self.conn.session_id)
        self.instance = str(self.conn.sf_instance)
        self._set_session_credentials()

    def _save_results_to_file(self, records):
        """Save Salesforce Bulk API results to a temp file.

        :param records: Records to save
        """
        with open(self.temp_file, "w") as f:
            for r in records:
                f.write("\n")
                f.write(str(str(r).encode("utf-8")))

    def _get_results_from_file(self):
        """Get Salesforce Bulk API results from a temp file.

        The records must be parsed. After the results are retrieved. The file is deleted.

        :return: Iterator with records.
        """
        results = []
        with open(self.temp_file, "r") as f:
            records = f.read()[1:].splitlines()
            for r in records:
                r = ast.literal_eval(r)
                results.append(r)
        os.remove(self.temp_file)
        return results

    def _get_session_credentials(self):
        """Get Salesforce session credentials stored in Redis.

        If the credentials variables do not exist, set the credentials as None.
        """
        # Establish connection to Redis
        self._connect_to_redis()
        # Get salesforce credentials if exists
        if self.redis_conn.conn.exists(self.session_credentials_key):
            self.session_id = self.redis_conn.fetch_by_key_name(
                self.session_credentials_key, "session_id")
            self.instance = self.redis_conn.fetch_by_key_name(
                self.session_credentials_key, "instance")
        else:
            self.session_id = None
            self.instance = None

    def _set_session_credentials(self):
        """Set Salesforce session credentials in Redis.

        """
        sf_session_credentials = {
            "session_id": self.session_id,
            "instance": self.instance,
        }
        self.redis_conn.set_key(self.session_credentials_key,
                                sf_session_credentials)

    def _connect_to_redis(self):
        """Connect to Redis.

        """
        CONF = config()
        host = CONF["redis"]["server"]
        port = CONF["redis"]["port"]
        db = CONF["redis"]["db"]
        self.redis_conn = RedisInteraction(host, port, db)
        self.redis_conn.connect()