Beispiel #1
0
class Model:
    def __init__(self, config):
        self.config = config
        # model
        self.dtn = DTN(32, config)
        # model optimizer
        self.dtn_op = tf.compat.v1.train.AdamOptimizer(config.LEARNING_RATE,
                                                       beta1=0.5)
        # model losses
        self.depth_map_loss = Error()
        self.class_loss = Error()
        self.route_loss = Error()
        self.uniq_loss = Error()
        # model saving setting
        self.last_epoch = 0
        self.checkpoint_manager = []

    def compile(self):
        checkpoint_dir = self.config.LOG_DIR
        checkpoint = tf.train.Checkpoint(dtn=self.dtn,
                                         dtn_optimizer=self.dtn_op)
        self.checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                             checkpoint_dir,
                                                             max_to_keep=10000)
        last_checkpoint = self.checkpoint_manager.latest_checkpoint
        checkpoint.restore(last_checkpoint)
        if last_checkpoint:
            self.last_epoch = int(last_checkpoint.split('-')[-1])
            print("Restored from {}".format(last_checkpoint))
        else:
            print("Initializing from scratch.")

    def train(self, train, val=None):
        config = self.config
        step_per_epoch = config.STEPS_PER_EPOCH
        step_per_epoch_val = config.STEPS_PER_EPOCH_VAL
        epochs = config.MAX_EPOCH

        # data stream
        it = train.feed
        global_step = self.last_epoch * step_per_epoch
        if val is not None:
            it_val = val.feed
        for epoch in range(self.last_epoch, epochs):
            start = time.time()
            # define the
            self.dtn_op = tf.compat.v1.train.AdamOptimizer(
                config.LEARNING_RATE, beta1=0.5)
            ''' train phase'''
            for step in range(step_per_epoch):
                depth_map_loss, class_loss, route_loss, uniq_loss, spoof_counts, eigenvalue, trace, _to_plot =\
                    self.train_one_step(next(it), global_step, True)
                # display loss
                global_step += 1
                print(
                    'Epoch {:d}-{:d}/{:d}: Map:{:.3g}, Cls:{:.3g}, Route:{:.3g}({:3.3f}, {:3.3f}), Uniq:{:.3g}, '
                    'Counts:[{:d},{:d},{:d},{:d},{:d},{:d},{:d},{:d}]     '.
                    format(
                        epoch + 1,
                        step + 1,
                        step_per_epoch,
                        self.depth_map_loss(depth_map_loss),
                        self.class_loss(class_loss),
                        self.route_loss(route_loss),
                        eigenvalue,
                        trace,
                        self.uniq_loss(uniq_loss),
                        spoof_counts[0],
                        spoof_counts[1],
                        spoof_counts[2],
                        spoof_counts[3],
                        spoof_counts[4],
                        spoof_counts[5],
                        spoof_counts[6],
                        spoof_counts[7],
                    ),
                    end='\r')
                # plot the figure
                if (step + 1) % 5 == 0:
                    fname = self.config.LOG_DIR + '/epoch-' + str(
                        epoch + 1) + '-train-' + str(1) + '.png'
                    plotResults(fname, _to_plot)

            # save the model
            if (epoch + 1) % 1 == 0:
                self.checkpoint_manager.save(checkpoint_number=epoch + 1)
            print('\n', end='\r')
            ''' eval phase'''
            if val is not None:
                for step in range(step_per_epoch_val):
                    depth_map_loss, class_loss, route_loss, uniq_loss, spoof_counts, eigenvalue, trace, _to_plot =\
                        self.train_one_step(next(it_val), global_step, False)
                    # display something
                    print(
                        '    Val-{:d}/{:d}: Map:{:.3g}, Cls:{:.3g}, Route:{:.3g}({:3.3f}, {:3.3f}), Uniq:{:.3g}, '
                        'Counts:[{:d},{:d},{:d},{:d},{:d},{:d},{:d},{:d}]     '
                        .format(
                            step + 1,
                            step_per_epoch_val,
                            self.depth_map_loss(depth_map_loss, val=1),
                            self.class_loss(class_loss, val=1),
                            self.route_loss(route_loss, val=1),
                            eigenvalue,
                            trace,
                            self.uniq_loss(uniq_loss, val=1),
                            spoof_counts[0],
                            spoof_counts[1],
                            spoof_counts[2],
                            spoof_counts[3],
                            spoof_counts[4],
                            spoof_counts[5],
                            spoof_counts[6],
                            spoof_counts[7],
                        ),
                        end='\r')
                    # plot the figure
                    if (step + 1) % 10 == 0:
                        fname = self.config.LOG_DIR + '/epoch-' + str(
                            epoch + 1) + '-val-' + str(step + 1) + '.png'
                        plotResults(fname, _to_plot)

                self.depth_map_loss.reset()
                self.class_loss.reset()
                self.route_loss.reset()
                self.uniq_loss.reset()

            # time of one epoch
            print('\n    Time taken for epoch {} is {:3g} sec'.format(
                epoch + 1,
                time.time() - start))

        return 0

    def train_one_step(self, data_batch, step, training):

        dtn = self.dtn
        dtn_op = self.dtn_op
        image, dmap, labels = data_batch
        # print(image.shape)
        # print(dmap.shape)
        # print(labels.shape)
        with tf.GradientTape() as tape:

            dmap_pred, cls_pred, route_value, leaf_node_mask, tru_loss, mu_update, eigenvalue, trace =\
                dtn(image, labels, True)

            # supervised feature loss
            depth_map_loss = leaf_l1_loss(dmap_pred,
                                          tf.image.resize(dmap, [32, 32]),
                                          leaf_node_mask)
            class_loss = leaf_l1_loss(cls_pred, labels, leaf_node_mask)
            supervised_loss = depth_map_loss + 0.001 * class_loss

            # unsupervised tree loss
            route_loss = tf.reduce_mean(
                tf.stack(tru_loss[0], axis=0) *
                [1., 0.5, 0.5, 0.25, 0.25, 0.25, 0.25])
            uniq_loss = tf.reduce_mean(
                tf.stack(tru_loss[1], axis=0) *
                [1., 0.5, 0.5, 0.25, 0.25, 0.25, 0.25])
            eigenvalue = np.mean(
                np.stack(eigenvalue, axis=0) *
                [1., 0.5, 0.5, 0.25, 0.25, 0.25, 0.25])
            trace = np.mean(
                np.stack(trace, axis=0) *
                [1., 0.5, 0.5, 0.25, 0.25, 0.25, 0.25])
            unsupervised_loss = 1 * route_loss + 0.001 * uniq_loss

            # total loss
            if step > 10000000:
                loss = supervised_loss + unsupervised_loss
            else:
                loss = supervised_loss

        if training:
            # back-propagate
            gradients = tape.gradient(loss, dtn.variables)
            dtn_op.apply_gradients(zip(gradients, dtn.variables))

            # Update mean values for each tree node
            mu_update_rate = self.config.TRU_PARAMETERS["mu_update_rate"]
            mu = [
                dtn.tru0.project.mu, dtn.tru1.project.mu, dtn.tru2.project.mu,
                dtn.tru3.project.mu, dtn.tru4.project.mu, dtn.tru5.project.mu,
                dtn.tru6.project.mu
            ]
            for mu, mu_of_visit in zip(mu, mu_update):
                if step == 0:
                    update_mu = mu_of_visit
                else:
                    update_mu = mu_of_visit * mu_update_rate + mu * (
                        1 - mu_update_rate)
                K.set_value(mu, update_mu)

        # leaf counts
        spoof_counts = []
        for leaf in leaf_node_mask:
            spoof_count = tf.reduce_sum(leaf[:, 0]).numpy()
            spoof_counts.append(int(spoof_count))

        _to_plot = [
            image[:, :, :, 0:3], image[:, :, :, 3:], dmap, dmap_pred[0],
            dmap_pred[1], dmap_pred[2], dmap_pred[3], dmap_pred[4],
            dmap_pred[5], dmap_pred[6], dmap_pred[7]
        ]

        return depth_map_loss, class_loss, route_loss, uniq_loss, spoof_counts, eigenvalue, trace, _to_plot
