class PortalRawDataNotifier(object):
    class NotifyCtx(object):
        def __init__(self, master_addr):
            self._master_addr = master_addr
            channel = make_insecure_channel(master_addr, ChannelType.INTERNAL)
            self._master_cli = dj_grpc.DataJoinMasterServiceStub(channel)
            self._data_source = None
            self._raw_date_ctl = None
            self._raw_data_updated_datetime = {}

        @property
        def data_source(self):
            if self._data_source is None:
                self._data_source = \
                        self._master_cli.GetDataSource(empty_pb2.Empty())
            return self._data_source

        def get_raw_data_updated_datetime(self, partition_id):
            if partition_id not in self._raw_data_updated_datetime:
                ts = self._raw_date_controller.get_raw_data_latest_timestamp(
                    partition_id)
                if ts.seconds > 3600:
                    ts.seconds -= 3600
                else:
                    ts.seconds = 0
                self._raw_data_updated_datetime[partition_id] = \
                        common.convert_timestamp_to_datetime(
                                common.trim_timestamp_by_hourly(ts)
                            )
            return self._raw_data_updated_datetime[partition_id]

        def add_raw_data(self, partition_id, fpaths, timestamps, end_ts):
            assert len(fpaths) == len(timestamps), \
                "the number of raw data path and timestamp should same"
            if len(fpaths) > 0:
                self._raw_date_controller.add_raw_data(partition_id, fpaths,
                                                       True, timestamps)
                self._raw_data_updated_datetime[partition_id] = end_ts

        @property
        def data_source_master_addr(self):
            return self._master_addr

        @property
        def _raw_date_controller(self):
            if self._raw_date_ctl is None:
                self._raw_date_ctl = RawDataController(self.data_source,
                                                       self._master_cli)
            return self._raw_date_ctl

    def __init__(self, etcd, portal_name, downstream_data_source_masters):
        self._lock = threading.Lock()
        self._etcd = etcd
        self._portal_name = portal_name
        assert len(downstream_data_source_masters) > 0, \
            "PortalRawDataNotifier launched when has master to notify"
        self._master_notify_ctx = {}
        for addr in downstream_data_source_masters:
            self._master_notify_ctx[addr] = \
                    PortalRawDataNotifier.NotifyCtx(addr)
        self._notify_worker = None
        self._started = False

    def start_notify_worker(self):
        with self._lock:
            if not self._started:
                assert self._notify_worker is None, \
                    "notify worker should be None if not started"
                self._notify_worker = RoutineWorker('potral-raw_data-notifier',
                                                    self._raw_data_notify_fn,
                                                    self._raw_data_notify_cond,
                                                    5)
                self._notify_worker.start_routine()
                self._started = True
            self._notify_worker.wakeup()

    def stop_notify_worker(self):
        notify_worker = None
        with self._lock:
            notify_worker = self._notify_worker
            self._notify_worker = None
        if notify_worker is not None:
            notify_worker.stop_routine()

    def _check_partition_num(self, notify_ctx, portal_manifest):
        assert isinstance(notify_ctx, PortalRawDataNotifier.NotifyCtx)
        data_source = notify_ctx.data_source
        ds_partition_num = data_source.data_source_meta.partition_num
        if portal_manifest.output_partition_num % ds_partition_num != 0:
            raise ValueError(
                    "the partition number({}) of down stream data source "\
                    "{} should be divised by output partition of "\
                    "portatl({})".format(ds_partition_num,
                    data_source.data_source_meta.name,
                    portal_manifest.output_partition_num)
                )

    def _add_raw_data_impl(self, notify_ctx, portal_manifest, ds_pid):
        dt = notify_ctx.get_raw_data_updated_datetime(ds_pid) + \
                timedelta(hours=1)
        begin_dt = common.convert_timestamp_to_datetime(
            common.trim_timestamp_by_hourly(portal_manifest.begin_timestamp))
        if dt < begin_dt:
            dt = begin_dt
        committed_dt = common.convert_timestamp_to_datetime(
            portal_manifest.committed_timestamp)
        fpaths = []
        timestamps = []
        ds_ptnum = notify_ctx.data_source.data_source_meta.partition_num
        while dt <= committed_dt:
            for pt_pid in range(ds_pid, portal_manifest.output_partition_num,
                                ds_ptnum):
                fpath = common.encode_portal_hourly_fpath(
                    portal_manifest.output_data_base_dir, dt, pt_pid)
                if gfile.Exists(fpath):
                    fpaths.append(fpath)
                    timestamps.append(common.convert_datetime_to_timestamp(dt))
            if len(fpaths) > 32 or dt == committed_dt:
                break
            dt += timedelta(hours=1)
        notify_ctx.add_raw_data(ds_pid, fpaths, timestamps, dt)
        logging.info("add %d raw data file for partition %d of data "\
                     "source %s. latest updated datetime %s",
                      len(fpaths), ds_pid,
                      notify_ctx.data_source.data_source_meta.name, dt)
        return dt >= committed_dt

    def _notify_one_data_source(self, notify_ctx, portal_manifest):
        assert isinstance(notify_ctx, PortalRawDataNotifier.NotifyCtx)
        try:
            self._check_partition_num(notify_ctx, portal_manifest)
            ds_ptnum = notify_ctx.data_source.data_source_meta.partition_num
            pt_ptnum = portal_manifest.output_partition_num
            add_finished = False
            while not add_finished:
                add_finished = True
                for ds_pid in range(ds_ptnum):
                    if not self._add_raw_data_impl(notify_ctx, portal_manifest,
                                                   ds_pid):
                        add_finished = False
        except Exception as e:  # pylint: disable=broad-except
            logging.error("Failed to notify data source[master-addr: %s] "\
                          "new raw data added, reason %s",
                          notify_ctx.data_source_master_addr, e)

    def _raw_data_notify_fn(self):
        portal_manifest = common.retrieve_portal_manifest(
            self._etcd, self._portal_name)
        for _, notify_ctx in self._master_notify_ctx.items():
            self._notify_one_data_source(notify_ctx, portal_manifest)

    def _raw_data_notify_cond(self):
        return True
