Beispiel #1
0
    def __init__(self):
        img_width_range = cfg.img_width_range
        word_len = cfg.word_len
        self.batch_size = cfg.batch_size
        self.visualize = cfg.visualize
        gpu_device_id = '/gpu:' + str(cfg.gpu_id)
        if cfg.gpu_id == -1:
            gpu_device_id = '/cpu:0'
            print("Using CPU model!")
        with tf.device(gpu_device_id):
            self.img_data = tf.placeholder(tf.float32,
                                           shape=(None, 1, 32, None),
                                           name='img_data')
            self.zero_paddings = tf.placeholder(tf.float32,
                                                shape=(None, None, 512),
                                                name='zero_paddings')

        self.bucket_specs = [(int(math.floor(64 / 4)), int(word_len + 2)),
                             (int(math.floor(108 / 4)), int(word_len + 2)),
                             (int(math.floor(140 / 4)), int(word_len + 2)),
                             (int(math.floor(256 / 4)), int(word_len + 2)),
                             (int(math.floor(img_width_range[1] / 4)),
                              int(word_len + 2))]
        buckets = self.buckets = self.bucket_specs

        self.decoder_inputs = []
        self.encoder_masks = []
        self.target_weights = []
        with tf.device(gpu_device_id):
            for i in xrange(int(buckets[-1][0] + 1)):
                self.encoder_masks.append(
                    tf.placeholder(tf.float32,
                                   shape=[None, 1],
                                   name="encoder_mask{0}".format(i)))
            for i in xrange(buckets[-1][1] + 1):
                self.decoder_inputs.append(
                    tf.placeholder(tf.int32,
                                   shape=[None],
                                   name="decoder{0}".format(i)))
                self.target_weights.append(
                    tf.placeholder(tf.float32,
                                   shape=[None],
                                   name="weight{0}".format(i)))
        self.bucket_min_width, self.bucket_max_width = img_width_range
        self.image_height = cfg.img_height
        self.valid_target_len = cfg.valid_target_len
        self.forward_only = True

        self.bucket_data = {
            i: BucketData()
            for i in range(self.bucket_max_width + 1)
        }

        with tf.device(gpu_device_id):
            cnn_model = CNN(self.img_data, True)  #(not self.forward_only))
            self.conv_output = cnn_model.tf_output()
            self.concat_conv_output = tf.concat(
                axis=1, values=[self.conv_output, self.zero_paddings])

            self.perm_conv_output = tf.transpose(self.concat_conv_output,
                                                 perm=[1, 0, 2])

        with tf.device(gpu_device_id):
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks=self.encoder_masks,
                encoder_inputs_tensor=self.perm_conv_output,
                decoder_inputs=self.decoder_inputs,
                target_weights=self.target_weights,
                target_vocab_size=cfg.target_vocab_size,
                buckets=self.buckets,
                target_embedding_size=cfg.target_embedding_size,
                attn_num_layers=cfg.attn_num_layers,
                attn_num_hidden=cfg.attn_num_hidden,
                forward_only=self.forward_only,
                use_gru=cfg.use_gru)
        #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1.0)
        self.sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True))
        self.saver_all = tf.train.Saver(tf.global_variables())
        self.saver_all.restore(self.sess, cfg.ocr_model_path)