Beispiel #2
0
def main(argv=None):
    # Configurations
    config = Config(gpu='1',
                    root_dir='./data/train/',
                    root_dir_val='./data/val/',
                    mode='training')

    # Create data feeding pipeline.
    dataset_train = Dataset(config, 'train')
    dataset_val = Dataset(config, 'val')

    # Train Graph
    losses, g_op, d_op, fig = _step(config, dataset_train, training_nn=True)
    losses_val, _, _, fig_val = _step(config, dataset_val, training_nn=False)

    # Add ops to save and restore all the variables.
    saver = tf.train.Saver(max_to_keep=50, )
    with tf.Session(config=config.GPU_CONFIG) as sess:
        # Restore the model
        ckpt = tf.train.get_checkpoint_state(config.LOG_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            last_epoch = ckpt.model_checkpoint_path.split('/')[-1].split(
                '-')[-1]
            print('**********************************************************')
            print('Restore from Epoch ' + str(last_epoch))
            print('**********************************************************')
        else:
            init = tf.initializers.global_variables()
            last_epoch = 0
            sess.run(init)
            print('**********************************************************')
            print('Train from scratch.')
            print('**********************************************************')

        avg_loss = Error()
        print_list = {}
        for epoch in range(int(last_epoch), config.MAX_EPOCH):
            start = time.time()
            # Train one epoch
            for step in range(config.STEPS_PER_EPOCH):
                if step % config.G_D_RATIO == 0:
                    _losses = sess.run(losses + [g_op, d_op, fig])
                else:
                    _losses = sess.run(losses + [g_op, fig])

                # Logging
                print_list['g_loss'] = _losses[0]
                print_list['d_loss'] = _losses[1]
                print_list['a_loss'] = _losses[2]
                display_list = ['Epoch '+str(epoch+1)+'-'+str(step+1)+'/'+ str(config.STEPS_PER_EPOCH)+':'] +\
                               [avg_loss(x) for x in print_list.items()]
                print(*display_list + ['          '], end='\r')
                # Visualization
                if step % config.LOG_FR_TRAIN == 0:
                    fname = config.LOG_DIR + '/Epoch-' + str(
                        epoch + 1) + '-' + str(step + 1) + '.png'
                    cv2.imwrite(fname, _losses[-1])

            # Model saving
            saver.save(sess, config.LOG_DIR + '/ckpt', global_step=epoch + 1)
            print('\n', end='\r')

            # Validate one epoch
            for step in range(config.STEPS_PER_EPOCH_VAL):
                _losses = sess.run(losses_val + [fig_val])

                # Logging
                print_list['g_loss'] = _losses[0]
                print_list['d_loss'] = _losses[1]
                print_list['a_loss'] = _losses[2]
                display_list = ['Epoch '+str(epoch+1)+'-Val-'+str(step+1)+'/'+ str(config.STEPS_PER_EPOCH_VAL)+':'] +\
                               [avg_loss(x, val=1) for x in print_list.items()]
                print(*display_list + ['          '], end='\r')
                # Visualization
                if step % config.LOG_FR_TEST == 0:
                    fname = config.LOG_DIR + '/Epoch-' + str(
                        epoch + 1) + '-Val-' + str(step + 1) + '.png'
                    cv2.imwrite(fname, _losses[-1])

            # time of one epoch
            print('\n    Time taken for epoch {} is {:3g} sec'.format(
                epoch + 1,
                time.time() - start))
            avg_loss.reset()
class Model:
    def __init__(self, config):
        self.config = config
        # model
        self.dtn = DTN(32, config)
        # model optimizer
        self.dtn_op = tf.compat.v1.train.AdamOptimizer(config.LEARNING_RATE,
                                                       beta1=0.5)
        # model losses
        self.depth_map_loss = Error()
        self.class_loss = Error()
        self.route_loss = Error()
        self.uniq_loss = Error()
        # model saving setting
        self.last_epoch = 0
        self.checkpoint_manager = []

    def compile(self):
        checkpoint_dir = self.config.LOG_DIR
        checkpoint = tf.train.Checkpoint(dtn=self.dtn,
                                         dtn_optimizer=self.dtn_op)
        self.checkpoint_manager = tf.train.CheckpointManager(checkpoint,
                                                             checkpoint_dir,
                                                             max_to_keep=30)
        last_checkpoint = self.checkpoint_manager.latest_checkpoint
        checkpoint.restore(last_checkpoint)
        if last_checkpoint:
            self.last_epoch = int(last_checkpoint.split('-')[-1])
            logging.info("Restored from {}".format(last_checkpoint))
        else:
            logging.info("Initializing from scratch.")

    def test(self):
        dirs = self.config.DATA_DIR_TEST
        dataset = Dataset(self.config, 'test', dirs)
        for image, dmap, labels in dataset.feed:
            dmap_pred, cls_pred, route_value, leaf_node_mask = self.dtn(
                image, labels, False)
            # leaf counts
            spoof_counts = []
            for leaf in leaf_node_mask:
                spoof_count = tf.reduce_sum(leaf[:, 0]).numpy()
                spoof_counts.append(int(spoof_count))
            cls_total = tf.math.add_n(cls_pred) / len(cls_pred)
            index = 0
            for label in tf.unstack(labels):
                cls = cls_total[index].numpy()
                if cls < 0.8 or cls > 1.2:
                    logging.info("label: {}, cls: {}".format(
                        label.numpy(), cls))
                index += 1
            # logging.info("spoof_counts:{}".format(spoof_counts))

    def train(self):
        dirs = self.config.DATA_DIR
        live_dir = self.config.DATA_DIR_LIVE[0]
        while True:
            for dir in dirs:
                train_dirs = [d for d in dirs if d != dir]
                train_dirs.append(live_dir)
                train = Dataset(self.config, 'train', train_dirs, dir)
                epochs = int((self.config.MAX_EPOCH % len(dirs)) /
                             len(dirs)) + self.config.MAX_EPOCH
                self._train(train, self.last_epoch + epochs)
                self.last_epoch += epochs

    def _train(self, train, epochs):
        config = self.config
        step_per_epoch = config.STEPS_PER_EPOCH
        step_per_epoch_val = config.STEPS_PER_EPOCH_VAL
        logging.info(
            "Training for {} epochs with {} steps per, and {} steps per validation"
            .format(epochs - self.last_epoch, step_per_epoch,
                    step_per_epoch_val))

        # data stream
        global_step = self.last_epoch * step_per_epoch
        for epoch in range(self.last_epoch, epochs):
            start = time.time()
            # define the
            self.dtn_op = tf.compat.v1.train.AdamOptimizer(
                config.LEARNING_RATE, beta1=0.5)
            ''' train phase'''
            for step in range(step_per_epoch):
                depth_map_loss, class_loss, route_loss, uniq_loss, spoof_counts, eigenvalue, trace, _to_plot = \
                    self.train_one_step(next(train.feed), global_step, True)
                # display loss
                global_step += 1
                logging.info(
                    'Epoch {:d}-{:d}/{:d}: Map:{:.3g}, Cls:{:.3g}, Route:{:.3g}({:3.3f}, {:3.3f}), Uniq:{:.3g}, '
                    'Counts:[{:d},{:d},{:d},{:d},{:d},{:d},{:d},{:d}]     '.
                    format(epoch + 1, step + 1, step_per_epoch,
                           self.depth_map_loss(depth_map_loss),
                           self.class_loss(class_loss),
                           self.route_loss(route_loss), eigenvalue, trace,
                           self.uniq_loss(uniq_loss), spoof_counts[0],
                           spoof_counts[1], spoof_counts[2], spoof_counts[3],
                           spoof_counts[4], spoof_counts[5], spoof_counts[6],
                           spoof_counts[7]))
                # plot the figure
                # if (step + 1) % 400 == 0:
                #     fname = self.config.LOG_DIR + '/epoch-' + str(epoch + 1) + '-train-' + str(step + 1) + '.png'
                #     plotResults(fname, _to_plot)

            # save the model
            self.checkpoint_manager.save(checkpoint_number=epoch + 1)
            ''' eval phase'''
            if train.feed_val is not None:
                for step in range(step_per_epoch_val):
                    depth_map_loss, class_loss, route_loss, uniq_loss, spoof_counts, eigenvalue, trace, _to_plot = \
                        self.train_one_step(next(train.feed_val), global_step, False)
                    # display something
                    logging.info(
                        '    Val-{:d}/{:d}: Map:{:.3g}, Cls:{:.3g}, Route:{:.3g}({:3.3f}, {:3.3f}), Uniq:{:.3g}, '
                        'Counts:[{:d},{:d},{:d},{:d},{:d},{:d},{:d},{:d}]     '
                        .format(step + 1, step_per_epoch_val,
                                self.depth_map_loss(depth_map_loss, val=1),
                                self.class_loss(class_loss, val=1),
                                self.route_loss(route_loss,
                                                val=1), eigenvalue, trace,
                                self.uniq_loss(uniq_loss, val=1),
                                spoof_counts[0], spoof_counts[1],
                                spoof_counts[2], spoof_counts[3],
                                spoof_counts[4], spoof_counts[5],
                                spoof_counts[6], spoof_counts[7]))
                    # plot the figure
                    # if (step + 1) % 100 == 0:
                    #     fname = self.config.LOG_DIR + '/epoch-' + str(epoch + 1) + '-val-' + str(step+1) + '.png'
                    #     plotResults(fname, _to_plot)
                self.depth_map_loss.reset()
                self.class_loss.reset()
                self.route_loss.reset()
                self.uniq_loss.reset()

            # time of one epoch
            logging.info('Time taken for epoch {} is {:3g} sec'.format(
                epoch + 1,
                time.time() - start))
        return 0

    def train_one_step(self, data_batch, step, training):
        dtn = self.dtn
        dtn_op = self.dtn_op
        image, dmap, labels = data_batch
        with tf.GradientTape() as tape:
            dmap_pred, cls_pred, route_value, leaf_node_mask, tru_loss, mu_update, eigenvalue, trace = \
                dtn(image, labels, True)

            # supervised feature loss
            depth_map_loss = leaf_l1_loss(dmap_pred,
                                          tf.image.resize(dmap, [32, 32]),
                                          leaf_node_mask)
            class_loss = leaf_l1_loss(cls_pred, labels, leaf_node_mask)
            supervised_loss = depth_map_loss + 0.001 * class_loss

            # unsupervised tree loss
            route_loss = tf.reduce_mean(
                tf.stack(tru_loss[0], axis=0) *
                [1., 0.5, 0.5, 0.25, 0.25, 0.25, 0.25])
            uniq_loss = tf.reduce_mean(
                tf.stack(tru_loss[1], axis=0) *
                [1., 0.5, 0.5, 0.25, 0.25, 0.25, 0.25])
            eigenvalue = np.mean(
                np.stack(eigenvalue, axis=0) *
                [1., 0.5, 0.5, 0.25, 0.25, 0.25, 0.25])
            trace = np.mean(
                np.stack(trace, axis=0) *
                [1., 0.5, 0.5, 0.25, 0.25, 0.25, 0.25])
            unsupervised_loss = 2 * route_loss + 0.001 * uniq_loss

            # total loss
            if step > 10000:
                loss = supervised_loss + unsupervised_loss
            else:
                loss = supervised_loss

        if training:
            # back-propagate
            gradients = tape.gradient(loss, dtn.variables)
            dtn_op.apply_gradients(zip(gradients, dtn.variables))

            # Update mean values for each tree node
            mu_update_rate = self.config.TRU_PARAMETERS["mu_update_rate"]
            mu = [
                dtn.tru0.project.mu, dtn.tru1.project.mu, dtn.tru2.project.mu,
                dtn.tru3.project.mu, dtn.tru4.project.mu, dtn.tru5.project.mu,
                dtn.tru6.project.mu
            ]
            for mu, mu_of_visit in zip(mu, mu_update):
                if step == 0:
                    update_mu = mu_of_visit
                else:
                    update_mu = mu_of_visit * mu_update_rate + mu * (
                        1 - mu_update_rate)
                K.set_value(mu, update_mu)

        # leaf counts
        spoof_counts = []
        for leaf in leaf_node_mask:
            spoof_count = tf.reduce_sum(leaf[:, 0]).numpy()
            spoof_counts.append(int(spoof_count))

        _to_plot = [image, dmap, dmap_pred[0]]

        return depth_map_loss, class_loss, route_loss, uniq_loss, spoof_counts, eigenvalue, trace, _to_plot