Beispiel #1
0
class ModelClassLoader:
    def __init__(self):
        self.logger = Logger(self.__class__.__name__, stdout_only=True)
        self.log = self.logger.get_log()

    @staticmethod
    def load_model_class(model_name):
        module_path = module_path_finder(MODEL_MODULE_PATH, model_name)
        model = import_class_from_module_path(module_path, model_name)
        return model
Beispiel #2
0
class AbstractPrintLog(AbstractVisualizer):
    def __init__(self, path=None, iter_cycle=None, name=None):
        super().__init__(path, iter_cycle, name)
        self.logger = Logger(self.__class__.__name__, self.visualizer_path)
        self.log = self.logger.get_log()

    def __del__(self):
        super().__del__()
        del self.log
        del self.logger

    def task(self, sess=None, iter_num=None, model=None, dataset=None):
        raise NotImplementedError
Beispiel #3
0
class BaseDatasetPack:
    def __init__(self):
        self.logger = Logger(self.__class__.__name__)
        self.log = self.logger.get_log()

        self.set = {}

    def load(self, path, **kwargs):
        for k in self.set:
            self.set[k].load(path, **kwargs)

    def shuffle(self):
        for key in self.set:
            self.set[key].shuffle()

    def split(self, from_key, a_key, b_key, rate):
        from_set = self.set[from_key]
        self.set.pop(from_key)

        a_set, b_set = from_set.split(rate)
        self.set[a_key] = a_set
        self.set[b_key] = b_set
        return a_set, b_set

    def merge_shuffle(self, a_key, b_key, rate):
        a_set = self.set[a_key]
        b_set = self.set[b_key]

        merge_set = a_set.merge(a_set, b_set)
        merge_set.shuffle()
        a_set, b_set = merge_set.split(rate, shuffle=True)

        self.set[a_key] = a_set
        self.set[b_key] = b_set
        return a_set, b_set

    def merge(self, a_key, b_key, merge_set_key):
        a_set = self.set[a_key]
        b_set = self.set[b_key]
        self.set.pop(a_key)
        self.set.pop(b_key)

        merge_set = a_set.merge(a_set, b_set)
        self.set[merge_set_key] = merge_set

    def sort(self, sort_key=None):
        for key in self.set:
            self.set[key].sort(sort_key)
class AbstractDataset(metaclass=MetaTask):
    def __init__(self,
                 preprocess=None,
                 batch_after_task=None,
                 before_load_task=None):
        """
        init dataset attrs

        *** bellow attrs must initiate other value ***
        self._SOURCE_URL: (str) url for download dataset
        self._SOURCE_FILE: (str) file name of zipped dataset
        self._data_files = (str) files name in dataset
        self.batch_keys = (str) feature label of dataset,
            managing batch keys in dict_keys.dataset_batch_keys recommend

        :param preprocess: injected function for preprocess dataset
        :param batch_after_task: injected function for after iter mini_batch
        :param before_load_task: hookable function for AbstractDataset.before_load
        """
        self._SOURCE_URL = None
        self._SOURCE_FILE = None
        self._data_files = None
        self.batch_keys = None
        self.logger = Logger(self.__class__.__name__, stdout_only=True)
        self.log = self.logger.get_log()
        self.preprocess = preprocess
        self.batch_after_task = batch_after_task
        self.data = {}
        self.cursor = {}
        self.data_size = 0
        self.before_load_task = before_load_task

    def __del__(self):
        del self.data
        del self.cursor
        del self.logger
        del self.log
        del self.batch_after_task
        del self.batch_keys

    def __repr__(self):
        return self.__class__.__name__

    def before_load(self, path):
        """
        check dataset is valid and if dataset is not valid download dataset

        :param path: dataset path
        :return:
        """
        try:
            os.makedirs(path)
        except FileExistsError:
            pass

        is_Invalid = False
        files = glob(os.path.join(path, '*'))
        names = list(map(lambda file: os.path.split(file)[1], files))
        for data_file in self._data_files:
            if data_file not in names:
                is_Invalid = True

        if is_Invalid:
            head, _ = os.path.split(path)
            download_file = os.path.join(head, self._SOURCE_FILE)

            self.log('download %s at %s ' % (self._SOURCE_FILE, download_file))
            download_data(self._SOURCE_URL, download_file)

            self.log("extract %s at %s" % (self._SOURCE_FILE, head))
            extract_data(download_file, head)

    def after_load(self, limit=None):
        """
        after task for dataset and do execute preprocess for dataset

        init cursor for each batch_key
        limit dataset size
        execute preprocess

        :param limit: limit size of dataset
        :return:
        """
        for key in self.batch_keys:
            self.cursor[key] = 0

        for key in self.batch_keys:
            self.data[key] = self.data[key][:limit]

        for key in self.batch_keys:
            self.data_size = max(len(self.data[key]), self.data_size)
            self.log("batch data '%s' %d item(s) loaded" %
                     (key, len(self.data[key])))
        self.log('%s fully loaded' % self.__str__())

        if self.preprocess is not None:
            self.preprocess(self)
            self.log('%s preprocess end' % self.__str__())

    def load(self, path, limit=None):
        pass

    def save(self):
        raise NotImplementedError

    def _append_data(self, batch_key, data):
        if batch_key not in self.data:
            self.data[batch_key] = data
        else:
            self.data[batch_key] = np.concatenate((self.data[batch_key], data))

    def __next_batch(self, batch_size, key, lookup=False):
        data = self.data[key]
        cursor = self.cursor[key]
        data_size = len(data)

        # if batch size exceeds the size of data set
        over_data = batch_size // data_size
        if over_data > 0:
            whole_data = np.concatenate((data[cursor:], data[:cursor]))
            batch_to_append = np.repeat(whole_data, over_data, axis=0)
            batch_size -= data_size * over_data
        else:
            batch_to_append = None

        begin, end = cursor, (cursor + batch_size) % data_size

        if begin < end:
            batch = data[begin:end]
        else:
            first, second = data[begin:], data[:end]
            batch = np.concatenate((first, second))

        if batch_to_append:
            batch = np.concatenate((batch_to_append, batch))

        if not lookup:
            self.cursor[key] = end

        return batch

    def next_batch(self, batch_size, batch_keys=None, lookup=False):
        """
        return iter mini batch

        :param batch_size: size of mini batch
        :param batch_keys: (iterable type) select keys,
            if  batch_keys length is 1 than just return mini batch
            else return list of mini batch
        :param lookup: lookup == True cursor will not update
        :return: (numpy array type) list of mini batch, order is same with batch_keys



        ex)
        dataset.next_batch(3, ["train_x", "train_label"]) =
            [[train_x1, train_x2, train_x3], [train_label1, train_label2, train_label3]]

        dataset.next_batch(3, ["train_x", "train_label"], lookup=True) =
            [[train_x4, train_x5, train_x6], [train_label4, train_label5, train_label6]]

        dataset.next_batch(3, ["train_x", "train_label"]) =
            [[train_x4, train_x5, train_x6], [train_label4, train_label5, train_label6]]

        """
        if batch_keys is None:
            batch_keys = self.batch_keys

        batches = []
        for key in batch_keys:
            batches += [self.__next_batch(batch_size, key, lookup)]

        if self.batch_after_task is not None:
            batches = self.batch_after_task(batches)

        if len(batches) == 1:
            batches = batches[0]

        return batches
