Exemple #1
0
def train(data, class_weights, flags, net: Net, framework: Framework,
          manager: tf.train.CheckpointManager):
    log = get_logger()
    io = SharedFlagIO(flags, subprogram=True)
    flags = io.read_flags() if io.read_flags() is not None else flags
    log.info('Building {} train op'.format(flags.model))
    goal = len(data) * flags.epoch
    first = True
    for i, (x_batch,
            loss_feed) in enumerate(framework.shuffle(data, class_weights)):
        loss = net(x_batch, training=True, **loss_feed)
        step = net.step.numpy()
        lr = net.optimizer.learning_rate.numpy()
        line = 'step: {} loss: {:f} lr: {:.2e} progress: {:.2f}%'
        if not first:
            flags.progress = i * flags.batch / goal * 100
            log.info(line.format(step, loss, lr, flags.progress))
        else:
            log.info(f"Following gradient from step {step}...")
        io.send_flags()
        flags = io.read_flags()
        ckpt = bool(not step % flags.save)
        if ckpt and not first:
            save = manager.save()
            log.info(f"Saved checkpoint: {save}")
        first = False
    if not ckpt:
        save = manager.save()
        log.info(f"Finished training at checkpoint: {save}")
Exemple #2
0
def loss(self, y_pred, y_true):
    losses = self.type.keys()
    loss_type = self.meta['type'].strip('[]')
    out_size = self.meta['out_size']
    H, W, _ = self.meta['inp_size']
    HW = H * W
    try:
        assert loss_type in losses, f'Loss type {loss_type} not implemented'
    except AssertionError as e:
        self.flags.error = str(e)
        self.logger.error(str(e))
        SharedFlagIO.send_flags(self)
        raise

    if self.first:
        self.logger.info('{} loss hyper-parameters:'.format(
            self.meta['model']))
        self.logger.info('Input Grid Size   = {}'.format(HW))
        self.logger.info('Number of Outputs = {}'.format(out_size))
        self.first = False

    diff = y_true - y_pred
    if loss_type in ['sse', '12']:
        return tf.nn.l2_loss(diff)
    elif loss_type == 'mse':
        return tf.keras.losses.MSE(y_true, y_pred)
    elif loss_type == ['smooth']:
        small = tf.cast(diff < 1, tf.float32)
        large = 1. - small
        return L1L2(tf.multiply(diff, large), tf.multiply(diff, small))
    elif loss_type in ['sparse', 'l1']:
        return l1(diff)
    elif loss_type == 'softmax':
        _loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_pred)
        return tf.reduce_mean(_loss)
Exemple #3
0
class Darknet(object):
    def __init__(self, flags):
        self.io = SharedFlagIO(subprogram=True)
        self.get_weight_src(flags)
        self.modify = False
        self.io.logger.info('Parsing {}'.format(self.src_cfg))
        src_parsed = self.create_ops()
        self.meta, self.layers = src_parsed
        # uncomment for v1 behavior
        # self.load_weights()

    def get_weight_src(self, flags):
        """
        analyse flags.load to know where is the
        source binary and what is its config.
        can be: None, flags.model, or some other
        """
        self.src_bin = flags.model + WGT_EXT
        self.src_bin = flags.binary + self.src_bin
        self.src_bin = os.path.abspath(self.src_bin)
        exist = os.path.isfile(self.src_bin)

        if flags.load == str():
            flags.load = int()
        if type(flags.load) is int:
            self.src_cfg = flags.model
            if flags.load:
                self.src_bin = None
            elif not exist:
                self.src_bin = None
        else:
            self.src_bin = flags.load
            name = self.model_name(flags.load)
            cfg_path = os.path.join(flags.config, name + CFG_EXT)
            if not os.path.isfile(cfg_path):
                self.io.logger.warn(
                    f'{cfg_path} not found, use {flags.model} instead')
                cfg_path = flags.model
            self.src_cfg = cfg_path
            flags.load = int()

    @staticmethod
    def model_name(file_path):
        file_name = os.path.basename(file_path)
        ext = str()
        if '.' in file_name:  # exclude extension
            file_name = file_name.split('.')
            ext = file_name[-1]
            file_name = '.'.join(file_name[:-1])
        if ext == str() or ext == 'meta':  # ckpt file
            file_name = file_name.split('-')
            num = int(file_name[-1])
            return '-'.join(file_name[:-1])
        if ext == 'weights':
            return file_name

    def create_ops(self):
        """
        return a list of `layers` objects (darkop.py)
        given path to binaries/ and configs/
        """
        cfg_layers = ConfigParser.create(self.src_cfg)

        meta = dict()
        layers = list()
        try:
            for i, info in enumerate(cfg_layers):
                if i == 0:
                    meta = info
                    continue
                else:
                    new = create_darkop(*info)
                layers.append(new)
        except TypeError as e:
            self.io.flags.error = str(e)
            self.io.logger.error(str(e))
            self.io.send_flags()
            raise
        return meta, layers

    def load_weights(self):
        """
        Use `layers` and Loader to load .weights file
        """
        self.io.logger.info(f'Loading {self.src_bin} ...')
        start = time.time()

        args = [self.src_bin, self.layers]
        wgts_loader = Loader.create(*args)
        for layer in self.layers:
            layer.load(wgts_loader)

        stop = time.time()
        self.io.logger.info('Finished in {}s'.format(stop - start))