class RawDataPartitioner(object):
    class OutputFileWriter(object):
        def __init__(self, options, partition_id):
            self._options = options
            self._partition_id = partition_id
            self._process_index = 0
            self._writer = None
            self._dumped_item = 0
            self._output_fpaths = []
            self._output_dir = os.path.join(
                    self._options.output_dir,
                    common.partition_repr(self._partition_id)
                )
            if not gfile.Exists(self._output_dir):
                gfile.MakeDirs(self._output_dir)
            assert gfile.IsDirectory(self._output_dir)

        def append_item(self, index, item):
            writer = self._get_output_writer()
            if self._options.output_builder == 'TF_RECORD':
                writer.write(item.tf_record)
            else:
                assert self._options.output_builder == 'CSV_DICT'
                writer.write(item.csv_record)
            self._dumped_item += 1
            if self._dumped_item >= self._options.output_item_threshold:
                self._finish_writer()
                if self._process_index % 16 == 0:
                    logging.info("Output partition %d dump %d files, "\
                                 "last index %d", self._partition_id,
                                 self._process_index, index)

        def finish(self):
            self._finish_writer()

        def get_output_files(self):
            return self._output_fpaths

        def _get_output_writer(self):
            if self._writer is None:
                self._new_writer()
            return self._writer

        def _new_writer(self):
            assert self._writer is None
            fname = "{:04}-{:08}.rd".format(
                    self._options.partitioner_rank_id,
                    self._process_index
                )
            fpath = os.path.join(self._output_dir, fname)
            self._output_fpaths.append(fpath)
            if self._options.output_builder == 'TF_RECORD':
                self._writer = tf.io.TFRecordWriter(fpath)
            else:
                assert self._options.output_builder == 'CSV_DICT'
                self._writer = CsvDictWriter(fpath)
            self._dumped_item = 0

        def _finish_writer(self):
            if self._writer is not None:
                self._writer.close()
                self._writer = None
            self._dumped_item = 0
            self._process_index += 1

    def __init__(self, options):
        self._options = options
        self._raw_data_batch_fetcher = RawDataBatchFetcher(options)
        self._fetch_worker = RoutineWorker('raw_data_batch_fetcher',
                                           self._raw_data_batch_fetch_fn,
                                           self._raw_data_batch_fetch_cond, 5)
        self._next_part_index = 0
        self._cond = threading.Condition()
        self._fetch_worker.start_routine()

    def partition(self):
        if self._check_finished_tag():
            logging.warning("partition has finished for rank id of parti"\
                            "tioner %d", self._options.partitioner_rank_id)
            return
        next_index = 0
        hint_index = 0
        fetch_finished = False
        fetcher = self._raw_data_batch_fetcher
        writers = [RawDataPartitioner.OutputFileWriter(self._options, pid)
                   for pid in range(self._options.output_partition_num)]
        iter_round = 0
        bp_options = self._options.batch_processor_options
        signal_round_threhold = bp_options.max_flying_item / \
                bp_options.batch_size // 8
        while not fetch_finished:
            fetch_finished, batch, hint_index = \
                    fetcher.fetch_item_batch_by_index(next_index, hint_index)
            iter_round += 1
            if batch is not None:
                for index, item in enumerate(batch):
                    raw_id = item.raw_id
                    partition_id = CityHash32(raw_id) % \
                            self._options.output_partition_num
                    writer = writers[partition_id]
                    writer.append_item(batch.begin_index+index, item)
                next_index = batch.begin_index + len(batch)
                if iter_round % signal_round_threhold == 0:
                    hint_index = self._evict_staless_batch(hint_index,
                                                           next_index-1)
                    logging.info("consumed %d items", next_index-1)
                self._set_next_part_index(next_index)
                self._wakeup_raw_data_fetcher()
            elif not fetch_finished:
                hint_index = self._evict_staless_batch(hint_index,
                                                       next_index-1)
                with self._cond:
                    self._cond.wait(1)
        for partition_id, writer in enumerate(writers):
            writer.finish()
            fpaths = writer.get_output_files()
            logging.info("part %d output %d files by partitioner",
                          partition_id, len(fpaths))
            for fpath in fpaths:
                logging.info("%s", fpath)
            logging.info("-----------------------------------")
        self._dump_finished_tag()
        self._fetch_worker.stop_routine()

    def _evict_staless_batch(self, hint_index, staless_index):
        evict_cnt = self._raw_data_batch_fetcher.evict_staless_item_batch(
                staless_index
            )
        if hint_index <= evict_cnt:
            return 0
        return hint_index-evict_cnt

    def _set_next_part_index(self, next_part_index):
        with self._cond:
            self._next_part_index = next_part_index

    def _get_next_part_index(self):
        with self._cond:
            return self._next_part_index

    def _raw_data_batch_fetch_fn(self):
        next_part_index = self._get_next_part_index()
        fetcher = self._raw_data_batch_fetcher
        for batch in fetcher.make_processor(next_part_index):
            logging.debug("fetch batch begin at %d, len %d. wakeup "\
                          "partitioner", batch.begin_index, len(batch))
            self._wakeup_partitioner()

    def _raw_data_batch_fetch_cond(self):
        next_part_index = self._get_next_part_index()
        return self._raw_data_batch_fetcher.need_process(next_part_index)

    def _wakeup_partitioner(self):
        with self._cond:
            self._cond.notify_all()

    def _wakeup_raw_data_fetcher(self):
        self._fetch_worker.wakeup()

    def _dump_finished_tag(self):
        finished_tag_fpath = self._get_finished_tag_fpath()
        with gfile.GFile(finished_tag_fpath, 'w') as fh:
            fh.write('')

    def _check_finished_tag(self):
        return gfile.Exists(self._get_finished_tag_fpath())

    def _get_finished_tag_fpath(self):
        return os.path.join(
                self._options.output_dir,
                '_SUCCESS.{:08}'.format(self._options.partitioner_rank_id)
            )