Beispiel #5
0
class DatasetLoader:
    """
    Todo
    """
    def __init__(self, root_path=ROOT_PATH):
        """create DatasetManager
        todo
        """
        self.root_path = root_path
        self.logger = Logger(self.__class__.__name__, self.root_path)
        self.log = self.logger.get_log()
        self.datasets = {}

    def __repr__(self):
        return self.__class__.__name__

    def load_dataset(self, dataset_name, limit=None):
        """load dataset, return dataset, input_shapes

        :type dataset_name: str
        :type limit: int
        :param dataset_name: dataset name to load
        :param limit: limit dataset_size

        :return: dataset, input_shapes

        :raise KeyError
        invalid dataset_name
        """
        try:
            if dataset_name not in self.datasets:
                self.import_dataset_and_helper(dataset_name=dataset_name)
            data_loader, data_helper = self.datasets[dataset_name]
            dataset, input_shapes = data_helper.load_dataset(limit=limit)
        except KeyError:
            raise KeyError("dataset_name %s not found" % dataset_name)

        return dataset, input_shapes

    def import_dataset_and_helper(self, dataset_name):
        """ import dataset_and_helper

        :type dataset_name: str
        :param dataset_name:
        """
        self.log('load %s dataset module' % dataset_name)
        paths = glob(os.path.join(DATA_HANDLER_PATH, '**', '*.py'),
                     recursive=True)

        dataset_path = None
        for path in paths:
            _, file_name = os.path.split(path)
            dataset_name_ = file_name.replace('.py', '')
            if dataset_name != dataset_name_:
                continue
            dataset_path = path

        if dataset_path is None:
            raise ModuleNotFoundError("dataset %s not found" % dataset_name)

        module_ = import_module_from_module_path(dataset_path)
        dataset = None
        helper = None
        for key in module_.__dict__:
            value = module_.__dict__[key]
            try:
                if issubclass(value, AbstractDataset):
                    dataset = value
                if issubclass(value, AbstractDatasetHelper):
                    helper = value
            except TypeError:
                pass

        if dataset is None:
            raise ModuleNotFoundError("dataset class %s not found" %
                                      dataset_name)
        if helper is None:
            raise ModuleNotFoundError("dataset helper class %s not found" %
                                      dataset_name)

        self.datasets[dataset_name] = (dataset, helper)
