Пример #1
0
    def params(self, run_id=None,
               learning_rate=0.01, learning_rate_decay_rate=0.99,
               learning_rate_decay_method=None, beta1=0.9, batch_size=100,
               net_type='VGG16', capacity=64, loss_type='MSE',
               use_l1_norm=False, l1_norm_rate=0.01,
               use_l2_norm=False, l2_norm_rate=0.01, comment=None):
        # net_type = 'InceptionV1'
        # net_type = 'InceptionV2'
        # net_type = 'InceptionV4'
        # net_type = 'ResNet18'
        # net_type = 'ResNet34'
        # net_type = 'ResNet50'
        # net_type = 'ResNet101'
        # net_type = 'ResNet152'

        if run_id is None:
            run_id = time_stamp()

        return to_dict(
            run_id=run_id,
            batch_size=batch_size,
            net_type=net_type,
            capacity=capacity,
            learning_rate=learning_rate,
            beta1=beta1,
            loss_type=loss_type,
            use_l1_norm=use_l1_norm,
            l1_norm_rate=l1_norm_rate,
            use_l2_norm=use_l2_norm,
            l2_norm_rate=l2_norm_rate,
            comment=comment
        )
Пример #2
0
    def plot_image(self, np_img, title=None, path=None, **kwargs):
        if title is None:
            title = time_stamp() + self.finger_print(6)

        extend = self.extend
        if path is None:
            path = path_join('.', 'matplot', title + extend)
        setup_file(path)

        np_image_save(np_img, path)
Пример #3
0
    def __init__(self, source_model, source_scope, verbose=0):
        super().__init__(verbose=verbose)

        if not source_model.is_built:
            raise RuntimeError(f'transfer fail, source model must be built')

        self.source_model = source_model
        self.source_scope = source_scope

        self.temp_dir = f'./temp_transfer'
        self.temp_path = path_join(self.temp_dir, time_stamp())
Пример #4
0
    def save_params(self, path=None):
        if path is None:
            path = os.path.join(self.params_save_path, time_stamp())

        params = self.export_params()

        pickle_path = path + '.pkl'
        dump_pickle(params, pickle_path)

        self.log.info('save params at {}'.format([pickle_path]))

        return pickle_path
Пример #5
0
    def params(
            self,
            run_id=None,
            learning_rate=0.01,
            beta1=0.9,
            batch_size=64,
            net_type='InceptionV1',
            n_classes=2,
            capacity=4,
            use_l1_norm=False,
            l1_norm_rate=0.01,
            use_l2_norm=False,
            l2_norm_rate=0.01,
            dropout_rate=0.5,
            fc_capacity=1024,
            fc_depth=2,
            comment=''
    ):
        # net_type = 'InceptionV1'
        # net_type = 'InceptionV2'
        # net_type = 'InceptionV4'
        # net_type = 'ResNet18'
        # net_type = 'ResNet34'
        # net_type = 'ResNet50'
        # net_type = 'ResNet101'
        # net_type = 'ResNet152'

        if run_id is None:
            run_id = time_stamp()

        return to_dict(
            run_id=run_id,
            batch_size=batch_size,
            net_type=net_type,
            capacity=capacity,
            learning_rate=learning_rate,
            beta1=beta1,
            n_classes=n_classes,
            use_l1_norm=use_l1_norm,
            l1_norm_rate=l1_norm_rate,
            use_l2_norm=use_l2_norm,
            l2_norm_rate=l2_norm_rate,
            dropout_rate=dropout_rate,
            fc_depth=fc_depth,
            fc_capacity=fc_capacity,
            comment=comment,
        )
Пример #6
0
    def plot_image_tile(self,
                        np_imgs,
                        column=10,
                        path=None,
                        title=None,
                        padding=3,
                        padding_value=0,
                        **kwargs):
        if title is None:
            title = time_stamp() + self.finger_print(6)

        extend = self.extend
        if path is None:
            path = path_join('.', 'matplot', title + extend)
        np_img_tile = np_img_to_tile(np_imgs,
                                     column_size=column,
                                     padding=padding,
                                     padding_value=padding_value)
        np_image_save(np_img_tile, path)
Пример #7
0
    def __init__(self, verbose=10, **kwargs):
        """create instance of AbstractModel

        :param verbose:
        :type logger_path: str
        :param logger_path: path for log file
        if logger_path is None, log ony stdout
        """
        LoggerMixIn.__init__(self, verbose=verbose)
        paramsMixIn.__init__(self)
        loss_packMixIn.__init__(self)

        self.sessionManager = SessionManager(
            sess=kwargs['sess'] if 'sess' in kwargs else None,
            config=kwargs['config'] if 'config' in kwargs else None,
        )

        self._is_input_shape_built = False
        self._is_graph_built = False

        self.verbose = verbose
        # gen instance id

        if 'run_id' in kwargs:
            self.run_id = kwargs['run_id']
        else:
            self.run_id = time_stamp()

        if 'id' in kwargs:
            id_ = kwargs['id']
        else:
            id_ = "_".join([
                "%s_%s" % (self.AUTHOR, self.__class__.__name__), self.run_id
            ])
        self.metadata = ModelMetadata(id_=id_, )

        self.singleton_dropout = SingletonDropout()
        self.singleton_bn = SingletonBN()
        self.is_train_phase = True