Exemple #4
0
class TFNet:
    # Interface Methods:
    def __init__(self, flags, darknet=None):
        self.io = SharedFlagIO(subprogram=True)
        # disable eager mode for TF1-dependent code
        tf.compat.v1.disable_eager_execution()
        self.flags = self.io.read_flags() if self.io.read_flags(
        ) is not None else flags
        self.io_flags = self.io.io_flags
        self.logger = get_logger()
        self.ntrain = 0
        darknet = Darknet(flags) if darknet is None else darknet
        self.ntrain = len(darknet.layers)
        self.darknet = darknet
        self.num_layer = len(darknet.layers)
        self.framework = Framework.create(darknet.meta, flags)
        self.annotation_data = self.framework.parse()
        self.meta = darknet.meta
        self.graph = tf.Graph()
        device_name = flags.gpu_name if flags.gpu > 0.0 else None
        start = time.time()
        with tf.device(device_name):
            with self.graph.as_default():
                self.build_forward()
                self.setup_meta_ops()
        self.logger.info('Finished in {}s'.format(time.time() - start))

    def raise_error(self, error: Exception, traceback=None):
        form = "{}\nOriginal Tensorflow Error: {}"
        try:
            raise error
        except Exception as e:
            if traceback:
                oe = traceback.message
                self.flags.error = form.format(str(e), oe)
            else:
                self.flags.error = str(e)
            self.logger.error(str(e))
            self.io.send_flags()
            raise

    def build_forward(self):
        # Placeholders
        inp_size = self.meta['inp_size']
        self.inp = tf.keras.layers.Input(dtype=tf.float32,
                                         shape=tuple(inp_size),
                                         name='input')
        self.feed = dict()  # other placeholders

        # Build the forward pass
        state = identity(self.inp)
        roof = self.num_layer - self.ntrain
        self.logger.info(LINE)
        self.logger.info(HEADER)
        self.logger.info(LINE)
        for i, layer in enumerate(self.darknet.layers):
            scope = '{}-{}'.format(str(i), layer.type)
            args = [layer, state, i, roof, self.feed]
            state = op_create(*args)
            mess = state.verbalise()
            msg = mess if mess else LINE
            self.logger.info(msg)

        self.top = state
        self.out = tf.identity(state.out, name='output')

    def setup_meta_ops(self):
        tf.config.set_soft_device_placement(False)
        tf.debugging.set_log_device_placement(False)
        utility = min(self.flags.gpu, 1.)
        if utility > 0.0:
            tf.config.set_soft_device_placement(True)
        else:
            self.logger.info('Running entirely on CPU')

        if self.flags.train:
            self.build_train_op()

        if self.flags.summary:
            self.summary_op = tf.compat.v1.summary.merge_all()
            self.writer = tf.compat.v1.summary.FileWriter(
                self.flags.summary + self.flags.project_name)

        self.sess = tf.compat.v1.Session()
        self.sess.run(tf.compat.v1.global_variables_initializer())

        if not self.ntrain:
            return

        try:
            self.saver = tf.compat.v1.train.Saver(
                tf.compat.v1.global_variables())

            if self.flags.load != 0:
                self.load_from_ckpt()

        except tf.errors.NotFoundError as e:
            self.flags.error = str(e.message)
            self.send_flags()
            raise

        if self.flags.summary:
            self.writer.add_graph(self.sess.graph)

    def load_from_ckpt(self):
        if self.flags.load < 0:  # load lastest ckpt

            with open(os.path.join(self.flags.backup, 'checkpoint'), 'r') as f:
                last = f.readlines()[-1].strip()
                load_point = last.split(' ')[1]
                load_point = load_point.split('"')[1]
                print(load_point)
                load_point = load_point.split('-')[-1]
                self.flags.load = int(load_point)

        load_point = os.path.join(self.flags.backup, self.meta['name'])
        load_point = '{}-{}'.format(load_point, self.flags.load)
        self.logger.info('Loading from {}'.format(load_point))
        try:
            self.saver.restore(self.sess, load_point)
        except ValueError:
            self.load_old_graph(load_point)

    def load_old_graph(self, ckpt):
        ckpt_loader = Loader.create(ckpt)
        self.logger.info(old_graph_msg.format(ckpt))

        for var in tf.compat.v1.global_variables():
            name = var.name.split(':')[0]
            args = [name, var.get_shape()]
            val = ckpt_loader(*args)
            if val is None:
                self.raise_error(VariableIsNone(var))
            shp = val.shape
            plh = tf.compat.v1.placeholder(tf.float32, shp)
            op = tf.compat.v1.assign(var, plh)
            self.sess.run(op, {plh: val})

    def build_train_op(self):
        self.framework.loss(self.out)
        self.logger.info('Building {} train op'.format(self.meta['model']))
        self.global_step = tf.Variable(0, trainable=False)
        # setup kwargs for trainer
        kwargs = dict()
        if self.flags.trainer in ['momentum', 'rmsprop', 'nesterov']:
            kwargs.update({'momentum': self.flags.momentum})
        if self.flags.trainer == 'nesterov':
            kwargs.update({self.flags.trainer: True})
        if self.flags.trainer == 'AMSGrad':
            kwargs.update({self.flags.trainer.lower(): True})
        if self.flags.clip:
            kwargs.update({'clipnorm': self.flags.clip_norm})

        # setup cyclic_learning_rate args
        ssc = self.flags.step_size_coefficient
        step_size = int(ssc * (len(self.annotation_data) // self.flags.batch))
        clr_kwargs = {
            'global_step': self.global_step,
            'mode': self.flags.clr_mode,
            'step_size': step_size,
            'learning_rate': self.flags.lr,
            'max_lr': self.flags.max_lr,
            'name': 'learning-rate'
        }

        # setup trainer
        self.optimizer = TRAINERS[self.flags.trainer](clr(**clr_kwargs),
                                                      **kwargs)

        # setup gradients for all globals except the global_step
        vars = tf.compat.v1.global_variables()[:-1]  #
        grads = self.optimizer.get_gradients(self.framework.loss, vars)
        self.train_op = self.optimizer.apply_gradients(zip(grads, vars))