Beispiel #6
0
class InstanceManager:
    def __init__(self, root_path):
        self.root_path = root_path
        self.logger = Logger(self.__class__.__name__, self.root_path)
        self.log = self.logger.get_log()
        self.model = None
        self.visualizers = []
        self.sub_process = {}

    def __del__(self):
        # reset tensorflow graph
        tf.reset_default_graph()

        for process_name in self.sub_process:
            if self.sub_process[process_name].poll is None:
                self.close_subprocess(process_name)

        del self.root_path
        del self.log
        del self.logger
        del self.model
        del self.visualizers

    def gen_instance(self, model=None, input_shapes=None):
        # gen instance id
        model_name = "%s_%s_%.1f" % (model.AUTHOR, model.__name__,
                                     model.VERSION)
        instance_id = model_name + '_' + strftime("%Y-%m-%d_%H-%M-%S",
                                                  localtime())
        self.log('gen instance id : %s' % instance_id)

        # init instance directory
        instance_path = os.path.join(self.root_path, INSTANCE_FOLDER)
        if not os.path.exists(instance_path):
            os.mkdir(instance_path)

        # init user instance directory
        instance_path = os.path.join(self.root_path, INSTANCE_FOLDER,
                                     instance_id)
        if not os.path.exists(instance_path):
            os.mkdir(instance_path)

        instance_visual_result_folder_path = os.path.join(
            instance_path, VISUAL_RESULT_FOLDER)
        if not instance_visual_result_folder_path:
            os.mkdir(instance_visual_result_folder_path)

        instance_source_folder_path = os.path.join(instance_path, 'src_code')
        if not os.path.exists(instance_source_folder_path):
            os.mkdir(instance_source_folder_path)

        instance_summary_folder_path = os.path.join(instance_path, 'summary')
        if not os.path.exists(instance_summary_folder_path):
            os.mkdir(instance_summary_folder_path)
        self.log('init instance directory')

        # copy model's module file to instance/src/"model_id.py"
        instance_source_path = os.path.join(instance_source_folder_path,
                                            instance_id + '.py')
        try:
            copy(inspect.getsourcefile(model), instance_source_path)
        except IOError as e:
            print(e)
        self.log('dump model source code')

        # init and dump metadata
        metadata = {
            MODEL_METADATA_KEY_INSTANCE_ID:
            instance_id,
            MODEL_METADATA_KEY_INSTANCE_PATH:
            instance_path,
            MODEL_METADATA_KEY_INSTANCE_VISUAL_RESULT_FOLDER_PATH:
            instance_visual_result_folder_path,
            MODEL_METADATA_KEY_INSTANCE_SOURCE_FOLDER:
            instance_source_folder_path,
            MODEL_METADATA_KEY_INSTANCE_SOURCE_PATH:
            instance_source_path,
            MODEL_METADATA_KEY_INSTANCE_CLASS_NAME:
            model.__name__,
            MODEL_METADATA_KEY_README:
            self.gen_readme(),
            MODEL_METADATA_KEY_INSTANCE_SUMMARY_FOLDER_PATH:
            instance_summary_folder_path
        }
        metadata_path = os.path.join(instance_path, 'instance.meta')
        self.dump_json(metadata, metadata_path)
        self.log('dump metadata')

        # build model
        self.model = model(metadata, input_shapes)
        self.log('build model')

        self.metadata_path = metadata_path
        self.metadata = metadata

    def load_model(self, metadata_path, input_shapes):
        metadata = self.load_json(metadata_path)
        self.log('load metadata')

        instance_class_name = metadata[MODEL_METADATA_KEY_INSTANCE_CLASS_NAME]
        instance_source_path = metadata[
            MODEL_METADATA_KEY_INSTANCE_SOURCE_PATH]
        model = load_class_from_source_path(instance_source_path,
                                            instance_class_name)
        self.log('model source code load')

        self.model = model(metadata, input_shapes)
        self.log('build model')

        instance_id = metadata[MODEL_METADATA_KEY_INSTANCE_ID]
        self.log('load instance id : %s' % instance_id)

    def load_visualizer(self, visualizers):
        visualizer_path = self.model.instance_visual_result_folder_path
        for visualizer, iter_cycle in visualizers:
            if not os.path.exists(visualizer_path):
                os.mkdir(visualizer_path)

            self.visualizers += [
                visualizer(visualizer_path, iter_cycle=iter_cycle)
            ]
            self.log('visualizer %s loaded' % visualizer.__name__)

        self.log('visualizer fully Load')

    def train_model(self,
                    dataset,
                    epoch_time=TOTAL_EPOCH,
                    check_point_interval=CHECK_POINT_INTERVAL,
                    is_restore=False):
        saver = tf.train.Saver()
        save_path = os.path.join(self.model.instance_path, 'check_point')
        check_point_path = os.path.join(save_path, 'model.ckpt')
        if not os.path.exists(save_path):
            os.mkdir(save_path)
            self.log('make save dir')

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            self.log('init global variables')

            summary_writer = tf.summary.FileWriter(
                self.model.instance_summary_folder_path, sess.graph)
            self.log('init summary_writer')

            if is_restore:
                saver.restore(sess, check_point_path)
                self.log('restore check point')

            batch_size = self.model.batch_size
            iter_per_epoch = int(dataset.data_size / batch_size)
            self.log('total Epoch: %d, total iter: %d, iter per epoch: %d' %
                     (epoch_time, epoch_time * iter_per_epoch, iter_per_epoch))

            iter_num, loss_val_D, loss_val_G = 0, 0, 0
            for epoch in range(epoch_time):
                for _ in range(iter_per_epoch):
                    iter_num += 1
                    self.model.train_model(sess=sess,
                                           iter_num=iter_num,
                                           dataset=dataset)
                    self.__visualizer_task(sess, iter_num, dataset)

                    self.model.write_summary(sess=sess,
                                             iter_num=iter_num,
                                             dataset=dataset,
                                             summary_writer=summary_writer)

                    if iter_num % check_point_interval == 0:
                        saver.save(sess, check_point_path)

        self.log('train end')

        tf.reset_default_graph()
        self.log('reset default graph')

    def sample_model(self, is_restore=False):
        self.log('start train_model')
        saver = tf.train.Saver()

        save_path = os.path.join(self.model.instance_path, 'check_point')
        check_point_path = os.path.join(save_path, 'model.ckpt')
        if not os.path.exists(save_path):
            os.mkdir(save_path)
            self.log('make save dir')

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            if is_restore:
                saver.restore(sess, check_point_path)
                self.log('restore check point')

            self.__visualizer_task(sess)

        self.log('sampling end')

        tf.reset_default_graph()
        self.log('reset default graph')

    def __visualizer_task(self, sess, iter_num=None, dataset=None):
        for visualizer in self.visualizers:
            if iter_num is None or iter_num % visualizer.iter_cycle == 0:
                try:
                    visualizer.task(sess, iter_num, self.model, dataset)
                except Exception as err:
                    self.log('at visualizer %s \n %s' % (visualizer, err))

    def open_subprocess(self, args, process_name):
        if process_name in self.sub_process and self.sub_process[
                process_name].poll is not None:
            # TODO better error class
            raise AssertionError(
                "process '%s'(pid:%s) already exist and still running" %
                (process_name, self.sub_process[process_name].pid))

        self.sub_process[process_name] = subprocess.Popen(args)
        str_args = " ".join(map(str, args))
        pid = self.sub_process[process_name].pid
        self.log("open subprocess '%s',  pid: %s" % (str_args, pid))

    def close_subprocess(self, process_name):
        if process_name in self.sub_process:
            self.log("kill subprocess '%s', pid: %s" %
                     (process_name, self.sub_process[process_name].pid))
            self.sub_process[process_name].kill()
        else:
            raise KeyError("fail close subprocess, '%s' not found" %
                           process_name)

    def open_tensorboard(self):
        python_path = sys.executable
        option = '--logdir=' + self.model.instance_summary_folder_path
        args = [python_path, tensorboard_dir(), option]
        self.open_subprocess(args=args, process_name="tensorboard")

    def close_tensorboard(self):
        self.close_subprocess('tensorboard')

    @staticmethod
    def dump_json(obj, path):
        with open(path, 'w') as f:
            json.dump(obj, f)

    @staticmethod
    def load_json(path):
        with open(path, 'r') as f:
            metadata = json.load(f)
        return metadata

    @staticmethod
    def gen_readme():
        # TODO implement
        return {}