Пример #8
0
    def teardown_matplot(self,
                         fig,
                         path=None,
                         show=None,
                         title=None,
                         extend=None,
                         dpi=None,
                         save=None):
        if extend is None:
            extend = self.extend

        if save is None:
            save = self.save

        if show is None:
            show = self.show

        if title is None:
            title = time_stamp() + self.finger_print(6)
        self.plt.title(title)
        self.plt.tight_layout()

        if path is None:
            path = path_join('.', 'matplot', title + extend)
        setup_file(path)

        if dpi is None:
            dpi = self.dpi

        if save:
            fig.savefig(path, dpi=dpi)

        if show:
            self.plt.show()

        # TODO check bug fix
        self.plt.cla()
        self.plt.close(fig)
Пример #9
0
    def params(self,
               run_id=None,
               verbose=10,
               learning_rate=0.01,
               beta1=0.9,
               batch_size=32,
               stage=4,
               loss_type='BCE+dice_soft',
               n_classes=1,
               net_type='FusionNet',
               capacity=16,
               depth=2,
               dropout_rate=0.5,
               comment=''):
        # loss_type = 'pixel_wise_softmax'
        # loss_type = 'iou'
        # loss_type = 'dice_soft'
        # loss_type = 'BCE+dice_soft'

        if run_id is None:
            run_id = time_stamp()

        params = to_dict(run_id=run_id,
                         verbose=verbose,
                         learning_rate=learning_rate,
                         beta1=beta1,
                         batch_size=batch_size,
                         stage=stage,
                         net_type=net_type,
                         loss_type=loss_type,
                         capacity=capacity,
                         n_classes=n_classes,
                         depth=depth,
                         dropout_rate=dropout_rate,
                         comment=comment)
        return params
Пример #10
0
    def save_fail_list(self, path=None):
        if path is None:
            path = os.path.join('.', 'fail_list', time_stamp())

        dump_pickle(self.fail_list, path + ".pkl")
        dump_json(list(map(str, self.fail_list)), path + ".json")
Пример #11
0
    def __init__(self,
                 name,
                 path=None,
                 file_name=None,
                 level='INFO',
                 with_file=True,
                 empty_stdout_format=True,
                 rotating_file=True):
        """create logger

        :param name:name of logger
        :param path: log file path
        :param file_name: log file name
        :param level: logging level
        :param with_file: default False
        if std_only is True, log message print only stdout not in logfile
        """
        self.name = name
        self.logger = logging.getLogger(name + time_stamp())
        self.logger.setLevel('DEBUG')

        self.file_handler = None
        if with_file:
            if path is None:
                path = os.path.join('.', 'log')
            if file_name is None:
                file_name = "{name}.log".format(name=name)

            check_path(path)
            if rotating_file:
                max_byte = 16 * 1024 * 1024
                self.file_handler = RotatingFileHandler(os.path.join(
                    path, file_name),
                                                        maxBytes=max_byte,
                                                        backupCount=2)
                self.file_handler.setFormatter(
                    logging.Formatter(self.FILE_LOGGER_FORMAT))
                self.file_handler.setLevel('DEBUG')
                self.logger.addHandler(self.file_handler)
            else:
                self.file_handler = logging.FileHandler(
                    os.path.join(path, file_name))
                self.file_handler.setFormatter(
                    logging.Formatter(self.FILE_LOGGER_FORMAT))
                self.file_handler.setLevel('DEBUG')
                self.logger.addHandler(self.file_handler)

        if empty_stdout_format:
            format_ = self.EMPTY_FORMAT
        else:
            format_ = self.STDOUT_LOGGER_FORMAT
        formatter = logging.Formatter(format_)
        self.stream_handler = logging.StreamHandler()
        self.stream_handler.setFormatter(formatter)
        self.stream_handler.setLevel(level)
        self.logger.addHandler(self.stream_handler)

        self._fatal = deco_args_to_str(getattr(self.logger, 'fatal'))
        self._error = deco_args_to_str(getattr(self.logger, 'error'))
        self._warn = deco_args_to_str(getattr(self.logger, 'warn'))
        self._info = deco_args_to_str(getattr(self.logger, 'info'))
        self._debug = deco_args_to_str(getattr(self.logger, 'debug'))
        self._critical = deco_args_to_str(getattr(self.logger, 'critical'))