Beispiel #2
0
    def __init__(self,
                 phase,
                 visualize,
                 output_dir,
                 batch_size,
                 initial_learning_rate,
                 steps_per_checkpoint,
                 model_dir,
                 target_embedding_size,
                 attn_num_hidden,
                 attn_num_layers,
                 clip_gradients,
                 max_gradient_norm,
                 session,
                 load_model,
                 gpu_id,
                 use_gru,
                 use_distance=True,
                 max_image_width=160,
                 max_image_height=60,
                 max_prediction_length=8,
                 channels=1,
                 reg_val=0):

        self.use_distance = use_distance

        # We need resized width, not the actual width
        max_resized_width = 1. * max_image_width / max_image_height * DataGen.IMAGE_HEIGHT

        self.max_original_width = max_image_width
        self.max_width = int(math.ceil(max_resized_width))

        self.encoder_size = int(math.ceil(1. * self.max_width / 4))
        self.decoder_size = max_prediction_length + 2
        self.buckets = [(self.encoder_size, self.decoder_size)]

        if gpu_id >= 0:
            device_id = '/gpu:' + str(gpu_id)
        else:
            device_id = '/cpu:0'
        self.device_id = device_id

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        if phase == 'test':
            batch_size = 1

        logging.info('phase: %s', phase)
        logging.info('model_dir: %s', model_dir)
        logging.info('load_model: %s', load_model)
        logging.info('output_dir: %s', output_dir)
        logging.info('steps_per_checkpoint: %d', steps_per_checkpoint)
        logging.info('batch_size: %d', batch_size)
        logging.info('learning_rate: %f', initial_learning_rate)
        logging.info('reg_val: %d', reg_val)
        logging.info('max_gradient_norm: %f', max_gradient_norm)
        logging.info('clip_gradients: %s', clip_gradients)
        logging.info('max_image_width %f', max_image_width)
        logging.info('max_prediction_length %f', max_prediction_length)
        logging.info('channels: %d', channels)
        logging.info('target_embedding_size: %f', target_embedding_size)
        logging.info('attn_num_hidden: %d', attn_num_hidden)
        logging.info('attn_num_layers: %d', attn_num_layers)
        logging.info('visualize: %s', visualize)

        if use_gru:
            logging.info('using GRU in the decoder.')

        self.reg_val = reg_val
        self.sess = session
        self.steps_per_checkpoint = steps_per_checkpoint
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.batch_size = batch_size
        self.global_step = tf.Variable(0, trainable=False)
        self.phase = phase
        self.visualize = visualize
        self.learning_rate = initial_learning_rate
        self.clip_gradients = clip_gradients
        self.channels = channels

        if phase == 'train':
            self.forward_only = False
        else:
            self.forward_only = True

        with tf.device(device_id):

            self.height = tf.constant(DataGen.IMAGE_HEIGHT, dtype=tf.int32)
            self.height_float = tf.constant(DataGen.IMAGE_HEIGHT,
                                            dtype=tf.float64)

            self.img_pl = tf.placeholder(tf.string,
                                         name='input_image_as_bytes')
            self.img_data = tf.cond(tf.less(tf.rank(self.img_pl), 1),
                                    lambda: tf.expand_dims(self.img_pl, 0),
                                    lambda: self.img_pl)
            self.img_data = tf.map_fn(self._prepare_image,
                                      self.img_data,
                                      dtype=tf.float32)
            num_images = tf.shape(self.img_data)[0]

            # TODO: create a mask depending on the image/batch size
            self.encoder_masks = []
            for i in xrange(self.encoder_size + 1):
                self.encoder_masks.append(tf.tile([[1.]], [num_images, 1]))

            self.decoder_inputs = []
            self.target_weights = []
            for i in xrange(self.decoder_size + 1):
                self.decoder_inputs.append(tf.tile([1], [num_images]))
                if i < self.decoder_size:
                    self.target_weights.append(tf.tile([1.], [num_images]))
                else:
                    self.target_weights.append(tf.tile([0.], [num_images]))

            cnn_model = CNN(self.img_data, not self.forward_only)
            self.conv_output = cnn_model.tf_output()
            self.perm_conv_output = tf.transpose(self.conv_output,
                                                 perm=[1, 0, 2])
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks=self.encoder_masks,
                encoder_inputs_tensor=self.perm_conv_output,
                decoder_inputs=self.decoder_inputs,
                target_weights=self.target_weights,
                target_vocab_size=len(DataGen.CHARMAP),
                buckets=self.buckets,
                target_embedding_size=target_embedding_size,
                attn_num_layers=attn_num_layers,
                attn_num_hidden=attn_num_hidden,
                forward_only=self.forward_only,
                use_gru=use_gru)

            table = tf.contrib.lookup.MutableHashTable(
                key_dtype=tf.int64,
                value_dtype=tf.string,
                default_value="",
                checkpoint=True,
            )

            insert = table.insert(
                tf.constant(list(range(len(DataGen.CHARMAP))), dtype=tf.int64),
                tf.constant(DataGen.CHARMAP),
            )

            with tf.control_dependencies([insert]):
                num_feed = []
                prb_feed = []

                for line in xrange(len(self.attention_decoder_model.output)):
                    guess = tf.argmax(
                        self.attention_decoder_model.output[line], axis=1)
                    proba = tf.reduce_max(tf.nn.softmax(
                        self.attention_decoder_model.output[line]),
                                          axis=1)
                    num_feed.append(guess)
                    prb_feed.append(proba)

                # Join the predictions into a single output string.
                trans_output = tf.transpose(num_feed)
                trans_output = tf.map_fn(
                    lambda m: tf.foldr(
                        lambda a, x: tf.cond(
                            tf.equal(x, DataGen.EOS_ID),
                            lambda: '',
                            lambda: table.lookup(x) + a  # pylint: disable=undefined-variable
                        ),
                        m,
                        initializer=''),
                    trans_output,
                    dtype=tf.string)

                # Calculate the total probability of the output string.
                trans_outprb = tf.transpose(prb_feed)
                trans_outprb = tf.gather(trans_outprb,
                                         tf.range(tf.size(trans_output)))
                trans_outprb = tf.map_fn(lambda m: tf.foldr(
                    lambda a, x: tf.multiply(tf.cast(x, tf.float64), a),
                    m,
                    initializer=tf.cast(1, tf.float64)),
                                         trans_outprb,
                                         dtype=tf.float64)

                self.prediction = tf.cond(
                    tf.equal(tf.shape(trans_output)[0], 1),
                    lambda: trans_output[0],
                    lambda: trans_output,
                )
                self.probability = tf.cond(
                    tf.equal(tf.shape(trans_outprb)[0], 1),
                    lambda: trans_outprb[0],
                    lambda: trans_outprb,
                )

                self.prediction = tf.identity(self.prediction,
                                              name='prediction')
                self.probability = tf.identity(self.probability,
                                               name='probability')

            if not self.forward_only:  # train
                self.updates = []
                self.summaries_by_bucket = []

                params = tf.trainable_variables()
                opt = tf.train.AdadeltaOptimizer(
                    learning_rate=initial_learning_rate)
                loss_op = self.attention_decoder_model.loss

                if self.reg_val > 0:
                    reg_losses = tf.get_collection(
                        tf.GraphKeys.REGULARIZATION_LOSSES)
                    logging.info('Adding %s regularization losses',
                                 len(reg_losses))
                    logging.debug('REGULARIZATION_LOSSES: %s', reg_losses)
                    loss_op = self.reg_val * tf.reduce_sum(
                        reg_losses) + loss_op

                gradients, params = list(
                    zip(*opt.compute_gradients(loss_op, params)))
                if self.clip_gradients:
                    gradients, _ = tf.clip_by_global_norm(
                        gradients, max_gradient_norm)

                # Summaries for loss, variables, gradients, gradient norms and total gradient norm.
                summaries = [
                    tf.summary.scalar("loss", loss_op),
                    tf.summary.scalar("total_gradient_norm",
                                      tf.global_norm(gradients))
                ]
                all_summaries = tf.summary.merge(summaries)
                self.summaries_by_bucket.append(all_summaries)

                # update op - apply gradients
                update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                with tf.control_dependencies(update_ops):
                    self.updates.append(
                        opt.apply_gradients(list(zip(gradients, params)),
                                            global_step=self.global_step))

        self.saver_all = tf.train.Saver(tf.all_variables())
        self.checkpoint_path = os.path.join(self.model_dir, "model.ckpt")

        ckpt = tf.train.get_checkpoint_state(model_dir)
        if ckpt and load_model:
            # pylint: disable=no-member
            logging.info("Reading model parameters from %s",
                         ckpt.model_checkpoint_path)
            self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            logging.info("Created model with fresh parameters.")
            self.sess.run(tf.initialize_all_variables())