Beispiel #7
0
class AbstractModel:
    """Abstract class of model for tensorflow graph

    TODO add docstring

    """
    VERSION = 1.0
    AUTHOR = 'demetoir'

    def __str__(self):
        return "%s_%s_%.1f" % (self.AUTHOR, self.__class__.__name__,
                               self.VERSION)

    def __init__(self, logger_path=None):
        """create instance of AbstractModel

        :type logger_path: str
        :param logger_path: path for log file
        if logger_path is None, log ony stdout
        """
        if logger_path is None:
            self.logger = Logger(self.__class__.__name__, with_file=True)
        else:
            self.logger = Logger(self.__class__.__name__, logger_path)
        self.log = self.logger.get_log()

    def load_model(self, metadata=None, input_shapes=None, params=None):
        """load tensor graph of entire model

        load model instance and inject metadata and input_shapes

        :param params:
        :type metadata: dict
        :type input_shapes: dict
        :param metadata: metadata for model
        :param input_shapes: input shapes for tensorflow placeholder
        :param params:

        :raise FailLoadModelError
        if any Error raise while load model
        """
        try:
            self.log("load metadata")
            self.load_metadata(metadata)

            with tf.variable_scope("misc_ops"):
                self.log('load misc ops')
                self.load_misc_ops()

            with tf.variable_scope("hyper_parameter"):
                if params is None:
                    params = self.params
                self.log('load hyper parameter')
                self.load_hyper_parameter(params)

            if input_shapes is None:
                input_shapes = self.input_shapes
            self.log("load input shapes")
            self.load_input_shapes(input_shapes)

            self.log('load main tensor graph')
            self.load_main_tensor_graph()

            with tf.variable_scope('loss'):
                self.log('load loss')
                self.load_loss_function()

            with tf.variable_scope('train_ops'):
                self.log('load train ops')
                self.load_train_ops()

            with tf.variable_scope('summary_ops'):
                self.log('load summary load')
                self.load_summary_ops()
        except Exception:
            exc_type, exc_value, exc_traceback = sys.exc_info()
            self.log("\n", "".join(traceback.format_tb(exc_traceback)))
            raise FailLoadModelError("fail to load model")
        else:
            self.log("load model complete")

    def load_metadata(self, metadata=None):
        """load metadata

        :type metadata: dict
        :param metadata: metadata for model
        """
        if metadata is None:
            self.log('skip to load metadata')
            return

        self.instance_id = metadata[MODEL_METADATA_KEY_INSTANCE_ID]
        self.instance_path = metadata[MODEL_METADATA_KEY_INSTANCE_PATH]
        self.instance_visual_result_folder_path = metadata[
            MODEL_METADATA_KEY_INSTANCE_VISUAL_RESULT_FOLDER_PATH]
        self.instance_source_path = metadata[
            MODEL_METADATA_KEY_INSTANCE_SOURCE_PATH]
        self.instance_class_name = metadata[
            MODEL_METADATA_KEY_INSTANCE_CLASS_NAME]
        self.readme = metadata[MODEL_METADATA_KEY_README]
        self.instance_summary_folder_path = metadata[
            MODEL_METADATA_KEY_INSTANCE_SUMMARY_FOLDER_PATH]
        self.params = metadata[MODEL_METADATA_KEY_PARAMS]
        self.input_shapes = metadata[MODEL_METADATA_KEY_INPUT_SHAPES]

    def load_input_shapes(self, input_shapes):
        """load input shapes for tensor placeholder

        :type input_shapes: dict
        :param input_shapes: input shapes for tensor placeholder

        :raise NotImplementError
        if not Implemented
        """
        raise NotImplementedError

    def load_hyper_parameter(self, params=None):
        """load hyper parameter for model

        :param params:
        :raise NotImplementError
        if not implemented
        """
        raise NotImplementedError

    def load_main_tensor_graph(self):
        """load main tensor graph

        :raise NotImplementError
        if not implemented
        """
        raise NotImplementedError

    def load_loss_function(self):
        """load loss function of model

        :raise NotImplementError
        if not implemented
        """
        raise NotImplementedError

    def load_misc_ops(self):
        """load misc operation of model

        :raise NotImplementError
        if not implemented
        """
        with tf.variable_scope('misc_ops'):
            self.global_step = tf.get_variable(
                "global_step", shape=1, initializer=tf.zeros_initializer)
            with tf.variable_scope('op_inc_global_step'):
                self.op_inc_global_step = self.global_step.assign(
                    self.global_step + 1)

    def load_train_ops(self):
        """Load train operation of model

        :raise NotImplementError
        if not implemented
        """
        raise NotImplementedError

    def load_summary_ops(self):
        """load summary operation for tensorboard

        :raise NotImplemented
        if not implemented
        """
        raise NotImplementedError

    def train_model(self, sess=None, iter_num=None, dataset=None):
        """train model

        :type sess: Session object for tensorflow.Session
        :type iter_num: int
        :type dataset: AbstractDataset
        :param sess: session object for tensorflow
        :param iter_num: current iteration number
        :param dataset: dataset for train model

        :raise NotImplemented
        if not implemented
        """
        raise NotImplementedError

    def write_summary(self,
                      sess=None,
                      iter_num=None,
                      dataset=None,
                      summary_writer=None):
        """write summary of model for tensorboard

        :type sess: Session object for tensorflow.Session
        :type iter_num: int
        :type dataset: dataset_handler.AbstractDataset
        :type summary_writer: tensorflow.summary.FileWriter

        :param sess: session object for tensorflow
        :param iter_num: current iteration number
        :param dataset: dataset for train model
        :param summary_writer: file writer for tensorboard summary

        :raise NotImplementedError
        if not implemented
        """
        raise NotImplementedError
Beispiel #8
0
class AbstractModel:
    VERSION = 1.0
    AUTHOR = 'demetoir'

    def __str__(self):
        return "%s_%s_%.1f" % (self.AUTHOR, self.__class__.__name__,
                               self.VERSION)

    def __init__(self, metadata, input_shapes):
        self.logger = Logger(self.__class__.__name__,
                             metadata[MODEL_METADATA_KEY_INSTANCE_PATH])
        self.log = self.logger.get_log()

        self.load_meta_data(metadata)
        self.input_shapes(input_shapes)
        self.hyper_parameter()
        self.network()
        self.log('network build')
        self.loss()
        self.log('loss build')
        self.train_ops()
        self.log('train ops build')
        self.misc_ops()
        self.log('misc ops build')
        self.summary_op()
        self.log('summary build')

    def load_meta_data(self, metadata):
        self.instance_id = metadata[MODEL_METADATA_KEY_INSTANCE_ID]
        self.instance_path = metadata[MODEL_METADATA_KEY_INSTANCE_PATH]
        self.instance_visual_result_folder_path = metadata[
            MODEL_METADATA_KEY_INSTANCE_VISUAL_RESULT_FOLDER_PATH]
        self.instance_source_path = metadata[
            MODEL_METADATA_KEY_INSTANCE_SOURCE_PATH]
        self.instance_class_name = metadata[
            MODEL_METADATA_KEY_INSTANCE_CLASS_NAME]
        self.readme = metadata[MODEL_METADATA_KEY_README]
        self.instance_summary_folder_path = metadata[
            MODEL_METADATA_KEY_INSTANCE_SUMMARY_FOLDER_PATH]

    def input_shapes(self, input_shapes):
        raise NotImplementedError

    def hyper_parameter(self):
        raise NotImplementedError

    def network(self):
        raise NotImplementedError

    def loss(self):
        raise NotImplementedError

    def train_ops(self):
        raise NotImplementedError

    def misc_ops(self):
        # TODO scope problem ..?
        with tf.variable_scope('misc_ops'):
            self.global_step = tf.get_variable(
                "global_step", shape=[1], initializer=tf.zeros_initializer)
            with tf.variable_scope('op_inc_global_step'):
                self.op_inc_global_step = self.global_step.assign(
                    self.global_step + 1)

    def train_model(self, sess=None, iter_num=None, dataset=None):
        raise NotImplementedError

    def summary_op(self):
        raise NotImplementedError

    def write_summary(self,
                      sess=None,
                      iter_num=None,
                      dataset=None,
                      summary_writer=None):
        raise NotImplementedError
Beispiel #9
0
class AbstractVisualizer:
    """abstract class for visualizer for instance

    """
    def __init__(self, path=None, execute_interval=None, name=None):
        """create Visualizer

        :type path: str
        :type execute_interval: int
        :type name: str
        :param path: path for saving visualized result
        :param execute_interval: interval for execute
        :param name: naming for visualizer
        """

        self.execute_interval = execute_interval
        self.name = name
        self.visualizer_path = os.path.join(path, self.__str__())

        if not os.path.exists(path):
            os.mkdir(path)

        if not os.path.exists(self.visualizer_path):
            os.mkdir(self.visualizer_path)

        files = glob(os.path.join(self.visualizer_path, '*'))
        self.output_count = len(files)

        self.logger = Logger(self.__class__.__name__, self.visualizer_path)
        self.log = self.logger.get_log()

    def __str__(self):
        if self.name is not None:
            return self.name
        else:
            return self.__class__.__name__

    def __del__(self):
        del self.execute_interval
        del self.name
        del self.visualizer_path
        del self.log
        del self.logger

    def task(self, sess=None, iter_num=None, model=None, dataset=None):
        """visualizing task

        :type sess: tensorflow.Session
        :type iter_num: int
        :type model: AbstractModel
        :type dataset: AbstractDataset
        :param sess: current tensorflow session
        :param iter_num: current iteration number
        :param model: current visualizing model
        :param dataset: current visualizing dataset
        """
        raise NotImplementedError

    def save_np_img(self, np_img, file_name=None):
        """save np_img file in visualizer path

        :type np_img: numpy.Array
        :type file_name: strs
        :param np_img: np_img to save
        :param file_name: save file name
        default None
        if file_name is None, file name of np_img will be 'output_count.png'
        """
        if file_name is None:
            file_name = '{}.png'.format(str(self.output_count).zfill(8))

        pil_img = np_img_to_PIL_img(np_img)
        with open(os.path.join(self.visualizer_path, file_name), 'wb') as fp:
            pil_img.save(fp)
class InstanceManager:
    """ manager class for Instance

    step for managing instance
    1. build instance(if already built instance ignore this step)
    2. load_instance and visualizers
    3. train_instance or sampling_instance

    ex)for build instance and train
    manager = InstanceManager(env_path)
    instance_path = manager.build_instance(model)
    manager.load_instance(instance_path, input_shapes)
    manager.load_visualizer(visualizers)
    manager.train_instance(dataset, epoch, check_point_interval)

    ex) resume from training
    manager = InstanceManager(env_path)
    manager.load_instance(built_instance_path, input_shapes)
    manager.load_visualizer(visualizers)
    manager.train_instance(dataset, epoch, check_point_interval, is_restore=True)
    """

    def __init__(self, root_path=ROOT_PATH):
        """ create a 'InstanceManager' at env_path

        :type root_path: str
        :param root_path: env path for manager
        """
        self.root_path = root_path
        self.logger = Logger(self.__class__.__name__, self.root_path)
        self.log = self.logger.get_log()
        self.instance = None
        self.visualizers = {}
        self.subprocess = {}

    def __del__(self):
        """ destructor of InstanceManager

        clean up all memory, subprocess, logging, tensorflow graph
        """
        # reset tensorflow graph
        tf.reset_default_graph()

        for process_name in self.subprocess:
            if self.subprocess[process_name].poll is None:
                self.close_subprocess(process_name)

        del self.root_path
        del self.log
        del self.logger
        del self.instance
        del self.visualizers

    def build_instance(self, model=None, input_shapes=None, param=None):
        """build instance for model class and return instance path

        * model must be subclass of AbstractModel

        generate unique id to new instance and initiate folder structure
        dump model's script
        generate and save metadata for new instance
        return built instance's path

        :param input_shapes:
        :param param:
        :type model: class
        :param model: subclass of AbstractModel

        :return: built instance's path
        """
        if not issubclass(model, AbstractModel):
            raise TypeError("argument model expect subclass of AbstractModel")

        # gen instance id
        model_name = "%s_%s_%.1f" % (model.AUTHOR, model.__name__, model.VERSION)
        instance_id = model_name + '_' + strftime("%Y-%m-%d_%H-%M-%S", localtime())
        self.log('build instance: %s' % instance_id)

        # init new instance directory
        self.log('init instance directory')
        instance_path = os.path.join(self.root_path, INSTANCE_FOLDER, instance_id)
        if not os.path.exists(instance_path):
            os.mkdir(instance_path)

        instance_visual_result_folder_path = os.path.join(instance_path, VISUAL_RESULT_FOLDER)
        if not instance_visual_result_folder_path:
            os.mkdir(instance_visual_result_folder_path)

        instance_source_folder_path = os.path.join(instance_path, 'src_code')
        if not os.path.exists(instance_source_folder_path):
            os.mkdir(instance_source_folder_path)

        instance_summary_folder_path = os.path.join(instance_path, 'summary')
        if not os.path.exists(instance_summary_folder_path):
            os.mkdir(instance_summary_folder_path)

        self.log('dump instance source code')
        instance_source_path = os.path.join(instance_source_folder_path, instance_id + '.py')
        try:
            copy(inspect.getsourcefile(model), instance_source_path)
        except IOError as e:
            print(e)

        self.log("build_metadata")
        metadata_path = os.path.join(instance_path, 'instance.meta')
        metadata = {
            MODEL_METADATA_KEY_INSTANCE_ID: instance_id,
            MODEL_METADATA_KEY_INSTANCE_PATH: instance_path,
            MODEL_METADATA_KEY_INSTANCE_VISUAL_RESULT_FOLDER_PATH: instance_visual_result_folder_path,
            MODEL_METADATA_KEY_INSTANCE_SOURCE_FOLDER_PATH: instance_source_folder_path,
            MODEL_METADATA_KEY_INSTANCE_SOURCE_PATH: instance_source_path,
            MODEL_METADATA_KEY_INSTANCE_SUMMARY_FOLDER_PATH: instance_summary_folder_path,
            MODEL_METADATA_KEY_INSTANCE_CLASS_NAME: model.__name__,
            MODEL_METADATA_KEY_README: None,
            MODEL_METADATA_KEY_METADATA_PATH: metadata_path,
            MODEL_METADATA_KEY_PARAMS: param,
            MODEL_METADATA_KEY_INPUT_SHAPES: input_shapes,
        }

        self.log('dump metadata')
        dump_json(metadata, metadata_path)

        self.log('build complete')
        return instance_path

    def load_instance(self, instance_path):
        """ load built instance into InstanceManager

        import model class from dumped script in instance_path
        inject metadata and input_shapes into model
        load tensorflow graph from model
        load instance into InstanceManager

        * more information for input_shapes look dict_keys/input_shape_keys.py

        :type instance_path: str
        :param instance_path: instance path to loading
        """
        metadata = load_json(os.path.join(instance_path, 'instance.meta'))
        self.log('load metadata')

        instance_class_name = metadata[MODEL_METADATA_KEY_INSTANCE_CLASS_NAME]
        instance_source_path = metadata[MODEL_METADATA_KEY_INSTANCE_SOURCE_PATH]
        model = import_class_from_module_path(instance_source_path, instance_class_name)
        self.log('instance source code load')

        self.instance = model(metadata[MODEL_METADATA_KEY_INSTANCE_PATH])
        self.instance.load_model(metadata)
        self.log('load instance')

        instance_id = metadata[MODEL_METADATA_KEY_INSTANCE_ID]
        self.log('load instance id : %s' % instance_id)

    @deco_handle_exception
    def train_instance(self, epoch, dataset=None, check_point_interval=None, is_restore=False, with_tensorboard=True):
        """training loaded instance with dataset for epoch and loaded visualizers will execute

        * if you want to use visualizer call load_visualizer function first

        every check point interval, tensor variables will save at check point
        check_point_interval's default is one epoch, but scale of interval is number of iteration
        so if check_point_interval=3000, tensor variable save every 3000 per iter
        option is_restore=False is default
        if you want to restore tensor variables from check point, use option is_restore=True

        InstanceManager may open subprocess like tensorboard, raising error may cause some issue
        like subprocess still alive, while InstanceManager process exit
        so any error raise while training wrapper @log_exception will catch error
        KeyboardInterrupt raise, normal exit for abort training and return
        any other error will print error message and return

        :param epoch: total epoch for train
        :param dataset: dataset for train
        :param check_point_interval: interval for check point to save train tensor variables
        :param is_restore: option for restoring from check point
        :param with_tensorboard: option for open child process for tensorboard to monitor summary
        """

        if with_tensorboard:
            self.open_tensorboard()

        self.log("current loaded visualizers")
        for key in self.visualizers:
            self.log(key)

        with tf.Session() as sess:
            saver = tf.train.Saver()
            save_path = os.path.join(self.instance.instance_path, 'check_point')
            check_point_path = os.path.join(save_path, 'instance.ckpt')
            if not os.path.exists(save_path):
                os.mkdir(save_path)
                self.log('make save dir')

            self.log('init global variables')
            sess.run(tf.global_variables_initializer())

            self.log('init summary_writer')
            summary_writer = tf.summary.FileWriter(self.instance.instance_summary_folder_path, sess.graph)

            if is_restore:
                self.log('restore check point')
                saver.restore(sess, check_point_path)

            batch_size = self.instance.batch_size
            iter_per_epoch = int(dataset.train_set.data_size / batch_size)
            self.log('train set size: %d, total Epoch: %d, total iter: %d, iter per epoch: %d'
                     % (dataset.train_set.data_size, epoch, epoch * iter_per_epoch, iter_per_epoch))

            iter_num = 0
            for epoch_ in range(epoch):
                # TODO need concurrency
                dataset.shuffle()
                for _ in range(iter_per_epoch):
                    iter_num += 1
                    self.instance.train_model(sess=sess, iter_num=iter_num, dataset=dataset)
                    self.__visualizer_task(sess, iter_num, dataset)

                    self.instance.write_summary(sess=sess, iter_num=iter_num, dataset=dataset,
                                                summary_writer=summary_writer)

                    if iter_num % check_point_interval == 0:
                        saver.save(sess, check_point_path)
                # self.log("epoch %s end" % (epoch_ + 1))

            saver.save(sess, check_point_path)
        self.log('train end')

        if with_tensorboard:
            self.close_tensorboard()

    @deco_handle_exception
    def sampling_instance(self, dataset=None, is_restore=True):
        """sampling result from trained instance by running loaded visualizers

        * if you want to use visualizer call load_visualizer function first

        InstanceManager may open subprocess like tensorboard, raising error may cause some issue
        like subprocess still alive, while InstanceManager process exit
        so any error raise while training wrapper @log_exception will catch error
        KeyboardInterrupt raise, normal exit for abort training and return
        any other error will print error message and return

        :param dataset:
        :param is_restore: option for restoring from check point
        """
        self.log('start sampling_model')
        saver = tf.train.Saver()

        self.log("current loaded visualizers")
        for key in self.visualizers:
            self.log(key)

        save_path = os.path.join(self.instance.instance_path, 'check_point')
        check_point_path = os.path.join(save_path, 'instance.ckpt')
        if not os.path.exists(save_path):
            os.mkdir(save_path)

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            if is_restore:
                saver.restore(sess, check_point_path)
                self.log('restore check point')

            iter_num = 0
            for visualizer in self.visualizers.values():
                try:
                    visualizer.task(sess=sess, iter_num=iter_num, model=self.instance, dataset=dataset)
                except Exception as err:
                    log_error_trace(self.log, err, head='while execute %s' % visualizer)

        self.log('sampling end')

    def load_visualizer(self, visualizer, execute_interval, key=None):
        """load visualizer for training and sampling result of instance

        :type visualizer: AbstractVisualizer
        :param visualizer: list of tuple,
        :type execute_interval: int
        :param execute_interval: interval to execute visualizer per iteration
        :param key: key of visualizer dict
        """
        visualizer_path = self.instance.instance_visual_result_folder_path
        if not os.path.exists(visualizer_path):
            os.mkdir(visualizer_path)

        if key is None:
            key = visualizer.__name__

        self.visualizers[key] = visualizer(visualizer_path, execute_interval=execute_interval)
        self.log('visualizer %s loaded key=%s' % (visualizer.__name__, key))
        return key

    def unload_visualizer(self, key):
        if key not in self.visualizers:
            raise KeyError("fail to unload visualizer, key '%s' not found" % key)
        self.visualizers.pop(key, None)

    def unload_all_visualizer(self):
        for key in self.visualizers:
            self.visualizers.pop(key, None)

    def __visualizer_task(self, sess, iter_num=None, dataset=None):
        """execute loaded visualizers

        :type iter_num: int
        :type dataset: AbstractDataset
        :param sess: tensorflow.Session object
        :param iter_num: current iteration number
        :param dataset: feed for visualizers
        """
        for visualizer in self.visualizers.values():
            if iter_num is None or iter_num % visualizer.execute_interval == 0:
                try:
                    visualizer.task(sess, iter_num, self.instance, dataset)
                except Exception as err:
                    log_error_trace(self.log, err, head='while execute %s' % visualizer)

    def open_subprocess(self, args_, subprocess_key=None):
        """open subprocess with args and return pid

        :type args_: list
        :type subprocess_key: str
        :param args_: list of argument for new subprocess
        :param subprocess_key: key for self.subprocess of opened subprocess
        if subprocess_key is None, pid will be subprocess_key

        :raise ChildProcessError
        if same process name is already opened

        :return: pid for opened subprocess
        """

        if subprocess_key in self.subprocess and self.subprocess[subprocess_key].poll is not None:
            # TODO better error class

            raise AssertionError("process '%s'(pid:%s) already exist and still running" % (
                subprocess_key, self.subprocess[subprocess_key].pid))

        child_process = subprocess.Popen(args_)
        if subprocess_key is None:
            subprocess_key = str(child_process.pid)
        self.subprocess[subprocess_key] = child_process
        str_args = " ".join(map(str, args_))
        self.log("open subprocess pid:%s, cmd='%s'" % (child_process.pid, str_args))

        return child_process.pid

    def close_subprocess(self, subprocess_key):
        """close subprocess

        close opened subprocess of process_name

        :type subprocess_key: str
        :param subprocess_key: key for closing subprocess

        :raises KeyError
        if subprocess_key is not key for self.subprocess
        """
        if subprocess_key in self.subprocess:
            self.log("kill subprocess pid:%s, '%s'" % (self.subprocess[subprocess_key].pid, subprocess_key))
            self.subprocess[subprocess_key].kill()
        else:
            raise KeyError("fail close subprocess, '%s' not found" % subprocess_key)

    def open_tensorboard(self):
        """open tensorboard for current instance"""
        python_path = sys.executable
        option = '--logdir=' + self.instance.instance_summary_folder_path
        # option += ' --port 6006'
        # option += ' --debugger_port 6064'
        args_ = [python_path, tensorboard_dir(), option]
        self.open_subprocess(args_=args_, subprocess_key="tensorboard")

    def close_tensorboard(self):
        """close tensorboard for current instance"""
        self.close_subprocess('tensorboard')

    def get_tf_values(self, fetches, feed_dict):
        return self.instance.get_tf_values(self.sess, fetches, feed_dict)
class PbarPooling:
    def __init__(self,
                 func=None,
                 n_parallel=4,
                 initializer=None,
                 initargs=(),
                 child_timeout=30):
        self.logger = Logger(self.__class__.__name__)
        self.log = self.logger.get_log()

        self.func = func

        self.n_parallel = n_parallel

        if initializer is None:
            self.initializer = init_worker
        else:
            self.initializer = initializer
        self.initargs = initargs
        self.child_timeout = child_timeout

        self.pools = [
            Pool(1, initializer=init_worker, initargs=initargs)
            for _ in range(n_parallel)
        ]
        self.queues = [Queue() for _ in range(n_parallel)]
        self.pbar = None
        self.fail_list = []

    def map(self, func=None, jobs=None):
        self.map_async(func, jobs)
        return self.get()

    def map_async(self, func=None, jobs=None):
        if func is not None:
            self.func = func

        self.log('start pooling queue {} jobs'.format(len(jobs)))

        self.pbar = tqdm(total=len(jobs))

        def update_pbar(args):
            self.pbar.update(1)

        self.update_pbar = update_pbar

        self.jobs = jobs
        for i in range(len(jobs)):
            pool_id = i % self.n_parallel
            job = jobs[i]
            pool = self.pools[pool_id]
            child = pool.apply_async(func, job, callback=update_pbar)
            self.queues[pool_id].put((child, job))

    def get(self):
        rets = []
        while sum([q.qsize() for q in self.queues]) > 0:
            for pool_id in range(self.n_parallel):
                ret = None
                if self.queues[pool_id].qsize() == 0:
                    continue
                child, job = self.queues[pool_id].get()
                try:
                    ret = child.get(timeout=self.child_timeout)
                except KeyboardInterrupt:
                    self.log("KeyboardInterrupt terminate pools\n"
                             "{fail}/{total} fail".format(
                                 fail=len(self.fail_list),
                                 total=len(self.jobs)))
                    self.terminate()
                    raise KeyboardInterrupt
                except BaseException as e:
                    log_error_trace(self.log, e)
                    self.log("job fail, kill job={job}, child={child}".format(
                        child=str(None), job=str(job[3])))
                    self.pbar.update(1)
                    self.fail_list += [job]
                    self.pools[pool_id].terminate()
                    self.pools[pool_id].join()
                    self.pools[pool_id] = Pool(1,
                                               initializer=self.initializer,
                                               initargs=self.initargs)

                    new_queue = Queue()
                    while self.queues[pool_id].qsize() > 0:
                        _, job = self.queues[pool_id].get()
                        child = self.pools[pool_id].apply_async(
                            self.func, job, callback=self.update_pbar)
                        new_queue.put((child, job))

                    self.queues[pool_id] = new_queue

                finally:
                    rets += [ret]

        self.pbar.close()
        self.log("{fail}/{total} fail".format(fail=len(self.fail_list),
                                              total=len(self.jobs)))
        self.log('end pooling queue')
        return rets

    def save_fail_list(self, path=None):
        if path is None:
            path = os.path.join('.', 'fail_list', time_stamp())

        dump_pickle(self.fail_list, path + ".pkl")
        dump_json(list(map(str, self.fail_list)), path + ".json")

    def terminate(self):
        for pool_id in range(self.n_parallel):
            self.pools[pool_id].terminate()
            self.pools[pool_id].join()

        if self.pbar is not None:
            self.pbar.close()

    def close(self):
        for pool_id in range(self.n_parallel):
            self.pools[pool_id].close()
            self.pools[pool_id].close()
            self.pools[pool_id].join()

        if self.pbar is not None:
            self.pbar.close()