예제 #1
0
    def __init__(self,
                 phase,
                 visualize,
                 data_path,
                 data_path_valid,
                 data_path_test,
                 data_base_dir,
                 output_dir,
                 batch_size,
                 initial_learning_rate,
                 num_epoch,
                 steps_per_checkpoint,
                 target_vocab_size,
                 model_dir,
                 target_embedding_size,
                 attn_num_hidden,
                 attn_num_layers,
                 clip_gradients,
                 max_gradient_norm,
                 session,
                 load_model,
                 gpu_id,
                 use_gru,
                 evaluate=False,
                 valid_target_length=float('inf'),
                 reg_val=0):

        gpu_device_id = '/gpu:' + str(gpu_id)
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        logging.info('loading data')
        #print (data_base_dir)
        # load data
        if phase == 'train':
            self.s_gen = DataGen(data_base_dir,
                                 data_path,
                                 valid_target_len=valid_target_length,
                                 evaluate=False)
            self.s_gen_valid = DataGen(data_base_dir,
                                       data_path_valid,
                                       evaluate=True)
            self.s_gen_test = DataGen(data_base_dir,
                                      data_path_test,
                                      evaluate=True)
        else:
            batch_size = 1
            self.s_gen = DataGen(data_base_dir, data_path, evaluate=True)

        #logging.info('valid_target_length: %s' %(str(valid_target_length)))
        logging.info('phase: %s' % phase)
        logging.info('model_dir: %s' % (model_dir))
        logging.info('load_model: %s' % (load_model))
        logging.info('output_dir: %s' % (output_dir))
        logging.info('steps_per_checkpoint: %d' % (steps_per_checkpoint))
        logging.info('batch_size: %d' % (batch_size))
        logging.info('num_epoch: %d' % num_epoch)
        logging.info('learning_rate: %d' % initial_learning_rate)
        logging.info('reg_val: %d' % (reg_val))
        logging.info('max_gradient_norm: %f' % max_gradient_norm)
        logging.info('clip_gradients: %s' % clip_gradients)
        logging.info('valid_target_length %f' % valid_target_length)
        logging.info('target_vocab_size: %d' % target_vocab_size)
        logging.info('target_embedding_size: %f' % target_embedding_size)
        logging.info('attn_num_hidden: %d' % attn_num_hidden)
        logging.info('attn_num_layers: %d' % attn_num_layers)
        logging.info('visualize: %s' % visualize)

        buckets = self.s_gen.bucket_specs
        logging.info('buckets')
        logging.info(buckets)
        if use_gru:
            logging.info('ues GRU in the decoder.')

        # variables
        self.img_data = tf.placeholder(tf.float32,
                                       shape=(None, 1, 32, None),
                                       name='img_data')
        self.zero_paddings = tf.placeholder(tf.float32,
                                            shape=(None, None, 512),
                                            name='zero_paddings')

        self.decoder_inputs = []
        self.encoder_masks = []
        self.target_weights = []
        for i in xrange(int(buckets[-1][0] + 1)):
            self.encoder_masks.append(
                tf.placeholder(tf.float32,
                               shape=[None, 1],
                               name="encoder_mask{0}".format(i)))
        for i in xrange(buckets[-1][1] + 1):
            self.decoder_inputs.append(
                tf.placeholder(tf.int32,
                               shape=[None],
                               name="decoder{0}".format(i)))
            self.target_weights.append(
                tf.placeholder(tf.float32,
                               shape=[None],
                               name="weight{0}".format(i)))

        self.reg_val = reg_val
        self.sess = session
        self.evaluate = evaluate
        self.steps_per_checkpoint = steps_per_checkpoint
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.buckets = buckets
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.global_step = tf.Variable(0, trainable=False)
        self.valid_target_length = valid_target_length
        self.phase = phase
        self.visualize = visualize
        self.learning_rate = initial_learning_rate
        self.clip_gradients = clip_gradients

        if phase == 'train':
            self.forward_only = False
        elif phase == 'test':
            self.forward_only = True
        else:
            assert False, phase

        with tf.device(gpu_device_id):
            cnn_model = CNN(self.img_data, True)  #(not self.forward_only))
            self.conv_output = cnn_model.tf_output()
            self.concat_conv_output = tf.concat(
                axis=1, values=[self.conv_output, self.zero_paddings])

            self.perm_conv_output = tf.transpose(self.concat_conv_output,
                                                 perm=[1, 0, 2])

        with tf.device(gpu_device_id):
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks=self.encoder_masks,
                encoder_inputs_tensor=self.perm_conv_output,
                decoder_inputs=self.decoder_inputs,
                target_weights=self.target_weights,
                target_vocab_size=target_vocab_size,
                buckets=buckets,
                target_embedding_size=target_embedding_size,
                attn_num_layers=attn_num_layers,
                attn_num_hidden=attn_num_hidden,
                forward_only=self.forward_only,
                use_gru=use_gru)

        if not self.forward_only:

            self.updates = []
            self.summaries_by_bucket = []
            with tf.device(gpu_device_id):
                params = tf.trainable_variables()
                # Gradients and SGD update operation for training the model.
                opt = tf.train.AdadeltaOptimizer(
                    learning_rate=initial_learning_rate)
                for b in xrange(len(buckets)):
                    if self.reg_val > 0:
                        reg_losses = tf.get_collection(
                            tf.GraphKeys.REGULARIZATION_LOSSES)
                        logging.info('Adding %s regularization losses',
                                     len(reg_losses))
                        logging.debug('REGULARIZATION_LOSSES: %s', reg_losses)
                        loss_op = self.reg_val * tf.reduce_sum(
                            reg_losses) + self.attention_decoder_model.losses[b]
                    else:
                        loss_op = self.attention_decoder_model.losses[b]

                    gradients, params = zip(
                        *opt.compute_gradients(loss_op, params))
                    if self.clip_gradients:
                        gradients, _ = tf.clip_by_global_norm(
                            gradients, max_gradient_norm)
                    # Add summaries for loss, variables, gradients, gradient norms and total gradient norm.
                    summaries = []
                    '''
                    for gradient, variable in gradients:
                        if isinstance(gradient, tf.IndexedSlices):
                            grad_values = gradient.values
                        else:
                            grad_values = gradient
                        summaries.append(tf.summary.histogram(variable.name, variable))
                        summaries.append(tf.summary.histogram(variable.name + "/gradients", grad_values))
                        summaries.append(tf.summary.scalar(variable.name + "/gradient_norm",
                                             tf.global_norm([grad_values])))
                    '''
                    summaries.append(tf.summary.scalar("loss", loss_op))
                    summaries.append(
                        tf.summary.scalar("total_gradient_norm",
                                          tf.global_norm(gradients)))
                    all_summaries = tf.summary.merge(summaries)
                    self.summaries_by_bucket.append(all_summaries)
                    # update op - apply gradients
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    with tf.control_dependencies(update_ops):
                        self.updates.append(
                            opt.apply_gradients(zip(gradients, params),
                                                global_step=self.global_step))

        self.saver_all = tf.train.Saver(tf.all_variables())

        ckpt = tf.train.get_checkpoint_state(model_dir)
        print(ckpt, load_model)
        if ckpt and load_model:
            logging.info("Reading model parameters from %s" %
                         ckpt.model_checkpoint_path)
            #self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            logging.info("Created model with fresh parameters.")
            self.sess.run(tf.initialize_all_variables())
예제 #2
0
class Model(object):
    def __init__(self,
                 phase,
                 visualize,
                 data_path,
                 data_base_dir,
                 output_dir,
                 batch_size,
                 initial_learning_rate,
                 num_epoch,
                 steps_per_checkpoint,
                 target_vocab_size,
                 model_dir,
                 target_embedding_size,
                 attn_num_hidden,
                 attn_num_layers,
                 clip_gradients,
                 max_gradient_norm,
                 opt_attn,
                 session,
                 load_model,
                 gpu_id,
                 use_gru,
                 reg_val,
                 augmentation,
                 evaluate=False,
                 valid_target_length=float('inf')):

        # Support two GPUs
        gpu_device_id_1 = '/gpu:' + str(gpu_id)
        gpu_device_id_2 = '/gpu:' + str(gpu_id)
        if gpu_id == 2:
            gpu_device_id_1 = '/gpu:' + str(gpu_id - 1)
            gpu_device_id_2 = '/gpu:' + str(gpu_id - 2)

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        logging.info('loading data')
        # load data
        if phase == 'train':
            self.s_gen = DataGen(data_base_dir,
                                 data_path,
                                 valid_target_len=valid_target_length,
                                 evaluate=False)
        else:
            batch_size = 1
            self.s_gen = DataGen(data_base_dir, data_path, evaluate=True)

        #logging.info('valid_target_length: %s' %(str(valid_target_length)))
        logging.info('phase: %s' % phase)
        logging.info('model_dir: %s' % (model_dir))
        logging.info('load_model: %s' % (load_model))
        logging.info('output_dir: %s' % (output_dir))
        logging.info('steps_per_checkpoint: %d' % (steps_per_checkpoint))
        logging.info('batch_size: %d' % (batch_size))
        logging.info('num_epoch: %d' % num_epoch)
        logging.info('learning_rate: %d' % initial_learning_rate)
        logging.info('reg_val: %d' % reg_val)
        logging.info('max_gradient_norm: %f' % max_gradient_norm)
        logging.info('opt_attn: %s' % opt_attn)
        logging.info('clip_gradients: %s' % clip_gradients)
        logging.info('valid_target_length %f' % valid_target_length)
        logging.info('target_vocab_size: %d' % target_vocab_size)
        logging.info('target_embedding_size: %f' % target_embedding_size)
        logging.info('attn_num_hidden: %d' % attn_num_hidden)
        logging.info('attn_num_layers: %d' % attn_num_layers)
        logging.info('visualize: %s' % visualize)
        logging.info('P(data augmentation): %s' % augmentation)

        buckets = self.s_gen.bucket_specs
        logging.info('buckets')
        logging.info(buckets)
        if use_gru:
            logging.info('use GRU in the decoder.')

        # variables
        self.img_data = tf.placeholder(tf.float32,
                                       shape=(None, 1, 64, None),
                                       name='img_data')
        self.zero_paddings = tf.placeholder(tf.float32,
                                            shape=(None, None, 512),
                                            name='zero_paddings')

        self.decoder_inputs = []
        self.encoder_masks = []
        self.target_weights = []
        for i in xrange(int(buckets[-1][0] + 1)):
            self.encoder_masks.append(
                tf.placeholder(tf.float32,
                               shape=[None, 1],
                               name="encoder_mask{0}".format(i)))
        for i in xrange(buckets[-1][1] + 1):
            self.decoder_inputs.append(
                tf.placeholder(tf.int32,
                               shape=[None],
                               name="decoder{0}".format(i)))
            self.target_weights.append(
                tf.placeholder(tf.float32,
                               shape=[None],
                               name="weight{0}".format(i)))

        self.reg_val = reg_val
        self.sess = session
        self.evaluate = evaluate
        self.steps_per_checkpoint = steps_per_checkpoint
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.buckets = buckets
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.global_step = tf.Variable(0, trainable=False)
        self.valid_target_length = valid_target_length
        self.phase = phase
        self.visualize = visualize
        self.learning_rate = initial_learning_rate
        self.clip_gradients = clip_gradients
        self.augmentation = augmentation

        if phase == 'train':
            self.forward_only = False
        elif phase == 'test':
            self.forward_only = True
        else:
            assert False, phase

        with tf.device(gpu_device_id_1):
            cnn_model = CNN(self.img_data, True)  #(not self.forward_only))
            self.conv_output = cnn_model.tf_output()
            self.concat_conv_output = tf.concat(
                axis=1, values=[self.conv_output, self.zero_paddings])

            self.perm_conv_output = tf.transpose(self.concat_conv_output,
                                                 perm=[1, 0, 2])

        with tf.device(gpu_device_id_2):
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks=self.encoder_masks,
                encoder_inputs_tensor=self.perm_conv_output,
                decoder_inputs=self.decoder_inputs,
                target_weights=self.target_weights,
                target_vocab_size=target_vocab_size,
                buckets=buckets,
                target_embedding_size=target_embedding_size,
                attn_num_layers=attn_num_layers,
                attn_num_hidden=attn_num_hidden,
                forward_only=self.forward_only,
                use_gru=use_gru,
                opt_attn=opt_attn)

        if not self.forward_only:

            self.updates = []
            self.summaries_by_bucket = []
            with tf.device(gpu_device_id_2):
                params = tf.trainable_variables()
                # Gradients and SGD update operation for training the model.
                opt = tf.train.AdadeltaOptimizer(
                    learning_rate=initial_learning_rate)
                for b in xrange(len(buckets)):
                    if self.reg_val > 0:
                        reg_losses = tf.get_collection(
                            tf.GraphKeys.REGULARIZATION_LOSSES)
                        logging.info('Adding %s regularization losses',
                                     len(reg_losses))
                        logging.debug('REGULARIZATION_LOSSES: %s', reg_losses)
                        loss_op = self.reg_val * tf.reduce_sum(
                            reg_losses) + self.attention_decoder_model.losses[b]
                    else:
                        loss_op = self.attention_decoder_model.losses[b]

                    gradients, params = zip(
                        *opt.compute_gradients(loss_op, params))
                    if self.clip_gradients:
                        gradients, _ = tf.clip_by_global_norm(
                            gradients, max_gradient_norm)
                    # Add summaries for loss, variables, gradients, gradient norms and total gradient norm.
                    summaries = []
                    '''
                    for gradient, variable in gradients:
                        if isinstance(gradient, tf.IndexedSlices):
                            grad_values = gradient.values
                        else:
                            grad_values = gradient
                        summaries.append(tf.summary.histogram(variable.name, variable))
                        summaries.append(tf.summary.histogram(variable.name + "/gradients", grad_values))
                        summaries.append(tf.summary.scalar(variable.name + "/gradient_norm",
                                             tf.global_norm([grad_values])))
                    '''
                    summaries.append(tf.summary.scalar("loss", loss_op))
                    summaries.append(
                        tf.summary.scalar("total_gradient_norm",
                                          tf.global_norm(gradients)))
                    all_summaries = tf.summary.merge(summaries)
                    self.summaries_by_bucket.append(all_summaries)
                    # update op - apply gradients
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    with tf.control_dependencies(update_ops):
                        self.updates.append(
                            opt.apply_gradients(zip(gradients, params),
                                                global_step=self.global_step))

        self.saver_all = tf.train.Saver(tf.global_variables())

        ckpt = tf.train.get_checkpoint_state(model_dir)
        if ckpt and load_model:
            logging.info("Reading model parameters from %s" %
                         ckpt.model_checkpoint_path)
            #self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            logging.info("Created model with fresh parameters.")
            self.sess.run(tf.global_variables_initializer())
        #self.sess.run(init_new_vars_op)

    # train or test as specified by phase
    def launch(self):
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        writer = tf.summary.FileWriter(self.model_dir, self.sess.graph)
        if self.phase == 'test':
            if not distance_loaded:
                logging.info(
                    'Warning: distance module not installed. Do whole sequence comparison instead.'
                )
            else:
                logging.info('Compare word based on edit distance.')
            num_correct = 0
            num_total = 0
            numerator = 0
            denominator = 0
            for batch in self.s_gen.gen(self.batch_size):
                # Get a batch and make a step.
                start_time = time.time()
                bucket_id = batch['bucket_id']
                img_data = batch['data']
                zero_paddings = batch['zero_paddings']
                decoder_inputs = batch['decoder_inputs']
                target_weights = batch['target_weights']
                encoder_masks = batch['encoder_mask']
                file_list = batch['filenames']
                real_len = batch['real_len']

                grounds = [
                    a for a in np.array([
                        decoder_input.tolist()
                        for decoder_input in decoder_inputs
                    ]).transpose()
                ]
                _, step_loss, step_logits, step_attns = self.step(
                    encoder_masks, img_data, zero_paddings, decoder_inputs,
                    target_weights, bucket_id, self.forward_only)
                curr_step_time = (time.time() - start_time)
                step_time += curr_step_time / self.steps_per_checkpoint
                logging.info(
                    'step_time: %f, loss: %f, step perplexity: %f' %
                    (curr_step_time, step_loss,
                     math.exp(step_loss) if step_loss < 300 else float('inf')))
                loss += step_loss / self.steps_per_checkpoint
                current_step += 1
                step_outputs = [
                    b for b in np.array([
                        np.argmax(logit, axis=1).tolist()
                        for logit in step_logits
                    ]).transpose()
                ]
                if self.visualize:
                    step_attns = np.array([[a.tolist() for a in step_attn]
                                           for step_attn in step_attns
                                           ]).transpose([1, 0, 2])

                for idx, output, ground in zip(range(len(grounds)),
                                               step_outputs, grounds):
                    flag_ground, flag_out = True, True
                    num_total += 1
                    output_valid = []
                    ground_valid = []
                    for j in range(1, len(ground)):
                        s1 = output[j - 1]
                        s2 = ground[j]
                        if s2 != 2 and flag_ground:
                            ground_valid.append(s2)
                        else:
                            flag_ground = False
                        if s1 != 2 and flag_out:
                            output_valid.append(s1)
                        else:
                            flag_out = False
                    if distance_loaded:
                        lev = distance.levenshtein(output_valid, ground_valid)
                        if self.visualize:
                            self.visualize_attention(
                                file_list[idx], step_attns[idx], output_valid,
                                ground_valid, num_incorrect > 0, real_len)
                        num_incorrect = float(lev) / len(ground_valid)
                        num_incorrect = min(1.0, num_incorrect)
                        nchar = len(ground_valid)
                    else:
                        if output_valid == ground_valid:
                            num_incorrect = 0
                        else:
                            num_incorrect = 1
                        if self.visualize:
                            self.visualize_attention(
                                file_list[idx], step_attns[idx], output_valid,
                                ground_valid, num_incorrect > 0, real_len)
                    num_correct += 1. - num_incorrect
                    numerator += lev
                    denominator += nchar
                logging.info('%f out of %d correct' % (num_correct, num_total))
                logging.info('Lev: %f, length of test set: %f' %
                             (numerator, denominator))
                logging.info(
                    'Global loss: %f, Global perplexity: %f' %
                    (loss, math.exp(loss) if loss < 300 else float('inf')))
        elif self.phase == 'train':
            total = (self.s_gen.get_size() // self.batch_size)
            with tqdm(desc='Train: ', total=total) as pbar:
                st = lambda aug: iaa.Sometimes(self.augmentation, aug)
                seq = iaa.Sequential([
                    st(
                        iaa.Affine(
                            scale={
                                "x": (0.8, 1.2),
                                "y": (0.8, 1.2)
                            },  # scale images to 80-120% of their size, individually per axis
                            translate_px={
                                "x": (-16, 16),
                                "y": (-16, 16)
                            },  # translate by -16 to +16 pixels (per axis)
                            rotate=(-45, 45),  # rotate by -45 to +45 degrees
                            shear=(-16, 16),  # shear by -16 to +16 degrees
                        ))
                ])
                for epoch in range(self.num_epoch):

                    logging.info('Generating first batch')
                    for i, batch in enumerate(self.s_gen.gen(self.batch_size)):
                        # Get a batch and make a step.
                        num_total = 0
                        num_correct = 0
                        numerator = 0
                        denominator = 0
                        start_time = time.time()
                        batch_len = batch['real_len']
                        bucket_id = batch['bucket_id']
                        img_data = batch['data']
                        img_data = seq.augment_images(
                            img_data.transpose(0, 2, 3, 1))
                        img_data = img_data.transpose(0, 3, 1, 2)
                        zero_paddings = batch['zero_paddings']
                        decoder_inputs = batch['decoder_inputs']
                        target_weights = batch['target_weights']
                        encoder_masks = batch['encoder_mask']
                        #logging.info('current_step: %d'%current_step)
                        #logging.info(np.array([decoder_input.tolist() for decoder_input in decoder_inputs]).transpose()[0])
                        #print (np.array([target_weight.tolist() for target_weight in target_weights]).transpose()[0])
                        summaries, step_loss, step_logits, _ = self.step(
                            encoder_masks, img_data, zero_paddings,
                            decoder_inputs, target_weights, bucket_id,
                            self.forward_only)

                        grounds = [
                            a for a in np.array([
                                decoder_input.tolist()
                                for decoder_input in decoder_inputs
                            ]).transpose()
                        ]
                        step_outputs = [
                            b for b in np.array([
                                np.argmax(logit, axis=1).tolist()
                                for logit in step_logits
                            ]).transpose()
                        ]

                        for idx, output, ground in zip(range(len(grounds)),
                                                       step_outputs, grounds):
                            flag_ground, flag_out = True, True
                            num_total += 1
                            output_valid = []
                            ground_valid = []
                            for j in range(1, len(ground)):
                                s1 = output[j - 1]
                                s2 = ground[j]
                                if s2 != 2 and flag_ground:
                                    ground_valid.append(s2)
                                else:
                                    flag_ground = False
                                if s1 != 2 and flag_out:
                                    output_valid.append(s1)
                                else:
                                    flag_out = False
                            if distance_loaded:
                                lev = distance.levenshtein(
                                    output_valid, ground_valid)
                                num_incorrect = float(lev) / len(ground_valid)
                                num_incorrect = min(1.0, num_incorrect)
                                nchar = len(ground_valid)
                            else:
                                if output_valid == ground_valid:
                                    num_incorrect = 0
                                else:
                                    num_incorrect = 1
                            num_correct += 1. - num_incorrect
                            numerator += lev
                            denominator += nchar

                        writer.add_summary(summaries, current_step)
                        curr_step_time = (time.time() - start_time)
                        step_time += curr_step_time / self.steps_per_checkpoint
                        precision = num_correct / num_total
                        logging.info(
                            'step %f - time: %f, loss: %f, perplexity: %f, precision: %f, CER: %f, batch_len: %f'
                            % (current_step, curr_step_time, step_loss,
                               math.exp(step_loss)
                               if step_loss < 300 else float('inf'), precision,
                               numerator / denominator, batch_len))
                        loss += step_loss / self.steps_per_checkpoint
                        pbar.set_description(
                            'Train, loss={:.8f}'.format(step_loss))
                        pbar.update()
                        current_step += 1
                        # If there is an EOS symbol in outputs, cut them at that point.
                        #if data_utils.EOS_ID in step_outputs:
                        #    step_outputs = step_outputs[:step_outputs.index(data_utils.EOS_ID)]
                        #if data_utils.PAD_ID in decoder_inputs:
                        #decoder_inputs = decoder_inputs[:decoder_inputs.index(data_utils.PAD_ID)]
                        #    print (step_outputs[0])

                        # Once in a while, we save checkpoint, print statistics, and run evals.
                        if current_step % self.steps_per_checkpoint == 0:
                            # Print statistics for the previous epoch.
                            logging.info(
                                "global step %d step-time %.2f loss %f  perplexity "
                                "%.2f" % (self.global_step.eval(), step_time,
                                          loss, math.exp(loss)
                                          if loss < 300 else float('inf')))
                            previous_losses.append(loss)
                            # Save checkpoint and zero timer and loss.
                            if not self.forward_only:
                                checkpoint_path = os.path.join(
                                    self.model_dir, "translate.ckpt")
                                logging.info("Saving model, current_step: %d" %
                                             current_step)
                                self.saver_all.save(
                                    self.sess,
                                    checkpoint_path,
                                    global_step=self.global_step)
                            step_time, loss = 0.0, 0.0
                            #sys.stdout.flush()

    # step, read one batch, generate gradients
    def step(self, encoder_masks, img_data, zero_paddings, decoder_inputs,
             target_weights, bucket_id, forward_only):
        # Check if the sizes match.
        encoder_size, decoder_size = self.buckets[bucket_id]
        if len(decoder_inputs) != decoder_size:
            raise ValueError(
                "Decoder length must be equal to the one in bucket,"
                " %d != %d." % (len(decoder_inputs), decoder_size))
        if len(target_weights) != decoder_size:
            raise ValueError(
                "Weights length must be equal to the one in bucket,"
                " %d != %d." % (len(target_weights), decoder_size))

        # Input feed: encoder inputs, decoder inputs, target_weights, as provided.
        input_feed = {}
        input_feed[self.img_data.name] = img_data
        input_feed[self.zero_paddings.name] = zero_paddings
        for l in xrange(decoder_size):
            input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
            input_feed[self.target_weights[l].name] = target_weights[l]
        for l in xrange(int(encoder_size)):
            try:
                input_feed[self.encoder_masks[l].name] = encoder_masks[l]
            except Exception as e:
                pass

        # Since our targets are decoder inputs shifted by one, we need one more.
        last_target = self.decoder_inputs[decoder_size].name
        input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)

        # Output feed: depends on whether we do a backward step or not.
        if not forward_only:
            output_feed = [
                self.updates[bucket_id],  # Update Op that does SGD.
                #self.gradient_norms[bucket_id],  # Gradient norm.
                self.attention_decoder_model.losses[bucket_id],
                self.summaries_by_bucket[bucket_id]
            ]
            for l in xrange(decoder_size):  # Output logits.
                output_feed.append(
                    self.attention_decoder_model.outputs[bucket_id][l])
        else:
            output_feed = [self.attention_decoder_model.losses[bucket_id]
                           ]  # Loss for this batch.
            for l in xrange(decoder_size):  # Output logits.
                output_feed.append(
                    self.attention_decoder_model.outputs[bucket_id][l])
            if self.visualize:
                output_feed += self.attention_decoder_model.attention_weights_histories[
                    bucket_id]

        outputs = self.sess.run(output_feed, input_feed)
        if not forward_only:
            return outputs[2], outputs[1], outputs[3:(
                3 + self.buckets[bucket_id][1]
            )], None  # Gradient norm summary, loss, no outputs, no attentions.
        else:
            return None, outputs[0], outputs[1:(
                1 + self.buckets[bucket_id][1])], outputs[(
                    1 + self.buckets[bucket_id][1]
                ):]  # No gradient norm, loss, outputs, attentions.

    def visualize_attention(self, filename, attentions, output_valid,
                            ground_valid, flag_incorrect, real_len):
        if flag_incorrect:
            output_dir = os.path.join(self.output_dir, 'incorrect')
        else:
            output_dir = os.path.join(self.output_dir, 'correct')
        output_dir = os.path.join(output_dir, filename.replace('/', '_'))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, 'word.txt'), 'w') as fword:
            gt = ''.join([chr(c - 3 + 33) for c in ground_valid])
            ot = ''.join([chr(c - 3 + 33) for c in output_valid])
            fword.write(gt + '\n')
            fword.write(ot)
            with open(filename, 'rb') as img_file:
                img = Image.open(img_file)
                w, h = img.size
                h = 32
                img = img.resize((real_len, h), Image.ANTIALIAS)
                if img.mode == '1':
                    img_data = np.asarray(img, dtype=bool) * np.iinfo(
                        np.uint8).max
                    img_data = img_data.astype(np.uint8)
                else:
                    img_data = np.asarray(img, dtype=np.uint8)
                data = []
                for idx in range(len(output_valid)):
                    output_filename = os.path.join(output_dir,
                                                   'image_%d.jpg' % (idx))
                    # get the first fourth
                    attention = attentions[idx][:(int(real_len / 4) - 1)]
                    # repeat each values four times
                    attention_orig = np.zeros(real_len)
                    for i in range(real_len):
                        if 0 < i / 4 - 1 and i / 4 - 1 < len(attention):
                            attention_orig[i] = attention[int(i / 4) - 1]
                    # do some scaling
                    attention_orig = np.convolve(
                        attention_orig,
                        [0.199547, 0.200226, 0.200454, 0.200226, 0.199547],
                        mode='same')
                    attention_orig = np.maximum(attention_orig, 0.3)
                    # copy values to other heights
                    attention_out = np.zeros((h, real_len))
                    for i in range(real_len):
                        attention_out[:, i] = attention_orig[i]
                    if len(img_data.shape) == 3:
                        attention_out = attention_out[:, :, np.newaxis]
                    data.append(attention_out[0, :])
                    img_out_data = img_data * attention_out
                    img_out = Image.fromarray(img_out_data.astype(np.uint8))
                    img_out.save(output_filename)

                # plot ~ 5% of the time
                if np.random.random() < 0.05:
                    fig = plt.figure(figsize=(2, 6))
                    gs = matplotlib.gridspec.GridSpec(
                        2,
                        1,
                        height_ratios=[len(ot) * 2, 1],
                        wspace=0.0,
                        hspace=0.0,
                        top=0.95,
                        bottom=0.05,
                        left=0.17,
                        right=0.845)
                    ax = plt.subplot(gs[0])
                    ax.imshow(data,
                              aspect='auto',
                              interpolation='nearest',
                              cmap='gray')
                    ax.set_xticklabels([])
                    ax.set_yticks(np.arange(len(ot)))
                    ax.set_yticklabels(ot.replace('|', ' '))
                    ax.tick_params(axis=u'both', which=u'both', length=0)
                    ax = plt.subplot(gs[1])
                    ax.imshow(img,
                              aspect='auto',
                              interpolation='nearest',
                              cmap='gray')
                    ax.set_xticklabels([])
                    ax.set_yticklabels([])
                    ax.tick_params(axis=u'both', which=u'both', length=0)
                    fig.savefig(os.path.join(output_dir, 'att_mat.png'))
예제 #3
0
class Model(object):
    def __init__(self,
                 phase,
                 visualize,
                 data_path,
                 data_path_valid,
                 data_path_test,
                 data_base_dir,
                 output_dir,
                 batch_size,
                 initial_learning_rate,
                 num_epoch,
                 steps_per_checkpoint,
                 target_vocab_size,
                 model_dir,
                 target_embedding_size,
                 attn_num_hidden,
                 attn_num_layers,
                 clip_gradients,
                 max_gradient_norm,
                 session,
                 load_model,
                 gpu_id,
                 use_gru,
                 evaluate=False,
                 valid_target_length=float('inf'),
                 reg_val=0):

        gpu_device_id = '/gpu:' + str(gpu_id)
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        logging.info('loading data')
        #print (data_base_dir)
        # load data
        if phase == 'train':
            self.s_gen = DataGen(data_base_dir,
                                 data_path,
                                 valid_target_len=valid_target_length,
                                 evaluate=False)
            self.s_gen_valid = DataGen(data_base_dir,
                                       data_path_valid,
                                       evaluate=True)
            self.s_gen_test = DataGen(data_base_dir,
                                      data_path_test,
                                      evaluate=True)
        else:
            batch_size = 1
            self.s_gen = DataGen(data_base_dir, data_path, evaluate=True)

        #logging.info('valid_target_length: %s' %(str(valid_target_length)))
        logging.info('phase: %s' % phase)
        logging.info('model_dir: %s' % (model_dir))
        logging.info('load_model: %s' % (load_model))
        logging.info('output_dir: %s' % (output_dir))
        logging.info('steps_per_checkpoint: %d' % (steps_per_checkpoint))
        logging.info('batch_size: %d' % (batch_size))
        logging.info('num_epoch: %d' % num_epoch)
        logging.info('learning_rate: %d' % initial_learning_rate)
        logging.info('reg_val: %d' % (reg_val))
        logging.info('max_gradient_norm: %f' % max_gradient_norm)
        logging.info('clip_gradients: %s' % clip_gradients)
        logging.info('valid_target_length %f' % valid_target_length)
        logging.info('target_vocab_size: %d' % target_vocab_size)
        logging.info('target_embedding_size: %f' % target_embedding_size)
        logging.info('attn_num_hidden: %d' % attn_num_hidden)
        logging.info('attn_num_layers: %d' % attn_num_layers)
        logging.info('visualize: %s' % visualize)

        buckets = self.s_gen.bucket_specs
        logging.info('buckets')
        logging.info(buckets)
        if use_gru:
            logging.info('ues GRU in the decoder.')

        # variables
        self.img_data = tf.placeholder(tf.float32,
                                       shape=(None, 1, 32, None),
                                       name='img_data')
        self.zero_paddings = tf.placeholder(tf.float32,
                                            shape=(None, None, 512),
                                            name='zero_paddings')

        self.decoder_inputs = []
        self.encoder_masks = []
        self.target_weights = []
        for i in xrange(int(buckets[-1][0] + 1)):
            self.encoder_masks.append(
                tf.placeholder(tf.float32,
                               shape=[None, 1],
                               name="encoder_mask{0}".format(i)))
        for i in xrange(buckets[-1][1] + 1):
            self.decoder_inputs.append(
                tf.placeholder(tf.int32,
                               shape=[None],
                               name="decoder{0}".format(i)))
            self.target_weights.append(
                tf.placeholder(tf.float32,
                               shape=[None],
                               name="weight{0}".format(i)))

        self.reg_val = reg_val
        self.sess = session
        self.evaluate = evaluate
        self.steps_per_checkpoint = steps_per_checkpoint
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.buckets = buckets
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.global_step = tf.Variable(0, trainable=False)
        self.valid_target_length = valid_target_length
        self.phase = phase
        self.visualize = visualize
        self.learning_rate = initial_learning_rate
        self.clip_gradients = clip_gradients

        if phase == 'train':
            self.forward_only = False
        elif phase == 'test':
            self.forward_only = True
        else:
            assert False, phase

        with tf.device(gpu_device_id):
            cnn_model = CNN(self.img_data, True)  #(not self.forward_only))
            self.conv_output = cnn_model.tf_output()
            self.concat_conv_output = tf.concat(
                axis=1, values=[self.conv_output, self.zero_paddings])

            self.perm_conv_output = tf.transpose(self.concat_conv_output,
                                                 perm=[1, 0, 2])

        with tf.device(gpu_device_id):
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks=self.encoder_masks,
                encoder_inputs_tensor=self.perm_conv_output,
                decoder_inputs=self.decoder_inputs,
                target_weights=self.target_weights,
                target_vocab_size=target_vocab_size,
                buckets=buckets,
                target_embedding_size=target_embedding_size,
                attn_num_layers=attn_num_layers,
                attn_num_hidden=attn_num_hidden,
                forward_only=self.forward_only,
                use_gru=use_gru)

        if not self.forward_only:

            self.updates = []
            self.summaries_by_bucket = []
            with tf.device(gpu_device_id):
                params = tf.trainable_variables()
                # Gradients and SGD update operation for training the model.
                opt = tf.train.AdadeltaOptimizer(
                    learning_rate=initial_learning_rate)
                for b in xrange(len(buckets)):
                    if self.reg_val > 0:
                        reg_losses = tf.get_collection(
                            tf.GraphKeys.REGULARIZATION_LOSSES)
                        logging.info('Adding %s regularization losses',
                                     len(reg_losses))
                        logging.debug('REGULARIZATION_LOSSES: %s', reg_losses)
                        loss_op = self.reg_val * tf.reduce_sum(
                            reg_losses) + self.attention_decoder_model.losses[b]
                    else:
                        loss_op = self.attention_decoder_model.losses[b]

                    gradients, params = zip(
                        *opt.compute_gradients(loss_op, params))
                    if self.clip_gradients:
                        gradients, _ = tf.clip_by_global_norm(
                            gradients, max_gradient_norm)
                    # Add summaries for loss, variables, gradients, gradient norms and total gradient norm.
                    summaries = []
                    '''
                    for gradient, variable in gradients:
                        if isinstance(gradient, tf.IndexedSlices):
                            grad_values = gradient.values
                        else:
                            grad_values = gradient
                        summaries.append(tf.summary.histogram(variable.name, variable))
                        summaries.append(tf.summary.histogram(variable.name + "/gradients", grad_values))
                        summaries.append(tf.summary.scalar(variable.name + "/gradient_norm",
                                             tf.global_norm([grad_values])))
                    '''
                    summaries.append(tf.summary.scalar("loss", loss_op))
                    summaries.append(
                        tf.summary.scalar("total_gradient_norm",
                                          tf.global_norm(gradients)))
                    all_summaries = tf.summary.merge(summaries)
                    self.summaries_by_bucket.append(all_summaries)
                    # update op - apply gradients
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    with tf.control_dependencies(update_ops):
                        self.updates.append(
                            opt.apply_gradients(zip(gradients, params),
                                                global_step=self.global_step))

        self.saver_all = tf.train.Saver(tf.all_variables())

        ckpt = tf.train.get_checkpoint_state(model_dir)
        print(ckpt, load_model)
        if ckpt and load_model:
            logging.info("Reading model parameters from %s" %
                         ckpt.model_checkpoint_path)
            #self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            logging.info("Created model with fresh parameters.")
            self.sess.run(tf.initialize_all_variables())
        #self.sess.run(init_new_vars_op)

    # train or test as specified by phase
    def launch(self):
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        writer = tf.summary.FileWriter(self.model_dir, self.sess.graph)
        if self.phase == 'test':
            if not distance_loaded:
                logging.info(
                    'Warning: distance module not installed. Do whole sequence comparison instead.'
                )
            else:
                logging.info('Compare word based on edit distance.')
            num_correct = 0
            num_total = 0
            for batch in self.s_gen.gen(self.batch_size):
                # Get a batch and make a step.
                start_time = time.time()
                bucket_id = batch['bucket_id']
                img_data = batch['data']
                zero_paddings = batch['zero_paddings']
                decoder_inputs = batch['decoder_inputs']
                target_weights = batch['target_weights']
                encoder_masks = batch['encoder_mask']
                file_list = batch['filenames']
                real_len = batch['real_len']

                grounds = [
                    a for a in np.array([
                        decoder_input.tolist()
                        for decoder_input in decoder_inputs
                    ]).transpose()
                ]
                _, step_loss, step_logits, step_attns = self.step(
                    encoder_masks, img_data, zero_paddings, decoder_inputs,
                    target_weights, bucket_id, self.forward_only)
                curr_step_time = (time.time() - start_time)
                step_time += curr_step_time / self.steps_per_checkpoint
                logging.info(
                    'step_time: %f, loss: %f, step perplexity: %f' %
                    (curr_step_time, step_loss,
                     math.exp(step_loss) if step_loss < 300 else float('inf')))
                loss += step_loss / self.steps_per_checkpoint
                current_step += 1
                step_outputs = [
                    b for b in np.array([
                        np.argmax(logit, axis=1).tolist()
                        for logit in step_logits
                    ]).transpose()
                ]
                if self.visualize:
                    step_attns = np.array([[a.tolist() for a in step_attn]
                                           for step_attn in step_attns
                                           ]).transpose([1, 0, 2])
                    #print (step_attns)

                for idx, output, ground in zip(range(len(grounds)),
                                               step_outputs, grounds):
                    flag_ground, flag_out = True, True
                    num_total += 1
                    output_valid = []
                    ground_valid = []
                    for j in range(1, len(ground)):
                        s1 = output[j - 1]
                        s2 = ground[j]
                        if s2 != 2 and flag_ground:
                            ground_valid.append(s2)
                        else:
                            flag_ground = False
                        if s1 != 2 and flag_out:
                            output_valid.append(s1)
                        else:
                            flag_out = False
                    if distance_loaded:
                        num_incorrect = distance.levenshtein(
                            output_valid, ground_valid)
                        if self.visualize:
                            self.visualize_attention(
                                file_list[idx], step_attns[idx], output_valid,
                                ground_valid, num_incorrect > 0, real_len)
                        num_incorrect = float(num_incorrect) / len(
                            ground_valid)
                        num_incorrect = min(1.0, num_incorrect)
                    else:
                        if output_valid == ground_valid:
                            num_incorrect = 0
                        else:
                            num_incorrect = 1
                        if self.visualize:
                            self.visualize_attention(
                                file_list[idx], step_attns[idx], output_valid,
                                ground_valid, num_incorrect > 0, real_len)
                    num_correct += 1. - num_incorrect
                logging.info('%f out of %d correct' % (num_correct, num_total))
        elif self.phase == 'train':
            total = (self.s_gen.get_size() // self.batch_size)
            WER = 1.0
            swriter = open("result_log.txt", "a")
            with tqdm(desc='Train: ', total=total) as pbar:
                for epoch in range(self.num_epoch):
                    logging.info('Generating first batch)')
                    n_correct = 0
                    n_total = 0
                    for i, batch in enumerate(self.s_gen.gen(self.batch_size)):
                        # Get a batch and make a step.
                        num_total = 0
                        num_correct = 0
                        start_time = time.time()
                        batch_len = batch['real_len']
                        bucket_id = batch['bucket_id']
                        img_data = batch['data']
                        zero_paddings = batch['zero_paddings']
                        decoder_inputs = batch['decoder_inputs']
                        target_weights = batch['target_weights']
                        encoder_masks = batch['encoder_mask']
                        #logging.info('current_step: %d'%current_step)
                        #logging.info(np.array([decoder_input.tolist() for decoder_input in decoder_inputs]).transpose()[0])
                        #print (np.array([target_weight.tolist() for target_weight in target_weights]).transpose()[0])
                        summaries, step_loss, step_logits, _ = self.step(
                            encoder_masks, img_data, zero_paddings,
                            decoder_inputs, target_weights, bucket_id,
                            self.forward_only)

                        grounds = [
                            a for a in np.array([
                                decoder_input.tolist()
                                for decoder_input in decoder_inputs
                            ]).transpose()
                        ]
                        step_outputs = [
                            b for b in np.array([
                                np.argmax(logit, axis=1).tolist()
                                for logit in step_logits
                            ]).transpose()
                        ]

                        for idx, output, ground in zip(range(len(grounds)),
                                                       step_outputs, grounds):
                            flag_ground, flag_out = True, True
                            num_total += 1
                            output_valid = []
                            ground_valid = []
                            for j in range(1, len(ground)):
                                s1 = output[j - 1]
                                s2 = ground[j]
                                if s2 != 2 and flag_ground:
                                    ground_valid.append(s2)
                                else:
                                    flag_ground = False
                                if s1 != 2 and flag_out:
                                    output_valid.append(s1)
                                else:
                                    flag_out = False
                            if distance_loaded:
                                num_incorrect = distance.levenshtein(
                                    output_valid, ground_valid)
                                num_incorrect = float(num_incorrect) / len(
                                    ground_valid)
                                num_incorrect = min(1.0, num_incorrect)
                            else:
                                if output_valid == ground_valid:
                                    num_incorrect = 0
                                else:
                                    num_incorrect = 1
                            num_correct += 1. - num_incorrect

                        writer.add_summary(summaries, current_step)
                        curr_step_time = (time.time() - start_time)
                        step_time += curr_step_time / total
                        precision = num_correct / num_total
                        n_total += num_total
                        n_correct += num_correct
                        #logging.info('step %f - time: %f, loss: %f, perplexity: %f, precision: %f, batch_len: %f'%(current_step, curr_step_time, step_loss, math.exp(step_loss) if step_loss < 300 else float('inf'), precision, batch_len))
                        loss += step_loss / self.steps_per_checkpoint
                        pbar.set_description(
                            'Train, loss={:.8f}'.format(step_loss))
                        pbar.update()
                        current_step += 1
                        # If there is an EOS symbol in outputs, cut them at that point.
                        #if data_utils.EOS_ID in step_outputs:
                        #    step_outputs = step_outputs[:step_outputs.index(data_utils.EOS_ID)]
                        #if data_utils.PAD_ID in decoder_inputs:
                        #decoder_inputs = decoder_inputs[:decoder_inputs.index(data_utils.PAD_ID)]
                        #    print (step_outputs[0])

                        # Once in a while, we save checkpoint, print statistics, and run evals.
                        '''if current_step % self.steps_per_checkpoint == 0:
                            # Print statistics for the previous epoch.
                            perplexity = math.exp(loss) if loss < 300 else float('inf')
                            logging.info("global step %d step-time %.2f loss %f  perplexity "
                                    "%.2f" % (self.global_step.eval(), step_time, loss, perplexity))
                            previous_losses.append(loss)
                            # Save checkpoint and zero timer and loss.
                            if not self.forward_only:
                                checkpoint_path = os.path.join(self.model_dir, "translate.ckpt")
                                logging.info("Saving model, current_step: %d"%current_step)
                                self.saver_all.save(self.sess, checkpoint_path, global_step=self.global_step)
                            step_time, loss = 0.0, 0.0'''
                        #sys.stdout.flush()
                    print("Epoch " + str(epoch) + "WER = " +
                          str(1 - n_correct * 1.0 / n_total))
                    swriter.write("Epoch " + str(epoch) + " WER = " +
                                  str(1 - n_correct * 1.0 / n_total) + "\n")
                    print('Run validation...')
                    valid_error_rate = self.eval(self.s_gen_valid)
                    print('Finished validation...')
                    print("WER on validation: %f" % valid_error_rate)
                    swriter.write("WER on validation: " +
                                  str(valid_error_rate) + "\n")
                    if WER > valid_error_rate:
                        WER = valid_error_rate
                        checkpoint_path = os.path.join(self.model_dir,
                                                       "translate.ckpt")
                        logging.info("Saving model, current_step: %d" %
                                     current_step)
                        logging.info("best WER on validation: %f" % WER)
                        self.saver_all.save(self.sess,
                                            checkpoint_path,
                                            global_step=self.global_step)
                        print('Run testing...')
                        test_error_rate = self.eval(self.s_gen_test)
                        print('Finished testing...')
                        logging.info("best WER on test: %f" % test_error_rate)
                        swriter.write("best WER on test: %f" +
                                      str(test_error_rate) + "\n")
            swriter.close()

    # step, read one batch, generate gradients
    def step(self, encoder_masks, img_data, zero_paddings, decoder_inputs,
             target_weights, bucket_id, forward_only):
        # Check if the sizes match.
        encoder_size, decoder_size = self.buckets[bucket_id]
        if len(decoder_inputs) != decoder_size:
            raise ValueError(
                "Decoder length must be equal to the one in bucket,"
                " %d != %d." % (len(decoder_inputs), decoder_size))
        if len(target_weights) != decoder_size:
            raise ValueError(
                "Weights length must be equal to the one in bucket,"
                " %d != %d." % (len(target_weights), decoder_size))

        # Input feed: encoder inputs, decoder inputs, target_weights, as provided.
        input_feed = {}
        input_feed[self.img_data.name] = img_data
        input_feed[self.zero_paddings.name] = zero_paddings
        for l in xrange(decoder_size):
            input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
            input_feed[self.target_weights[l].name] = target_weights[l]
        for l in xrange(int(encoder_size)):
            try:
                input_feed[self.encoder_masks[l].name] = encoder_masks[l]
            except Exception as e:
                pass
                #ipdb.set_trace()

        # Since our targets are decoder inputs shifted by one, we need one more.
        last_target = self.decoder_inputs[decoder_size].name
        input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)

        # Output feed: depends on whether we do a backward step or not.
        if not forward_only:
            output_feed = [
                self.updates[bucket_id],  # Update Op that does SGD.
                #self.gradient_norms[bucket_id],  # Gradient norm.
                self.attention_decoder_model.losses[bucket_id],
                self.summaries_by_bucket[bucket_id]
            ]
            for l in xrange(decoder_size):  # Output logits.
                output_feed.append(
                    self.attention_decoder_model.outputs[bucket_id][l])
        else:
            output_feed = [self.attention_decoder_model.losses[bucket_id]
                           ]  # Loss for this batch.
            for l in xrange(decoder_size):  # Output logits.
                output_feed.append(
                    self.attention_decoder_model.outputs[bucket_id][l])
            if self.visualize:
                output_feed += self.attention_decoder_model.attention_weights_histories[
                    bucket_id]

        outputs = self.sess.run(output_feed, input_feed)
        if not forward_only:
            return outputs[2], outputs[1], outputs[3:(
                3 + self.buckets[bucket_id][1]
            )], None  # Gradient norm summary, loss, no outputs, no attentions.
        else:
            return None, outputs[0], outputs[1:(
                1 + self.buckets[bucket_id][1])], outputs[(
                    1 + self.buckets[bucket_id][1]
                ):]  # No gradient norm, loss, outputs, attentions.

    def visualize_attention(self, filename, attentions, output_valid,
                            ground_valid, flag_incorrect, real_len):
        if flag_incorrect:
            output_dir = os.path.join(self.output_dir, 'incorrect')
        else:
            output_dir = os.path.join(self.output_dir, 'correct')
        output_dir = os.path.join(output_dir, filename.replace('/', '_'))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, 'word.txt'), 'w') as fword:
            fword.write(
                ' '.join([self.s_gen.get_char(c)
                          for c in ground_valid]) + '\n')
            fword.write(' '.join(
                [self.s_gen.get_char(c) for c in output_valid]))
            with open(filename, 'rb') as img_file:
                img = Image.open(img_file)
                w, h = img.size
                h = 32
                img = img.resize((real_len, h), Image.ANTIALIAS)
                img_data = np.asarray(img, dtype=np.uint8)
                for idx in range(len(output_valid)):
                    output_filename = os.path.join(output_dir,
                                                   'image_%d.png' % (idx))
                    attention = attentions[idx][:(int(real_len / 4) - 1)]

                    # I have got the attention_orig here, which is of size 32*len(ground_truth), the only thing left is to visualize it and save it to output_filename
                    # TODO here
                    attention_orig = np.zeros(real_len)
                    for i in range(real_len):
                        if 0 < i / 4 - 1 and i / 4 - 1 < len(attention):
                            attention_orig[i] = attention[int(i / 4) - 1]
                    attention_orig = np.convolve(
                        attention_orig,
                        [0.199547, 0.200226, 0.200454, 0.200226, 0.199547],
                        mode='same')
                    attention_orig = np.maximum(attention_orig, 0.3)
                    attention_out = np.zeros((h, real_len))
                    for i in range(real_len):
                        attention_out[:, i] = attention_orig[i]
                    if len(img_data.shape) == 3:
                        attention_out = attention_out[:, :, np.newaxis]
                    img_out_data = img_data * attention_out
                    img_out = Image.fromarray(img_out_data.astype(np.uint8))
                    img_out.save(output_filename)
                    #print (output_filename)
                #assert False
    def eval(self, data_gen):
        num_correct = 0
        num_total = 0
        for batch in data_gen.gen(self.batch_size):
            start_time = time.time()
            bucket_id = batch['bucket_id']
            img_data = batch['data']
            zero_paddings = batch['zero_paddings']
            decoder_inputs = batch['decoder_inputs']
            target_weights = batch['target_weights']
            encoder_masks = batch['encoder_mask']
            file_list = batch['filenames']
            real_len = batch['real_len']

            grounds = [
                a for a in np.array([
                    decoder_input.tolist() for decoder_input in decoder_inputs
                ]).transpose()
            ]
            _, step_loss, step_logits, step_attns = self.step(
                encoder_masks,
                img_data,
                zero_paddings,
                decoder_inputs,
                target_weights,
                bucket_id,
                forward_only=True)
            step_outputs = [
                b for b in np.array([
                    np.argmax(logit, axis=1).tolist() for logit in step_logits
                ]).transpose()
            ]

            for idx, output, ground in zip(range(len(grounds)), step_outputs,
                                           grounds):
                flag_ground, flag_out = True, True
                num_total += 1
                output_valid = []
                ground_valid = []
                for j in range(1, len(ground)):
                    s1 = output[j - 1]
                    s2 = ground[j]
                    if s2 != 2 and flag_ground:
                        ground_valid.append(s2)
                    else:
                        flag_ground = False
                    if s1 != 2 and flag_out:
                        output_valid.append(s1)
                    else:
                        flag_out = False
                if distance_loaded:
                    num_incorrect = distance.levenshtein(
                        output_valid, ground_valid)
                    if self.visualize:
                        self.visualize_attention(file_list[idx],
                                                 step_attns[idx], output_valid,
                                                 ground_valid,
                                                 num_incorrect > 0, real_len)
                    num_incorrect = float(num_incorrect) / len(ground_valid)
                    num_incorrect = min(1.0, num_incorrect)
                else:
                    if output_valid == ground_valid:
                        num_incorrect = 0
                    else:
                        num_incorrect = 1
                num_correct += 1. - num_incorrect
        return (num_total - num_correct) * 1.0 / num_total
예제 #4
0
class Model(object):
    def __init__(self,
                 phase,
                 visualize,
                 data_path,
                 data_base_dir,
                 output_dir,
                 batch_size,
                 initial_learning_rate,
                 num_epoch,
                 steps_per_checkpoint,
                 target_vocab_size,
                 model_dir,
                 target_embedding_size,
                 attn_num_hidden,
                 attn_num_layers,
                 session,
                 load_model,
                 evaluate=False,
                 valid_target_length=float('inf'),
                 use_lstm=True):

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        logging.info('loading data')
        # load data
        if phase == 'train':
            self.s_gen = DataGen(data_base_dir,
                                 data_path,
                                 valid_target_len=valid_target_length,
                                 evaluate=False)
        else:
            batch_size = 1
            self.s_gen = DataGen(data_base_dir, data_path, evaluate=True)

        #logging.info('valid_target_length: %s' %(str(valid_target_length)))
        logging.info('data_path: %s' % (data_path))
        logging.info('phase: %s' % phase)
        logging.info('batch_size: %d' % batch_size)
        logging.info('num_epoch: %d' % num_epoch)
        logging.info('steps_per_checkpoint %d' % steps_per_checkpoint)
        logging.info('target_vocab_size: %d' % target_vocab_size)
        logging.info('model_dir: %s' % model_dir)
        logging.info('target_embedding_size: %d' % target_embedding_size)
        logging.info('attn_num_hidden: %d' % attn_num_hidden)
        logging.info('attn_num_layers: %d' % attn_num_layers)

        buckets = self.s_gen.bucket_specs
        logging.info('buckets')
        logging.info(buckets)

        # variables
        self.img_data = tf.placeholder(tf.float32,
                                       shape=(None, 1, 32, None),
                                       name='img_data')
        self.zero_paddings = tf.placeholder(tf.float32,
                                            shape=(None, None, 512),
                                            name='zero_paddings')

        self.decoder_inputs = []
        self.encoder_masks = []
        self.target_weights = []
        for i in xrange(buckets[-1][0] + 1):
            self.encoder_masks.append(
                tf.placeholder(tf.float32,
                               shape=[None, 1],
                               name="encoder_mask{0}".format(i)))
        for i in xrange(buckets[-1][1] + 1):
            self.decoder_inputs.append(
                tf.placeholder(tf.int32,
                               shape=[None],
                               name="decoder{0}".format(i)))
            self.target_weights.append(
                tf.placeholder(tf.float32,
                               shape=[None],
                               name="weight{0}".format(i)))

        self.sess = session
        self.evaluate = evaluate
        self.steps_per_checkpoint = steps_per_checkpoint
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.buckets = buckets
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.global_step = tf.Variable(0, trainable=False)
        self.valid_target_length = valid_target_length
        self.phase = phase
        self.visualize = visualize

        if phase == 'train':
            self.forward_only = False
        elif phase == 'test':
            self.forward_only = True
        else:
            assert False, phase
        #with tf.device('/gpu:1'):
        with tf.device('/gpu:0'):
            cnn_model = CNN(self.img_data)
            self.conv_output = cnn_model.tf_output()
            self.concat_conv_output = tf.concat(
                concat_dim=1, values=[self.conv_output, self.zero_paddings])

        #with tf.device('/cpu:0'):
        with tf.device('/gpu:0'):
            self.perm_conv_output = tf.transpose(self.concat_conv_output,
                                                 perm=[1, 0, 2])

        with tf.device('/gpu:0'):
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks=self.encoder_masks,
                encoder_inputs_tensor=self.perm_conv_output,
                decoder_inputs=self.decoder_inputs,
                target_weights=self.target_weights,
                target_vocab_size=target_vocab_size,
                buckets=buckets,
                target_embedding_size=target_embedding_size,
                attn_num_layers=attn_num_layers,
                attn_num_hidden=attn_num_hidden,
                forward_only=self.forward_only,
                use_lstm=use_lstm)

        # Gradients and SGD update operation for training the model.
        params_raw = tf.trainable_variables()
        params = []
        params_add = []
        for param in params_raw:
            #if 'running' in param.name or 'conv' in param.name or 'batch' in param.name:
            if 'running' in param.name:
                logging.info('parameter {0} NOT trainable'.format(param.name))
            else:
                logging.info('parameter {0} trainable'.format(param.name))
                params.append(param)
            logging.info(param.get_shape())

        if not self.forward_only:
            #self.gradient_norms = []
            self.updates = []
            with tf.device('/gpu:0'):
                #opt = tf.train.GradientDescentOptimizer(self.learning_rate)
                opt = tf.train.AdadeltaOptimizer(
                    learning_rate=initial_learning_rate,
                    rho=0.95,
                    epsilon=1e-08,
                    use_locking=False,
                    name='Adadelta')
                for b in xrange(len(buckets)):
                    gradients = tf.gradients(
                        self.attention_decoder_model.losses[b], params)
                    #self.gradient_norms.append(norm)
                    self.updates.append(
                        opt.apply_gradients(zip(gradients, params),
                                            global_step=self.global_step))

            #with tf.device('/gpu:1'):
            with tf.device('/gpu:0'):
                self.keras_updates = []
                for old_value, new_value in cnn_model.model.updates:
                    self.keras_updates.append(tf.assign(old_value, new_value))

        params_raw = tf.all_variables()
        params_init = []
        params_load = []
        for param in params_raw:
            #if 'Adadelta' in param.name and ('batch' in param.name or 'conv' in param.name):
            #    params_add.append(param)
            #if not 'BiRNN' in param.name:
            #    params_load.append(param)
            #else:
            #    params_init.append(param)
            if 'running' in param.name or (
                ('conv' in param.name or 'batch' in param.name) and
                ('Ada' not in param.name)
            ) or 'BiRNN' in param.name or 'attention' in param.name:
                params_load.append(param)
            else:
                params_init.append(param)

        self.saver_all = tf.train.Saver(tf.all_variables())
        #self.saver = tf.train.Saver(list(set(tf.all_variables())-set(params_add)))
        self.saver = tf.train.Saver(params_load)
        init_new_vars_op = tf.initialize_variables(params_init)

        ckpt = tf.train.get_checkpoint_state(model_dir)
        if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path) and load_model:
            logging.info("Reading model parameters from %s" %
                         ckpt.model_checkpoint_path)
            #self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            logging.info("Created model with fresh parameters.")
            self.sess.run(tf.initialize_all_variables())
        #self.sess.run(init_new_vars_op)

    # train or test as specified by phase
    def launch(self):
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        if self.phase == 'test':
            if not distance_loaded:
                logging.info(
                    'Warning: distance module not installed. Do whole sequence comparison instead.'
                )
            else:
                logging.info('Compare word based on edit distance.')
            num_correct = 0
            num_total = 0
            for batch in self.s_gen.gen(self.batch_size):
                # Get a batch and make a step.
                start_time = time.time()
                bucket_id = batch['bucket_id']
                img_data = batch['data']
                zero_paddings = batch['zero_paddings']
                decoder_inputs = batch['decoder_inputs']
                target_weights = batch['target_weights']
                encoder_masks = batch['encoder_mask']
                file_list = batch['filenames']
                real_len = batch['real_len']
                #print (decoder_inputs)
                #print (encoder_masks)
                grounds = [
                    a for a in np.array([
                        decoder_input.tolist()
                        for decoder_input in decoder_inputs
                    ]).transpose()
                ]
                _, step_loss, step_logits, step_attns = self.step(
                    encoder_masks, img_data, zero_paddings, decoder_inputs,
                    target_weights, bucket_id, self.forward_only)
                curr_step_time = (time.time() - start_time)
                step_time += curr_step_time / self.steps_per_checkpoint
                logging.info(
                    'step_time: %f, step perplexity: %f' %
                    (curr_step_time,
                     math.exp(step_loss) if step_loss < 300 else float('inf')))
                loss += step_loss / self.steps_per_checkpoint
                current_step += 1
                step_outputs = [
                    b for b in np.array([
                        np.argmax(logit, axis=1).tolist()
                        for logit in step_logits
                    ]).transpose()
                ]
                if self.visualize:
                    step_attns = np.array([[a.tolist() for a in step_attn]
                                           for step_attn in step_attns
                                           ]).transpose([1, 0, 2])
                    #print (step_attns)
                for idx, output, ground in zip(range(len(grounds)),
                                               step_outputs, grounds):
                    flag_ground, flag_out = True, True
                    num_total += 1
                    output_valid = []
                    ground_valid = []
                    for j in range(1, len(ground)):
                        s1 = output[j - 1]
                        s2 = ground[j]
                        if s2 != 2 and flag_ground:
                            ground_valid.append(s2)
                        else:
                            flag_ground = False
                        if s1 != 2 and flag_out:
                            output_valid.append(s1)
                        else:
                            flag_out = False
                    if distance_loaded:
                        num_incorrect = distance.levenshtein(
                            output_valid, ground_valid)
                        if self.visualize:
                            self.visualize_attention(
                                file_list[idx], step_attns[idx], output_valid,
                                ground_valid, num_incorrect > 0, real_len)
                        num_incorrect = float(num_incorrect) / len(
                            ground_valid)
                        num_incorrect = min(1.0, num_incorrect)
                    else:
                        if output_valid == ground_valid:
                            num_incorrect = 0
                        else:
                            num_incorrect = 1
                        if self.visualize:
                            self.visualize_attention(
                                file_list[idx], step_attns[idx], output_valid,
                                ground_valid, num_incorrect > 0, real_len)
                    num_correct += 1. - num_incorrect
                logging.info('%f out of %d correct' % (num_correct, num_total))
        elif self.phase == 'train':
            for epoch in range(self.num_epoch):
                for batch in self.s_gen.gen(self.batch_size):
                    # Get a batch and make a step.
                    start_time = time.time()
                    bucket_id = batch['bucket_id']
                    img_data = batch['data']
                    zero_paddings = batch['zero_paddings']
                    decoder_inputs = batch['decoder_inputs']
                    target_weights = batch['target_weights']
                    encoder_masks = batch['encoder_mask']
                    logging.info('current_step: %d' % current_step)
                    #logging.info(np.array([decoder_input.tolist() for decoder_input in decoder_inputs]).transpose()[0])
                    #print (np.array([target_weight.tolist() for target_weight in target_weights]).transpose()[0])
                    _, step_loss, step_logits, _ = self.step(
                        encoder_masks, img_data, zero_paddings, decoder_inputs,
                        target_weights, bucket_id, self.forward_only)
                    curr_step_time = (time.time() - start_time)
                    step_time += curr_step_time / self.steps_per_checkpoint
                    logging.info('step_time: %f, step perplexity: %f' %
                                 (curr_step_time, math.exp(step_loss)
                                  if step_loss < 300 else float('inf')))
                    loss += step_loss / self.steps_per_checkpoint
                    current_step += 1
                    # If there is an EOS symbol in outputs, cut them at that point.
                    #if data_utils.EOS_ID in step_outputs:
                    #    step_outputs = step_outputs[:step_outputs.index(data_utils.EOS_ID)]
                    #if data_utils.PAD_ID in decoder_inputs:
                    #decoder_inputs = decoder_inputs[:decoder_inputs.index(data_utils.PAD_ID)]
                    #    print (step_outputs[0])

                    # Once in a while, we save checkpoint, print statistics, and run evals.
                    if current_step % self.steps_per_checkpoint == 0:
                        # Print statistics for the previous epoch.
                        perplexity = math.exp(loss) if loss < 300 else float(
                            'inf')
                        logging.info(
                            "global step %d step-time %.2f perplexity "
                            "%.2f" %
                            (self.global_step.eval(), step_time, perplexity))
                        previous_losses.append(loss)
                        # Save checkpoint and zero timer and loss.
                        if not self.forward_only:
                            checkpoint_path = os.path.join(
                                self.model_dir, "translate.ckpt")
                            logging.info("Saving model, current_step: %d" %
                                         current_step)
                            self.saver_all.save(self.sess,
                                                checkpoint_path,
                                                global_step=self.global_step)
                        step_time, loss = 0.0, 0.0
                        #sys.stdout.flush()

    # step, read one batch, generate gradients
    def step(self, encoder_masks, img_data, zero_paddings, decoder_inputs,
             target_weights, bucket_id, forward_only):
        # Check if the sizes match.
        encoder_size, decoder_size = self.buckets[bucket_id]
        if len(decoder_inputs) != decoder_size:
            raise ValueError(
                "Decoder length must be equal to the one in bucket,"
                " %d != %d." % (len(decoder_inputs), decoder_size))
        if len(target_weights) != decoder_size:
            raise ValueError(
                "Weights length must be equal to the one in bucket,"
                " %d != %d." % (len(target_weights), decoder_size))

        # Input feed: encoder inputs, decoder inputs, target_weights, as provided.
        input_feed = {}
        if not forward_only:
            input_feed[K.learning_phase()] = 0
        else:
            input_feed[K.learning_phase()] = 0
        input_feed[self.img_data.name] = img_data
        input_feed[self.zero_paddings.name] = zero_paddings
        for l in xrange(decoder_size):
            input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
            input_feed[self.target_weights[l].name] = target_weights[l]
        for l in xrange(encoder_size):
            input_feed[self.encoder_masks[l].name] = encoder_masks[l]

        # Since our targets are decoder inputs shifted by one, we need one more.
        last_target = self.decoder_inputs[decoder_size].name
        input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)

        # Output feed: depends on whether we do a backward step or not.
        if not forward_only:
            output_feed = [
                self.updates[bucket_id],  # Update Op that does SGD.
                #self.gradient_norms[bucket_id],  # Gradient norm.
                self.attention_decoder_model.losses[bucket_id]
            ]
            output_feed += self.keras_updates
        else:
            output_feed = [self.attention_decoder_model.losses[bucket_id]
                           ]  # Loss for this batch.
            for l in xrange(decoder_size):  # Output logits.
                output_feed.append(
                    self.attention_decoder_model.outputs[bucket_id][l])
            if self.visualize:
                output_feed += self.attention_decoder_model.attention_weights_histories[
                    bucket_id]

        outputs = self.sess.run(output_feed, input_feed)
        if not forward_only:
            return None, outputs[
                1], None, None  # Gradient norm, loss, no outputs, no attentions.
        else:
            return None, outputs[0], outputs[1:(
                1 + self.buckets[bucket_id][1])], outputs[(
                    1 + self.buckets[bucket_id][1]
                ):]  # No gradient norm, loss, outputs, attentions.

    def visualize_attention(self, filename, attentions, output_valid,
                            ground_valid, flag_incorrect, real_len):
        if flag_incorrect:
            output_dir = os.path.join(self.output_dir, 'incorrect')
        else:
            output_dir = os.path.join(self.output_dir, 'correct')
        output_dir = os.path.join(output_dir, filename.replace('/', '_'))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, 'word.txt'), 'w') as fword:
            fword.write(' '.join([
                chr(c - 13 + 97) if c - 13 + 97 > 96 else chr(c - 3 + 48)
                for c in ground_valid
            ]) + '\n')
            fword.write(' '.join([
                chr(c - 13 + 97) if c - 13 + 97 > 96 else chr(c - 3 + 48)
                for c in output_valid
            ]))
            with open(filename, 'rb') as img_file:
                img = Image.open(img_file)
                w, h = img.size
                h = 32
                img = img.resize((real_len, h), Image.ANTIALIAS)
                img_data = np.asarray(img, dtype=np.uint8)
                for idx in range(len(output_valid)):
                    output_filename = os.path.join(output_dir,
                                                   'image_%d.jpg' % (idx))
                    attention = attentions[idx][:(real_len / 4 - 1)]

                    # I have got the attention_orig here, which is of size 32*len(ground_truth), the only thing left is to visualize it and save it to output_filename
                    # TODO here
                    attention_orig = np.zeros(real_len)
                    for i in range(real_len):
                        if 0 < i / 4 - 1 and i / 4 - 1 < len(attention):
                            attention_orig[i] = attention[i / 4 - 1]
                    attention_orig = np.convolve(
                        attention_orig,
                        [0.199547, 0.200226, 0.200454, 0.200226, 0.199547],
                        mode='same')
                    attention_orig = np.maximum(attention_orig, 0.3)
                    attention_out = np.zeros((h, real_len))
                    for i in range(real_len):
                        attention_out[:, i] = attention_orig[i]
                    if len(img_data.shape) == 3:
                        attention_out = attention_out[:, :, np.newaxis]
                    img_out_data = img_data * attention_out
                    img_out = Image.fromarray(img_out_data.astype(np.uint8))
                    img_out.save(output_filename)
예제 #5
0
    def __init__(self,
                 phase,
                 visualize,
                 data_path,
                 data_base_dir,
                 output_dir,
                 batch_size,
                 initial_learning_rate,
                 num_epoch,
                 steps_per_checkpoint,
                 target_vocab_size,
                 model_dir,
                 target_embedding_size,
                 attn_num_hidden,
                 attn_num_layers,
                 session,
                 load_model,
                 evaluate=False,
                 valid_target_length=float('inf'),
                 use_lstm=True):

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        logging.info('loading data')
        # load data
        if phase == 'train':
            self.s_gen = DataGen(data_base_dir,
                                 data_path,
                                 valid_target_len=valid_target_length,
                                 evaluate=False)
        else:
            batch_size = 1
            self.s_gen = DataGen(data_base_dir, data_path, evaluate=True)

        #logging.info('valid_target_length: %s' %(str(valid_target_length)))
        logging.info('data_path: %s' % (data_path))
        logging.info('phase: %s' % phase)
        logging.info('batch_size: %d' % batch_size)
        logging.info('num_epoch: %d' % num_epoch)
        logging.info('steps_per_checkpoint %d' % steps_per_checkpoint)
        logging.info('target_vocab_size: %d' % target_vocab_size)
        logging.info('model_dir: %s' % model_dir)
        logging.info('target_embedding_size: %d' % target_embedding_size)
        logging.info('attn_num_hidden: %d' % attn_num_hidden)
        logging.info('attn_num_layers: %d' % attn_num_layers)

        buckets = self.s_gen.bucket_specs
        logging.info('buckets')
        logging.info(buckets)

        # variables
        self.img_data = tf.placeholder(tf.float32,
                                       shape=(None, 1, 32, None),
                                       name='img_data')
        self.zero_paddings = tf.placeholder(tf.float32,
                                            shape=(None, None, 512),
                                            name='zero_paddings')

        self.decoder_inputs = []
        self.encoder_masks = []
        self.target_weights = []
        for i in xrange(buckets[-1][0] + 1):
            self.encoder_masks.append(
                tf.placeholder(tf.float32,
                               shape=[None, 1],
                               name="encoder_mask{0}".format(i)))
        for i in xrange(buckets[-1][1] + 1):
            self.decoder_inputs.append(
                tf.placeholder(tf.int32,
                               shape=[None],
                               name="decoder{0}".format(i)))
            self.target_weights.append(
                tf.placeholder(tf.float32,
                               shape=[None],
                               name="weight{0}".format(i)))

        self.sess = session
        self.evaluate = evaluate
        self.steps_per_checkpoint = steps_per_checkpoint
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.buckets = buckets
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.global_step = tf.Variable(0, trainable=False)
        self.valid_target_length = valid_target_length
        self.phase = phase
        self.visualize = visualize

        if phase == 'train':
            self.forward_only = False
        elif phase == 'test':
            self.forward_only = True
        else:
            assert False, phase
        #with tf.device('/gpu:1'):
        with tf.device('/gpu:0'):
            cnn_model = CNN(self.img_data)
            self.conv_output = cnn_model.tf_output()
            self.concat_conv_output = tf.concat(
                concat_dim=1, values=[self.conv_output, self.zero_paddings])

        #with tf.device('/cpu:0'):
        with tf.device('/gpu:0'):
            self.perm_conv_output = tf.transpose(self.concat_conv_output,
                                                 perm=[1, 0, 2])

        with tf.device('/gpu:0'):
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks=self.encoder_masks,
                encoder_inputs_tensor=self.perm_conv_output,
                decoder_inputs=self.decoder_inputs,
                target_weights=self.target_weights,
                target_vocab_size=target_vocab_size,
                buckets=buckets,
                target_embedding_size=target_embedding_size,
                attn_num_layers=attn_num_layers,
                attn_num_hidden=attn_num_hidden,
                forward_only=self.forward_only,
                use_lstm=use_lstm)

        # Gradients and SGD update operation for training the model.
        params_raw = tf.trainable_variables()
        params = []
        params_add = []
        for param in params_raw:
            #if 'running' in param.name or 'conv' in param.name or 'batch' in param.name:
            if 'running' in param.name:
                logging.info('parameter {0} NOT trainable'.format(param.name))
            else:
                logging.info('parameter {0} trainable'.format(param.name))
                params.append(param)
            logging.info(param.get_shape())

        if not self.forward_only:
            #self.gradient_norms = []
            self.updates = []
            with tf.device('/gpu:0'):
                #opt = tf.train.GradientDescentOptimizer(self.learning_rate)
                opt = tf.train.AdadeltaOptimizer(
                    learning_rate=initial_learning_rate,
                    rho=0.95,
                    epsilon=1e-08,
                    use_locking=False,
                    name='Adadelta')
                for b in xrange(len(buckets)):
                    gradients = tf.gradients(
                        self.attention_decoder_model.losses[b], params)
                    #self.gradient_norms.append(norm)
                    self.updates.append(
                        opt.apply_gradients(zip(gradients, params),
                                            global_step=self.global_step))

            #with tf.device('/gpu:1'):
            with tf.device('/gpu:0'):
                self.keras_updates = []
                for old_value, new_value in cnn_model.model.updates:
                    self.keras_updates.append(tf.assign(old_value, new_value))

        params_raw = tf.all_variables()
        params_init = []
        params_load = []
        for param in params_raw:
            #if 'Adadelta' in param.name and ('batch' in param.name or 'conv' in param.name):
            #    params_add.append(param)
            #if not 'BiRNN' in param.name:
            #    params_load.append(param)
            #else:
            #    params_init.append(param)
            if 'running' in param.name or (
                ('conv' in param.name or 'batch' in param.name) and
                ('Ada' not in param.name)
            ) or 'BiRNN' in param.name or 'attention' in param.name:
                params_load.append(param)
            else:
                params_init.append(param)

        self.saver_all = tf.train.Saver(tf.all_variables())
        #self.saver = tf.train.Saver(list(set(tf.all_variables())-set(params_add)))
        self.saver = tf.train.Saver(params_load)
        init_new_vars_op = tf.initialize_variables(params_init)

        ckpt = tf.train.get_checkpoint_state(model_dir)
        if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path) and load_model:
            logging.info("Reading model parameters from %s" %
                         ckpt.model_checkpoint_path)
            #self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            logging.info("Created model with fresh parameters.")
            self.sess.run(tf.initialize_all_variables())
예제 #6
0
    def __init__(self,
                 phase,
                 visualize,
                 data_path,
                 data_base_dir,
                 output_dir,
                 tb_logs,
                 tb_log_every,
                 batch_size,
                 initial_learning_rate,
                 num_epoch,
                 steps_per_checkpoint,
                 target_vocab_size,
                 model_dir,
                 target_embedding_size,
                 attn_num_hidden,
                 attn_num_layers,
                 session,
                 load_model,
                 gpu_id,
                 use_gru,
                 evaluate=False,
                 valid_target_length=float('inf'),
                 old_model_version=False):

        gpu_device_id = '/gpu:' + str(gpu_id)

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        logging.info('loading data')

        if phase == 'train':
            self.s_gen = DataGen(
                data_base_dir, data_path, valid_target_len=valid_target_length,
                evaluate=False)
        else:
            batch_size = 1
            self.s_gen = DataGen(
                data_base_dir, data_path, evaluate=True)

        logging.info('data_path: %s' % (data_path))
        logging.info('dTensorboard_logging_path: %s' % (tb_logs))
        logging.info('phase: %s' % phase)
        logging.info('batch_size: %d' % batch_size)
        logging.info('num_epoch: %d' % num_epoch)
        logging.info('steps_per_checkpoint %d' % steps_per_checkpoint)
        logging.info('target_vocab_size: %d' % target_vocab_size)
        logging.info('model_dir: %s' % model_dir)
        logging.info('target_embedding_size: %d' % target_embedding_size)
        logging.info('attn_num_hidden: %d' % attn_num_hidden)
        logging.info('attn_num_layers: %d' % attn_num_layers)

        buckets = self.s_gen.bucket_specs
        logging.info('buckets')
        logging.info(buckets)
        if use_gru:
            logging.info('ues GRU in the decoder.')

        # variables
        self.seq_data = tf.placeholder(tf.float32, shape=(None, 1, 10, None),
                                       name='seq_data')
        self.zero_paddings = tf.placeholder(tf.float32, shape=(None, None, 512),
                                            name='zero_paddings')

        self.decoder_inputs = []
        self.encoder_masks = []
        self.target_weights = []
        for i in xrange(buckets[-1][0] + 1):
            self.encoder_masks.append(
                tf.placeholder(tf.float32, shape=[None, 1],
                               name="encoder_mask{0}".format(i)))
        for i in xrange(buckets[-1][1] + 1):
            self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
                                                      name="decoder{0}".format(
                                                          i)))
            self.target_weights.append(tf.placeholder(tf.float32, shape=[None],
                                                      name="weight{0}".format(
                                                          i)))

        self.sess = session
        self.evaluate = evaluate
        self.steps_per_checkpoint = steps_per_checkpoint
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.buckets = buckets
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.global_step = tf.Variable(0, trainable=False)
        self.valid_target_length = valid_target_length
        self.phase = phase
        self.visualize = visualize
        self.tb_logs = tb_logs
        self.tb_log_every = tb_log_every

        if phase == 'train':
            self.forward_only = False
        elif phase == 'test':
            self.forward_only = True
        else:
            assert False, phase

        with tf.device(gpu_device_id):
            cnn_model = CNN(self.seq_data)
            self.conv_output = cnn_model.tf_output()
            self.concat_conv_output = tf.concat(concat_dim=1,
                                                values=[self.conv_output,
                                                        self.zero_paddings])
            self.perm_conv_output = tf.transpose(self.concat_conv_output,
                                                 perm=[1, 0, 2])

        with tf.device(gpu_device_id):
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks=self.encoder_masks,
                encoder_inputs_tensor=self.perm_conv_output,
                decoder_inputs=self.decoder_inputs,
                target_weights=self.target_weights,
                target_vocab_size=target_vocab_size,
                buckets=buckets,
                target_embedding_size=target_embedding_size,
                attn_num_layers=attn_num_layers,
                attn_num_hidden=attn_num_hidden,
                forward_only=self.forward_only,
                use_gru=use_gru)


        # Gradients and SGD update operation for training the model.
        params_raw = tf.trainable_variables()
        params = []
        params_add = []
        params_run = []
        for param in params_raw:
            # if 'running' in param.name or 'conv' in param.name or 'batch'
            # in param.name:
            if 'running' in param.name:
                logging.info('parameter {0} NOT trainable'.format(param.name))
                # for old keras conversion
                if 'running_std' in param.name:
                    params_run.append(param)
            else:
                logging.info('parameter {0} trainable'.format(param.name))
                params.append(param)
            logging.info(param.get_shape())

        if not self.forward_only:
            # self.gradient_norms = []
            self.updates = []
            with tf.device(gpu_device_id):
                #opt = tf.train.AdadeltaOptimizer(
                #    learning_rate=initial_learning_rate, rho=0.95,
                #    epsilon=1e-08, use_locking=False, name='Adadelta')
                opt = tf.train.AdamOptimizer(learning_rate=initial_learning_rate)
                for b in xrange(len(buckets)):
                    gradients = tf.gradients(
                        self.attention_decoder_model.losses[b], params)
                    # self.gradient_norms.append(norm)
                    self.updates.append(opt.apply_gradients(
                        zip(gradients, params), global_step=self.global_step))

            with tf.device(gpu_device_id):
                self.keras_updates = []
                for i in xrange(int(len(cnn_model.model.updates) / 2)):
                    old_value = cnn_model.model.updates[i * 2]
                    new_value = cnn_model.model.updates[i * 2 + 1]
                    self.keras_updates.append(tf.assign(old_value, new_value))

        params_dict = {}
        params_raw = tf.all_variables()
        # params_init = []
        # params_load = []
        for param in params_raw:
            name = param.name
            # to be compatible with old version saved model
            if 'BiRNN' in name:
                name = name.replace('BiRNN/', 'BiRNN_')
            if ':0' in name:
                name = name.replace(':0', '')
            params_dict[name] = param
            self._activation_summary(param)
            # if 'Adadelta' in param.name and ('batch' in param.name or
            # 'conv' in param.name):
            #    params_add.append(param)
            # if not 'BiRNN' in param.name:
            #    params_load.append(param)
            # else:
            #    params_init.append(param)
            # if 'BiRNN/' in param.name:
            #    param.name = param.name.replace('BiRNN/', 'BiRNN_')
            # if 'running' in param.name or (('conv' in param.name or 'batch'
            # in param.name) and ('Ada' not in param.name)) or 'BiRNN' in
            # param.name or 'attention' in param.name:
            #    params_load.append(param)
            # else:
            #    params_init.append(param)
        #for b_id, loss in enumerate(self.attention_decoder_model.losses):
        #    print(loss)
        #    tf.scalar_summary("Bucket loss/" + str(buckets[b_id]), loss)
        #    tf.scalar_summary("Bucket perplexity/" + str(buckets[b_id]),
        #                      tf.exp(loss))

        self.summary_op = tf.merge_all_summaries()
        self.saver_all = tf.train.Saver(params_dict)
        # self.saver = tf.train.Saver(list(set(tf.all_variables())-set(
        # params_add)))
        # self.saver = tf.train.Saver(params_load)
        # init_new_vars_op = tf.initialize_variables(params_init)

        ckpt = tf.train.get_checkpoint_state(model_dir)
        if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path) and load_model:
            self.model_loaded = True
            logging.info("Reading model parameters from %s" %
                         ckpt.model_checkpoint_path)
            # self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
            if old_model_version:
                for param in params_run:
                    self.sess.run([param.assign(tf.square(param))])
        else:
            self.model_loaded = False
            logging.info("Created model with fresh parameters.")
            self.sess.run(tf.initialize_all_variables())
예제 #7
0
class Model(object):
    def __init__(self,
                 phase,
                 visualize,
                 data_path,
                 data_base_dir,
                 output_dir,
                 tb_logs,
                 tb_log_every,
                 batch_size,
                 initial_learning_rate,
                 num_epoch,
                 steps_per_checkpoint,
                 target_vocab_size,
                 model_dir,
                 target_embedding_size,
                 attn_num_hidden,
                 attn_num_layers,
                 session,
                 load_model,
                 gpu_id,
                 use_gru,
                 evaluate=False,
                 valid_target_length=float('inf'),
                 old_model_version=False):

        gpu_device_id = '/gpu:' + str(gpu_id)

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        logging.info('loading data')

        if phase == 'train':
            self.s_gen = DataGen(
                data_base_dir, data_path, valid_target_len=valid_target_length,
                evaluate=False)
        else:
            batch_size = 1
            self.s_gen = DataGen(
                data_base_dir, data_path, evaluate=True)

        logging.info('data_path: %s' % (data_path))
        logging.info('dTensorboard_logging_path: %s' % (tb_logs))
        logging.info('phase: %s' % phase)
        logging.info('batch_size: %d' % batch_size)
        logging.info('num_epoch: %d' % num_epoch)
        logging.info('steps_per_checkpoint %d' % steps_per_checkpoint)
        logging.info('target_vocab_size: %d' % target_vocab_size)
        logging.info('model_dir: %s' % model_dir)
        logging.info('target_embedding_size: %d' % target_embedding_size)
        logging.info('attn_num_hidden: %d' % attn_num_hidden)
        logging.info('attn_num_layers: %d' % attn_num_layers)

        buckets = self.s_gen.bucket_specs
        logging.info('buckets')
        logging.info(buckets)
        if use_gru:
            logging.info('ues GRU in the decoder.')

        # variables
        self.seq_data = tf.placeholder(tf.float32, shape=(None, 1, 10, None),
                                       name='seq_data')
        self.zero_paddings = tf.placeholder(tf.float32, shape=(None, None, 512),
                                            name='zero_paddings')

        self.decoder_inputs = []
        self.encoder_masks = []
        self.target_weights = []
        for i in xrange(buckets[-1][0] + 1):
            self.encoder_masks.append(
                tf.placeholder(tf.float32, shape=[None, 1],
                               name="encoder_mask{0}".format(i)))
        for i in xrange(buckets[-1][1] + 1):
            self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
                                                      name="decoder{0}".format(
                                                          i)))
            self.target_weights.append(tf.placeholder(tf.float32, shape=[None],
                                                      name="weight{0}".format(
                                                          i)))

        self.sess = session
        self.evaluate = evaluate
        self.steps_per_checkpoint = steps_per_checkpoint
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.buckets = buckets
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.global_step = tf.Variable(0, trainable=False)
        self.valid_target_length = valid_target_length
        self.phase = phase
        self.visualize = visualize
        self.tb_logs = tb_logs
        self.tb_log_every = tb_log_every

        if phase == 'train':
            self.forward_only = False
        elif phase == 'test':
            self.forward_only = True
        else:
            assert False, phase

        with tf.device(gpu_device_id):
            cnn_model = CNN(self.seq_data)
            self.conv_output = cnn_model.tf_output()
            self.concat_conv_output = tf.concat(concat_dim=1,
                                                values=[self.conv_output,
                                                        self.zero_paddings])
            self.perm_conv_output = tf.transpose(self.concat_conv_output,
                                                 perm=[1, 0, 2])

        with tf.device(gpu_device_id):
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks=self.encoder_masks,
                encoder_inputs_tensor=self.perm_conv_output,
                decoder_inputs=self.decoder_inputs,
                target_weights=self.target_weights,
                target_vocab_size=target_vocab_size,
                buckets=buckets,
                target_embedding_size=target_embedding_size,
                attn_num_layers=attn_num_layers,
                attn_num_hidden=attn_num_hidden,
                forward_only=self.forward_only,
                use_gru=use_gru)


        # Gradients and SGD update operation for training the model.
        params_raw = tf.trainable_variables()
        params = []
        params_add = []
        params_run = []
        for param in params_raw:
            # if 'running' in param.name or 'conv' in param.name or 'batch'
            # in param.name:
            if 'running' in param.name:
                logging.info('parameter {0} NOT trainable'.format(param.name))
                # for old keras conversion
                if 'running_std' in param.name:
                    params_run.append(param)
            else:
                logging.info('parameter {0} trainable'.format(param.name))
                params.append(param)
            logging.info(param.get_shape())

        if not self.forward_only:
            # self.gradient_norms = []
            self.updates = []
            with tf.device(gpu_device_id):
                #opt = tf.train.AdadeltaOptimizer(
                #    learning_rate=initial_learning_rate, rho=0.95,
                #    epsilon=1e-08, use_locking=False, name='Adadelta')
                opt = tf.train.AdamOptimizer(learning_rate=initial_learning_rate)
                for b in xrange(len(buckets)):
                    gradients = tf.gradients(
                        self.attention_decoder_model.losses[b], params)
                    # self.gradient_norms.append(norm)
                    self.updates.append(opt.apply_gradients(
                        zip(gradients, params), global_step=self.global_step))

            with tf.device(gpu_device_id):
                self.keras_updates = []
                for i in xrange(int(len(cnn_model.model.updates) / 2)):
                    old_value = cnn_model.model.updates[i * 2]
                    new_value = cnn_model.model.updates[i * 2 + 1]
                    self.keras_updates.append(tf.assign(old_value, new_value))

        params_dict = {}
        params_raw = tf.all_variables()
        # params_init = []
        # params_load = []
        for param in params_raw:
            name = param.name
            # to be compatible with old version saved model
            if 'BiRNN' in name:
                name = name.replace('BiRNN/', 'BiRNN_')
            if ':0' in name:
                name = name.replace(':0', '')
            params_dict[name] = param
            self._activation_summary(param)
            # if 'Adadelta' in param.name and ('batch' in param.name or
            # 'conv' in param.name):
            #    params_add.append(param)
            # if not 'BiRNN' in param.name:
            #    params_load.append(param)
            # else:
            #    params_init.append(param)
            # if 'BiRNN/' in param.name:
            #    param.name = param.name.replace('BiRNN/', 'BiRNN_')
            # if 'running' in param.name or (('conv' in param.name or 'batch'
            # in param.name) and ('Ada' not in param.name)) or 'BiRNN' in
            # param.name or 'attention' in param.name:
            #    params_load.append(param)
            # else:
            #    params_init.append(param)
        #for b_id, loss in enumerate(self.attention_decoder_model.losses):
        #    print(loss)
        #    tf.scalar_summary("Bucket loss/" + str(buckets[b_id]), loss)
        #    tf.scalar_summary("Bucket perplexity/" + str(buckets[b_id]),
        #                      tf.exp(loss))

        self.summary_op = tf.merge_all_summaries()
        self.saver_all = tf.train.Saver(params_dict)
        # self.saver = tf.train.Saver(list(set(tf.all_variables())-set(
        # params_add)))
        # self.saver = tf.train.Saver(params_load)
        # init_new_vars_op = tf.initialize_variables(params_init)

        ckpt = tf.train.get_checkpoint_state(model_dir)
        if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path) and load_model:
            self.model_loaded = True
            logging.info("Reading model parameters from %s" %
                         ckpt.model_checkpoint_path)
            # self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
            if old_model_version:
                for param in params_run:
                    self.sess.run([param.assign(tf.square(param))])
        else:
            self.model_loaded = False
            logging.info("Created model with fresh parameters.")
            self.sess.run(tf.initialize_all_variables())

    # train or test as specified by phase
    def launch(self):
        step_time, loss = 0.0, 0.0
        current_step = 0
        if not os.path.exists(self.tb_logs):
            os.makedirs(self.tb_logs)
        summary_writer = tf.train.SummaryWriter(self.tb_logs, self.sess.graph)
        previous_losses = []
        if self.phase == 'test':
            if not distance_loaded:
                logging.info(
                    'Warning: distance module not installed. Do whole '
                    'sequence comparison instead.')
            else:
                logging.info('Compare word based on edit distance.')
            num_correct = 0
            num_total = 0
            for batch in self.s_gen.gen(self.batch_size):
                # Get a batch and make a step.
                start_time = time.time()
                bucket_id = batch['bucket_id']
                seq_data = batch['data']
                zero_paddings = batch['zero_paddings']
                decoder_inputs = batch['decoder_inputs']
                target_weights = batch['target_weights']
                encoder_masks = batch['encoder_mask']
                file_list = batch['filenames']
                real_len = batch['real_len']

                grounds = [a for a in np.array(
                    [decoder_input.tolist() for decoder_input in
                     decoder_inputs]).transpose()]
                _, step_loss, step_logits, step_attns = self.step(encoder_masks,
                                                                  seq_data,
                                                                  zero_paddings,
                                                                  decoder_inputs,
                                                                  target_weights,
                                                                  bucket_id,
                                                                  self.forward_only)
                curr_step_time = (time.time() - start_time)
                step_time += curr_step_time / self.steps_per_checkpoint
                logging.info('step_time: %f, loss: %f, step perplexity: %f' % (
                curr_step_time, step_loss,
                math.exp(step_loss) if step_loss < 300 else float('inf')))
                loss += step_loss / self.steps_per_checkpoint
                current_step += 1
                step_outputs = [b for b in np.array(
                    [np.argmax(logit, axis=1).tolist() for logit in
                     step_logits]).transpose()]
                if self.visualize:
                    step_attns = np.array([[a.tolist() for a in step_attn]
                                           for step_attn in
                                           step_attns]).transpose([1, 0, 2])
                    # print (step_attns)

                for idx, output, ground in zip(range(len(grounds)),
                                               step_outputs, grounds):
                    flag_ground, flag_out = True, True
                    num_total += 1
                    output_valid = []
                    ground_valid = []
                    for j in range(1, len(ground)):
                        s1 = output[j - 1]
                        s2 = ground[j]
                        if s2 != 2 and flag_ground:
                            ground_valid.append(s2)
                        else:
                            flag_ground = False
                        if s1 != 2 and flag_out:
                            output_valid.append(s1)
                        else:
                            flag_out = False
                    if distance_loaded:
                        num_incorrect = distance.levenshtein(output_valid,
                                                             ground_valid)
                        if self.visualize:
                            self.visualize_attention(file_list[idx],
                                                     step_attns[idx],
                                                     output_valid, ground_valid,
                                                     num_incorrect > 0,
                                                     real_len)
                        num_incorrect = float(num_incorrect) / len(ground_valid)
                        num_incorrect = min(1.0, num_incorrect)
                    else:
                        if output_valid == ground_valid:
                            num_incorrect = 0
                        else:
                            num_incorrect = 1
                        if self.visualize:
                            self.visualize_attention(file_list[idx],
                                                     step_attns[idx],
                                                     output_valid, ground_valid,
                                                     num_incorrect > 0,
                                                     real_len)
                    num_correct += 1. - num_incorrect
                logging.info('%f out of %d correct' % (num_correct, num_total))
        elif self.phase == 'train':
            for epoch in range(self.num_epoch):
                for batch in self.s_gen.gen(self.batch_size):

                    # Get a batch and make a step.
                    start_time = time.time()
                    bucket_id = batch['bucket_id']
                    seq_data = batch['data']
                    zero_paddings = batch['zero_paddings']
                    decoder_inputs = batch['decoder_inputs']
                    target_weights = batch['target_weights']
                    encoder_masks = batch['encoder_mask']
                    logging.info('current_step: %d' % current_step)

                    # logging.info(np.array([decoder_input.tolist() for
                    # decoder_input in decoder_inputs]).transpose()[0])
                    # print (np.array([target_weight.tolist() for
                    # target_weight in target_weights]).transpose()[0])

                    _, step_loss, step_logits, _ = self.step(encoder_masks,
                                                             seq_data,
                                                             zero_paddings,
                                                             decoder_inputs,
                                                             target_weights,
                                                             bucket_id,
                                                             self.forward_only)
                    curr_step_time = (time.time() - start_time)
                    step_time += curr_step_time / self.steps_per_checkpoint
                    logging.info(
                        'step_time: %f, step_loss: %f, step perplexity: %f' % (
                        curr_step_time, step_loss,
                        math.exp(step_loss) if step_loss < 300 else float(
                            'inf')))
                    loss += step_loss / self.steps_per_checkpoint
                    current_step += 1

                    # If there is an EOS symbol in outputs, cut them at that
                    # point.
                    # if data_utils.EOS_ID in step_outputs:
                    #    step_outputs = step_outputs[:step_outputs.index(
                    # data_utils.EOS_ID)]
                    # if data_utils.PAD_ID in decoder_inputs:
                    # decoder_inputs = decoder_inputs[:decoder_inputs.index(
                    # data_utils.PAD_ID)]
                    #    print (step_outputs[0])

                    # Once in a while, we save checkpoint, print statistics,
                    # and run evals.
                    loss_dumps_path = os.path.join(self.tb_logs,
                                                   'loss_perp_log.tsv')
                    if current_step % self.tb_log_every == 0:
                        summary_str = self.sess.run(self.summary_op)
                        summary_writer.add_summary(summary_str, current_step)
                        perplexity = math.exp(step_loss) if loss < 300 else float('inf')
                        if self.model_loaded:
                            with open(loss_dumps_path, "a") as myfile :
                                myfile.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(
                                    time.time(),
                                    epoch,
                                    current_step,
                                    step_loss,
                                    perplexity,
                                    curr_step_time))
                        else:
                            with open(loss_dumps_path, "w") as myfile :
                                myfile.write("{}\t{}\t{}\t{}\t{}\t{}\n".format(
                                    time.time(),
                                    epoch,
                                    current_step,
                                    step_loss,
                                    perplexity,
                                    curr_step_time))
                                self.model_loaded = True

                    if current_step % self.steps_per_checkpoint == 0:
                        # Print statistics for the previous epoch.
                        perplexity = math.exp(loss) if loss < 300 else float(
                            'inf')
                        logging.info(
                            "global step %d step-time %.2f loss %f  perplexity "
                            "%.2f" % (self.global_step.eval(), step_time, loss,
                                      perplexity))
                        previous_losses.append(loss)
                        print(
                            "epoch: {}, step: {}, loss: {}, perplexity: {}, "
                            "step_time: {}".format(
                                epoch, current_step, loss, perplexity,
                                curr_step_time))



                        # Save checkpoint and zero timer and loss.
                        if not self.forward_only:
                            checkpoint_path = os.path.join(self.model_dir,
                                                           "attn_seq.ckpt")
                            logging.info(
                                "Saving model, current_step: %d" % current_step)
                            self.saver_all.save(self.sess, checkpoint_path,
                                                global_step=self.global_step)
                        step_time, loss = 0.0, 0.0
                        # sys.stdout.flush()

    # step, read one batch, generate gradients
    def step(self, encoder_masks, seq_data, zero_paddings, decoder_inputs,
             target_weights,
             bucket_id, forward_only):
        # Check if the sizes match.
        encoder_size, decoder_size = self.buckets[bucket_id]
        if len(decoder_inputs) != decoder_size:
            raise ValueError(
                "Decoder length must be equal to the one in bucket,"
                " %d != %d." % (len(decoder_inputs), decoder_size))
        if len(target_weights) != decoder_size:
            raise ValueError(
                "Weights length must be equal to the one in bucket,"
                " %d != %d." % (len(target_weights), decoder_size))

        # Input feed: encoder inputs, decoder inputs, target_weights,
        # as provided.
        input_feed = {}
        if not forward_only:
            input_feed[K.learning_phase()] = 1
        else:
            input_feed[K.learning_phase()] = 0
        input_feed[self.seq_data.name] = seq_data
        input_feed[self.zero_paddings.name] = zero_paddings
        for l in xrange(decoder_size):
            input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
            input_feed[self.target_weights[l].name] = target_weights[l]
        for l in xrange(encoder_size):
            input_feed[self.encoder_masks[l].name] = encoder_masks[l]

        # Since our targets are decoder inputs shifted by one, we need one more.
        last_target = self.decoder_inputs[decoder_size].name
        input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)

        # Output feed: depends on whether we do a backward step or not.
        if not forward_only:
            output_feed = [self.updates[bucket_id],  # Update Op that does SGD.
                           # self.gradient_norms[bucket_id],  # Gradient norm.
                           self.attention_decoder_model.losses[bucket_id]]
            output_feed += self.keras_updates
        else:
            output_feed = [self.attention_decoder_model.losses[
                               bucket_id]]  # Loss for this batch.
            for l in xrange(decoder_size):  # Output logits.
                output_feed.append(
                    self.attention_decoder_model.outputs[bucket_id][l])
            if self.visualize:
                output_feed += \
                self.attention_decoder_model.attention_weights_histories[
                    bucket_id]

        outputs = self.sess.run(output_feed, input_feed)
        if not forward_only:
            return None, outputs[
                1], None, None  # Gradient norm, loss, no outputs,
            # no attentions.
        else:
            return None, outputs[0], outputs[1:(
            1 + self.buckets[bucket_id][1])], outputs[(
            1 + self.buckets[bucket_id][
                1]):]  # No gradient norm, loss, outputs, attentions.

    def _activation_summary(self, x) :
        """Helper to create summaries for activations.
        Creates a summary that provides a histogram of activations.
        Args:
          x: Tensor
        Returns:
          nothing
        """

        tensor_name = x.name
        tf.histogram_summary(tensor_name + '/activations', x)

    def visualize_attention(self, filename, attentions, output_valid,
                            ground_valid, flag_incorrect, real_len):
        if flag_incorrect:
            output_dir = os.path.join(self.output_dir, 'incorrect')
        else:
            output_dir = os.path.join(self.output_dir, 'correct')
        output_dir = os.path.join(output_dir, filename.split("/")[-1].split(
            ".")[0] + filename.split("/")[-1].split(".")[1])
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, 'word.txt'), 'w') as fword:
            gt_str = ' '.join([chr(c - 13 + 97) if c - 13 + 97 > 96 else chr(c - 3 + 48)
                 for c in ground_valid])
            fword.write(gt_str + '\n')

            output_str = ' '.join(
                [chr(c - 13 + 97) if c - 13 + 97 > 96 else chr(c - 3 + 48) for c
                 in output_valid])
            fword.write(output_str)

            # with open(filename, 'rb') as seq_file:
            seq = np.load(filename)
            w, h = seq.shape
            h = 10

            seq = signal.resample(seq, real_len)
            #seq = seq.transpose()
            seq_data = np.asarray(seq, dtype=np.float32)
            out_attns = []
            output_filename = os.path.join(output_dir,'image.png')

            for idx in range(len(output_valid)):

                attention = attentions[idx][:(int(real_len / 16) - 1)]

                # I have got the attention_orig here, which is of size 32*len(ground_truth),
                #  the only thing left is to visualize it and save it to output_filename
                attention_orig = np.zeros(real_len)
                for i in range(real_len):
                    if 0 < i / 16 - 1 and i / 16 - 1 < len(attention):
                        attention_orig[i] = attention[int(i / 16) - 1]
                attention_orig = np.convolve(attention_orig,
                                             [0.199547, 0.200226, 0.200454,
                                              0.200226, 0.199547], mode='same')
                #attention_orig = np.maximum(attention_orig, 0.3)
                attention_out = np.zeros((h, real_len))
                for i in range(real_len):
                    attention_out[:, i] = attention_orig[i]

                out_attns.append(attention_orig)

            out_attns = np.vstack(out_attns)
            out_attns = out_attns.transpose()
            seq_np = np.array(seq_data)
            rows, cols = seq_np.shape[0], seq_np.shape[1]

            f1 = plt.figure()
            #f2 = plt.figure()
            ax1 = f1.add_subplot(121)

            ax2 = f1.add_subplot(122)

            y_axis_ticks = np.arange(0, rows, 1)
            #x_axis_ticks = np.arange(1, rows, 1)
            for i in range(cols):
                dat = seq_np[:, i]
                ax1.plot(dat, y_axis_ticks)

            #ax1.set_title('Sharing Y axis')
            ax2.imshow(out_attns, interpolation='nearest', aspect='auto', cmap=cm.jet)
            #ax2.set_xticks(output_str.split(' '))
            #plt.show()
            plt.savefig(output_filename, bbox_inches='tight', dpi=750)
예제 #8
0
    def __init__(self,
            phase,
            visualize,
            data_path,
            data_base_dir,
            output_dir,
            batch_size,
            initial_learning_rate,
            num_epoch,
            steps_per_checkpoint,
            target_vocab_size, 
            model_dir, 
            target_embedding_size,
            attn_num_hidden, 
            attn_num_layers, 
            session,
            load_model,
            gpu_id,
            use_gru,
            evaluate=False,
            valid_target_length=float('inf')):

        gpu_device_id = '/gpu:' + str(gpu_id)

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        logging.info('loading data')
        # load data
        if phase == 'train':
            self.s_gen = DataGen(
                data_base_dir, data_path, valid_target_len=valid_target_length, evaluate=False)
        else:
            batch_size = 1
            self.s_gen = DataGen(
                data_base_dir, data_path, evaluate=True)

        #logging.info('valid_target_length: %s' %(str(valid_target_length)))
        logging.info('data_path: %s' %(data_path))
        logging.info('phase: %s' %phase)    
        logging.info('batch_size: %d' %batch_size)
        logging.info('num_epoch: %d' %num_epoch)
        logging.info('steps_per_checkpoint %d' %steps_per_checkpoint)
        logging.info('target_vocab_size: %d' %target_vocab_size)
        logging.info('model_dir: %s' %model_dir)
        logging.info('target_embedding_size: %d' %target_embedding_size)
        logging.info('attn_num_hidden: %d' %attn_num_hidden)
        logging.info('attn_num_layers: %d' %attn_num_layers)

        buckets = self.s_gen.bucket_specs
        logging.info('buckets')
        logging.info(buckets)
        if use_gru:
            logging.info('ues GRU in the decoder.')

        # variables
        self.img_data = tf.placeholder(tf.float32, shape=(None, 1, 32, None), name='img_data')
        self.zero_paddings = tf.placeholder(tf.float32, shape=(None, None, 512), name='zero_paddings')
        
        self.decoder_inputs = []
        self.encoder_masks = []
        self.target_weights = []
        for i in xrange(buckets[-1][0] + 1):
            self.encoder_masks.append(tf.placeholder(tf.float32, shape=[None, 1],
                                                    name="encoder_mask{0}".format(i)))
        for i in xrange(buckets[-1][1] + 1):
            self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
                                                    name="decoder{0}".format(i)))
            self.target_weights.append(tf.placeholder(tf.float32, shape=[None],
                                                    name="weight{0}".format(i)))
      
        self.sess = session
        self.evaluate = evaluate
        self.steps_per_checkpoint = steps_per_checkpoint 
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.buckets = buckets
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.global_step = tf.Variable(0, trainable=False)
        self.valid_target_length = valid_target_length
        self.phase = phase
        self.visualize = visualize
       
        if phase == 'train':
            self.forward_only = False
        elif phase == 'test':
            self.forward_only = True
        else:
            assert False, phase
        #with tf.device('/gpu:1'):
        with tf.device(gpu_device_id):
            cnn_model = CNN(self.img_data)
            self.conv_output = cnn_model.tf_output()
            self.concat_conv_output = tf.concat(concat_dim=1, values=[self.conv_output, self.zero_paddings])
        
        #with tf.device('/cpu:0'): 
        with tf.device(gpu_device_id): 
            self.perm_conv_output = tf.transpose(self.concat_conv_output, perm=[1, 0, 2])
        
        #with tf.device('/gpu:1'):
        with tf.device(gpu_device_id):
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks = self.encoder_masks,
                encoder_inputs_tensor = self.perm_conv_output, 
                decoder_inputs = self.decoder_inputs,
                target_weights = self.target_weights,
                target_vocab_size = target_vocab_size, 
                buckets = buckets,
                target_embedding_size = target_embedding_size,
                attn_num_layers = attn_num_layers,
                attn_num_hidden = attn_num_hidden,
                forward_only = self.forward_only,
                use_gru = use_gru)
        
        # Gradients and SGD update operation for training the model.
        params_raw = tf.trainable_variables()
        params = []
        params_add = []
        for param in params_raw:
            #if 'running' in param.name or 'conv' in param.name or 'batch' in param.name:
            if 'running' in param.name:
                logging.info('parameter {0} NOT trainable'.format(param.name))
            else:
                logging.info('parameter {0} trainable'.format(param.name))
                params.append(param)
            logging.info(param.get_shape())

        if not self.forward_only:
            #self.gradient_norms = []
            self.updates = []
            with tf.device(gpu_device_id):
            #opt = tf.train.GradientDescentOptimizer(self.learning_rate)
                opt = tf.train.AdadeltaOptimizer(learning_rate=initial_learning_rate, rho=0.95, epsilon=1e-08, use_locking=False, name='Adadelta')
                for b in xrange(len(buckets)):
                    gradients = tf.gradients(self.attention_decoder_model.losses[b], params)
                    #self.gradient_norms.append(norm)
                    self.updates.append(opt.apply_gradients(
                        zip(gradients, params), global_step=self.global_step))
       
            #with tf.device('/gpu:1'):
            with tf.device(gpu_device_id):
                self.keras_updates = []
                for old_value, new_value in cnn_model.model.updates:
                        self.keras_updates.append(tf.assign(old_value, new_value))

        params_raw = tf.all_variables()
        params_init = []
        params_load = []
        for param in params_raw:
            #if 'Adadelta' in param.name and ('batch' in param.name or 'conv' in param.name):
            #    params_add.append(param)
            #if not 'BiRNN' in param.name:
            #    params_load.append(param)
            #else:
            #    params_init.append(param)
            if 'running' in param.name or (('conv' in param.name or 'batch' in param.name) and ('Ada' not in param.name)) or 'BiRNN' in param.name or 'attention' in param.name:
                params_load.append(param)
            else:
                params_init.append(param)

        self.saver_all = tf.train.Saver(tf.all_variables())
        #self.saver = tf.train.Saver(list(set(tf.all_variables())-set(params_add)))
        self.saver = tf.train.Saver(params_load)
        init_new_vars_op = tf.initialize_variables(params_init)

        ckpt = tf.train.get_checkpoint_state(model_dir)
        if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path) and load_model:
            logging.info("Reading model parameters from %s" % ckpt.model_checkpoint_path)
            #self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            logging.info("Created model with fresh parameters.")
            self.sess.run(tf.initialize_all_variables())
예제 #9
0
class Model(object):

    def __init__(self,
            phase,
            visualize,
            data_path,
            data_base_dir,
            output_dir,
            batch_size,
            initial_learning_rate,
            num_epoch,
            steps_per_checkpoint,
            target_vocab_size, 
            model_dir, 
            target_embedding_size,
            attn_num_hidden, 
            attn_num_layers, 
            session,
            load_model,
            gpu_id,
            use_gru,
            evaluate=False,
            valid_target_length=float('inf')):

        gpu_device_id = '/gpu:' + str(gpu_id)

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        logging.info('loading data')
        # load data
        if phase == 'train':
            self.s_gen = DataGen(
                data_base_dir, data_path, valid_target_len=valid_target_length, evaluate=False)
        else:
            batch_size = 1
            self.s_gen = DataGen(
                data_base_dir, data_path, evaluate=True)

        #logging.info('valid_target_length: %s' %(str(valid_target_length)))
        logging.info('data_path: %s' %(data_path))
        logging.info('phase: %s' %phase)    
        logging.info('batch_size: %d' %batch_size)
        logging.info('num_epoch: %d' %num_epoch)
        logging.info('steps_per_checkpoint %d' %steps_per_checkpoint)
        logging.info('target_vocab_size: %d' %target_vocab_size)
        logging.info('model_dir: %s' %model_dir)
        logging.info('target_embedding_size: %d' %target_embedding_size)
        logging.info('attn_num_hidden: %d' %attn_num_hidden)
        logging.info('attn_num_layers: %d' %attn_num_layers)

        buckets = self.s_gen.bucket_specs
        logging.info('buckets')
        logging.info(buckets)
        if use_gru:
            logging.info('ues GRU in the decoder.')

        # variables
        self.img_data = tf.placeholder(tf.float32, shape=(None, 1, 32, None), name='img_data')
        self.zero_paddings = tf.placeholder(tf.float32, shape=(None, None, 512), name='zero_paddings')
        
        self.decoder_inputs = []
        self.encoder_masks = []
        self.target_weights = []
        for i in xrange(buckets[-1][0] + 1):
            self.encoder_masks.append(tf.placeholder(tf.float32, shape=[None, 1],
                                                    name="encoder_mask{0}".format(i)))
        for i in xrange(buckets[-1][1] + 1):
            self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
                                                    name="decoder{0}".format(i)))
            self.target_weights.append(tf.placeholder(tf.float32, shape=[None],
                                                    name="weight{0}".format(i)))
      
        self.sess = session
        self.evaluate = evaluate
        self.steps_per_checkpoint = steps_per_checkpoint 
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.buckets = buckets
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.global_step = tf.Variable(0, trainable=False)
        self.valid_target_length = valid_target_length
        self.phase = phase
        self.visualize = visualize
       
        if phase == 'train':
            self.forward_only = False
        elif phase == 'test':
            self.forward_only = True
        else:
            assert False, phase
        #with tf.device('/gpu:1'):
        with tf.device(gpu_device_id):
            cnn_model = CNN(self.img_data)
            self.conv_output = cnn_model.tf_output()
            self.concat_conv_output = tf.concat(concat_dim=1, values=[self.conv_output, self.zero_paddings])
        
        #with tf.device('/cpu:0'): 
        with tf.device(gpu_device_id): 
            self.perm_conv_output = tf.transpose(self.concat_conv_output, perm=[1, 0, 2])
        
        #with tf.device('/gpu:1'):
        with tf.device(gpu_device_id):
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks = self.encoder_masks,
                encoder_inputs_tensor = self.perm_conv_output, 
                decoder_inputs = self.decoder_inputs,
                target_weights = self.target_weights,
                target_vocab_size = target_vocab_size, 
                buckets = buckets,
                target_embedding_size = target_embedding_size,
                attn_num_layers = attn_num_layers,
                attn_num_hidden = attn_num_hidden,
                forward_only = self.forward_only,
                use_gru = use_gru)
        
        # Gradients and SGD update operation for training the model.
        params_raw = tf.trainable_variables()
        params = []
        params_add = []
        for param in params_raw:
            #if 'running' in param.name or 'conv' in param.name or 'batch' in param.name:
            if 'running' in param.name:
                logging.info('parameter {0} NOT trainable'.format(param.name))
            else:
                logging.info('parameter {0} trainable'.format(param.name))
                params.append(param)
            logging.info(param.get_shape())

        if not self.forward_only:
            #self.gradient_norms = []
            self.updates = []
            with tf.device(gpu_device_id):
            #opt = tf.train.GradientDescentOptimizer(self.learning_rate)
                opt = tf.train.AdadeltaOptimizer(learning_rate=initial_learning_rate, rho=0.95, epsilon=1e-08, use_locking=False, name='Adadelta')
                for b in xrange(len(buckets)):
                    gradients = tf.gradients(self.attention_decoder_model.losses[b], params)
                    #self.gradient_norms.append(norm)
                    self.updates.append(opt.apply_gradients(
                        zip(gradients, params), global_step=self.global_step))
       
            #with tf.device('/gpu:1'):
            with tf.device(gpu_device_id):
                self.keras_updates = []
                for old_value, new_value in cnn_model.model.updates:
                        self.keras_updates.append(tf.assign(old_value, new_value))

        params_raw = tf.all_variables()
        params_init = []
        params_load = []
        for param in params_raw:
            #if 'Adadelta' in param.name and ('batch' in param.name or 'conv' in param.name):
            #    params_add.append(param)
            #if not 'BiRNN' in param.name:
            #    params_load.append(param)
            #else:
            #    params_init.append(param)
            if 'running' in param.name or (('conv' in param.name or 'batch' in param.name) and ('Ada' not in param.name)) or 'BiRNN' in param.name or 'attention' in param.name:
                params_load.append(param)
            else:
                params_init.append(param)

        self.saver_all = tf.train.Saver(tf.all_variables())
        #self.saver = tf.train.Saver(list(set(tf.all_variables())-set(params_add)))
        self.saver = tf.train.Saver(params_load)
        init_new_vars_op = tf.initialize_variables(params_init)

        ckpt = tf.train.get_checkpoint_state(model_dir)
        if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path) and load_model:
            logging.info("Reading model parameters from %s" % ckpt.model_checkpoint_path)
            #self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            logging.info("Created model with fresh parameters.")
            self.sess.run(tf.initialize_all_variables())
        #self.sess.run(init_new_vars_op)

    # train or test as specified by phase
    def launch(self):
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        if self.phase == 'test':
            if not distance_loaded:
                logging.info('Warning: distance module not installed. Do whole sequence comparison instead.')
            else:
                logging.info('Compare word based on edit distance.')
            num_correct = 0
            num_total = 0
            for batch in self.s_gen.gen(self.batch_size):
                # Get a batch and make a step.
                start_time = time.time()
                bucket_id = batch['bucket_id']
                img_data = batch['data']
                zero_paddings = batch['zero_paddings']
                decoder_inputs = batch['decoder_inputs']
                target_weights = batch['target_weights']
                encoder_masks = batch['encoder_mask']
                file_list = batch['filenames']
                real_len = batch['real_len']
                #print (decoder_inputs)
                #print (encoder_masks)
                grounds = [a for a in np.array([decoder_input.tolist() for decoder_input in decoder_inputs]).transpose()]
                _, step_loss, step_logits, step_attns = self.step(encoder_masks, img_data, zero_paddings, decoder_inputs, target_weights, bucket_id, self.forward_only)
                curr_step_time = (time.time() - start_time)
                step_time += curr_step_time / self.steps_per_checkpoint
                logging.info('step_time: %f, step perplexity: %f'%(curr_step_time, math.exp(step_loss) if step_loss < 300 else float('inf')))
                loss += step_loss / self.steps_per_checkpoint
                current_step += 1
                step_outputs = [b for b in np.array([np.argmax(logit, axis=1).tolist() for logit in step_logits]).transpose()]
                if self.visualize:
                    step_attns = np.array([[a.tolist() for a in step_attn] for step_attn in step_attns]).transpose([1, 0, 2])
                    #print (step_attns)
                for idx, output, ground in zip(range(len(grounds)), step_outputs, grounds):
                    flag_ground,flag_out = True,True
                    num_total += 1
                    output_valid = []
                    ground_valid = []
                    for j in range(1,len(ground)):
                        s1 = output[j-1]
                        s2 = ground[j]
                        if s2 != 2 and flag_ground:
                            ground_valid.append(s2)
                        else:
                            flag_ground = False
                        if s1 != 2 and flag_out:
                            output_valid.append(s1)
                        else:
                            flag_out = False
                    if distance_loaded:
                        num_incorrect = distance.levenshtein(output_valid, ground_valid)
                        if self.visualize:
                            self.visualize_attention(file_list[idx], step_attns[idx], output_valid, ground_valid, num_incorrect>0, real_len)
                        num_incorrect = float(num_incorrect) / len(ground_valid)
                        num_incorrect = min(1.0, num_incorrect)
                    else:
                        if output_valid == ground_valid:
                            num_incorrect = 0
                        else:
                            num_incorrect = 1
                        if self.visualize:
                            self.visualize_attention(file_list[idx], step_attns[idx], output_valid, ground_valid, num_incorrect>0, real_len)
                    num_correct += 1. - num_incorrect
                logging.info('%f out of %d correct' %(num_correct, num_total))
        elif self.phase == 'train':
            for epoch in range(self.num_epoch):
                for batch in self.s_gen.gen(self.batch_size):
                    # Get a batch and make a step.
                    start_time = time.time()
                    bucket_id = batch['bucket_id']
                    img_data = batch['data']
                    zero_paddings = batch['zero_paddings']
                    decoder_inputs = batch['decoder_inputs']
                    target_weights = batch['target_weights']
                    encoder_masks = batch['encoder_mask']
                    logging.info('current_step: %d'%current_step)
                    #logging.info(np.array([decoder_input.tolist() for decoder_input in decoder_inputs]).transpose()[0])
                    #print (np.array([target_weight.tolist() for target_weight in target_weights]).transpose()[0])
                    _, step_loss, step_logits, _ = self.step(encoder_masks, img_data, zero_paddings, decoder_inputs, target_weights, bucket_id, self.forward_only)
                    curr_step_time = (time.time() - start_time)
                    step_time += curr_step_time / self.steps_per_checkpoint
                    logging.info('step_time: %f, step perplexity: %f'%(curr_step_time, math.exp(step_loss) if step_loss < 300 else float('inf')))
                    loss += step_loss / self.steps_per_checkpoint
                    current_step += 1
                    # If there is an EOS symbol in outputs, cut them at that point.
                    #if data_utils.EOS_ID in step_outputs:
                    #    step_outputs = step_outputs[:step_outputs.index(data_utils.EOS_ID)]
                    #if data_utils.PAD_ID in decoder_inputs:
                    #decoder_inputs = decoder_inputs[:decoder_inputs.index(data_utils.PAD_ID)]
                    #    print (step_outputs[0])

                    # Once in a while, we save checkpoint, print statistics, and run evals.
                    if current_step % self.steps_per_checkpoint == 0:
                        # Print statistics for the previous epoch.
                        perplexity = math.exp(loss) if loss < 300 else float('inf')
                        logging.info("global step %d step-time %.2f perplexity "
                                    "%.2f" % (self.global_step.eval(),
                                    step_time, perplexity))
                        previous_losses.append(loss)
                        print("epoch: {}, step: {}, loss: {}, perplexity: {}, step_time: {}".format(epoch, current_step, loss, perplexity, curr_step_time))
                        # Save checkpoint and zero timer and loss.
                        if not self.forward_only:
                            checkpoint_path = os.path.join(self.model_dir, "translate.ckpt")
                            logging.info("Saving model, current_step: %d"%current_step)
                            self.saver_all.save(self.sess, checkpoint_path, global_step=self.global_step)
                        step_time, loss = 0.0, 0.0
                        #sys.stdout.flush()

    # step, read one batch, generate gradients
    def step(self, encoder_masks, img_data, zero_paddings, decoder_inputs, target_weights,
               bucket_id, forward_only):
        # Check if the sizes match.
        encoder_size, decoder_size = self.buckets[bucket_id]
        if len(decoder_inputs) != decoder_size:
            raise ValueError("Decoder length must be equal to the one in bucket,"
                    " %d != %d." % (len(decoder_inputs), decoder_size))
        if len(target_weights) != decoder_size:
            raise ValueError("Weights length must be equal to the one in bucket,"
                    " %d != %d." % (len(target_weights), decoder_size))
        
        # Input feed: encoder inputs, decoder inputs, target_weights, as provided.
        input_feed = {}
        if not forward_only:
            input_feed[K.learning_phase()] = 0
        else:
            input_feed[K.learning_phase()] = 0
        input_feed[self.img_data.name] = img_data
        input_feed[self.zero_paddings.name] = zero_paddings
        for l in xrange(decoder_size):
            input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
            input_feed[self.target_weights[l].name] = target_weights[l]
        for l in xrange(encoder_size):
            input_feed[self.encoder_masks[l].name] = encoder_masks[l]
    
        # Since our targets are decoder inputs shifted by one, we need one more.
        last_target = self.decoder_inputs[decoder_size].name
        input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)
    
        # Output feed: depends on whether we do a backward step or not.
        if not forward_only:
            output_feed = [self.updates[bucket_id],  # Update Op that does SGD.
                    #self.gradient_norms[bucket_id],  # Gradient norm.
                    self.attention_decoder_model.losses[bucket_id]]
            output_feed += self.keras_updates
        else:
            output_feed = [self.attention_decoder_model.losses[bucket_id]]  # Loss for this batch.
            for l in xrange(decoder_size):  # Output logits.
                output_feed.append(self.attention_decoder_model.outputs[bucket_id][l])
            if self.visualize:
                output_feed += self.attention_decoder_model.attention_weights_histories[bucket_id]
    
        outputs = self.sess.run(output_feed, input_feed)
        if not forward_only:
            return None, outputs[1], None, None  # Gradient norm, loss, no outputs, no attentions.
        else:
            return None, outputs[0], outputs[1:(1+self.buckets[bucket_id][1])], outputs[(1+self.buckets[bucket_id][1]):]  # No gradient norm, loss, outputs, attentions.
    def visualize_attention(self, filename, attentions, output_valid, ground_valid, flag_incorrect, real_len):
        if flag_incorrect:
            output_dir = os.path.join(self.output_dir, 'incorrect')
        else:
            output_dir = os.path.join(self.output_dir, 'correct')
        output_dir = os.path.join(output_dir, filename.replace('/', '_'))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, 'word.txt'), 'w') as fword:
            fword.write(' '.join([chr(c-13+97) if c-13+97>96 else chr(c-3+48) for c in ground_valid])+'\n')
            fword.write(' '.join([chr(c-13+97) if c-13+97>96 else chr(c-3+48) for c in output_valid]))
            with open(filename, 'rb') as img_file:
                img = Image.open(img_file)
                w, h = img.size
                h = 32
                img = img.resize(
                        (real_len, h),
                        Image.ANTIALIAS)
                img_data = np.asarray(img, dtype=np.uint8)
                for idx in range(len(output_valid)):
                    output_filename = os.path.join(output_dir, 'image_%d.jpg'%(idx))
                    attention = attentions[idx][:(real_len/4-1)]

                    # I have got the attention_orig here, which is of size 32*len(ground_truth), the only thing left is to visualize it and save it to output_filename
                    # TODO here
                    attention_orig = np.zeros(real_len)
                    for i in range(real_len):
                        if 0 < i/4-1 and i/4-1 < len(attention):
                            attention_orig[i] = attention[i/4-1]
                    attention_orig = np.convolve(attention_orig, [0.199547,0.200226,0.200454,0.200226,0.199547], mode='same')
                    attention_orig = np.maximum(attention_orig, 0.3)
                    attention_out = np.zeros((h, real_len))
                    for i in range(real_len):
                        attention_out[:,i] = attention_orig[i]
                    if len(img_data.shape) == 3:
                        attention_out = attention_out[:,:,np.newaxis]
                    img_out_data = img_data * attention_out
                    img_out = Image.fromarray(img_out_data.astype(np.uint8))
                    img_out.save(output_filename)
예제 #10
0
class Model(object):
    def __init__(self,
                 phase,
                 visualize,
                 data_path,
                 data_base_dir,
                 output_dir,
                 batch_size,
                 initial_learning_rate,
                 num_epoch,
                 steps_per_checkpoint,
                 target_vocab_size,
                 model_dir,
                 target_embedding_size,
                 attn_num_hidden,
                 attn_num_layers,
                 clip_gradients,
                 max_gradient_norm,
                 session,
                 load_model,
                 gpu_id,
                 use_gru,
                 evaluate=False,
                 valid_target_length=float('inf'),
                 reg_val=0):

        gpu_device_id = '/gpu:' + str(gpu_id)
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        logging.info('loading data')
        # load data
        if phase == 'train':
            self.s_gen = DataGen(data_base_dir,
                                 data_path,
                                 valid_target_len=valid_target_length,
                                 evaluate=False)
        else:
            batch_size = 1
            self.s_gen = DataGen(data_base_dir, data_path, evaluate=True)

        #logging.info('valid_target_length: %s' %(str(valid_target_length)))
        logging.info('phase: %s' % phase)
        logging.info('model_dir: %s' % (model_dir))
        logging.info('load_model: %s' % (load_model))
        logging.info('output_dir: %s' % (output_dir))
        logging.info('steps_per_checkpoint: %d' % (steps_per_checkpoint))
        logging.info('batch_size: %d' % (batch_size))
        logging.info('num_epoch: %d' % num_epoch)
        logging.info('learning_rate: %d' % initial_learning_rate)
        logging.info('reg_val: %d' % (reg_val))
        logging.info('max_gradient_norm: %f' % max_gradient_norm)
        logging.info('clip_gradients: %s' % clip_gradients)
        logging.info('valid_target_length %f' % valid_target_length)
        logging.info('target_vocab_size: %d' % target_vocab_size)
        logging.info('target_embedding_size: %f' % target_embedding_size)
        logging.info('attn_num_hidden: %d' % attn_num_hidden)
        logging.info('attn_num_layers: %d' % attn_num_layers)
        logging.info('visualize: %s' % visualize)

        buckets = self.s_gen.bucket_specs
        logging.info('buckets')
        logging.info(buckets)
        if use_gru:
            logging.info('ues GRU in the decoder.')

        # variables
        self.img_data = tf.placeholder(tf.float32,
                                       shape=(None, 1, 32, None),
                                       name='img_data')
        self.zero_paddings = tf.placeholder(tf.float32,
                                            shape=(None, None, 512),
                                            name='zero_paddings')

        self.decoder_inputs = []
        self.encoder_masks = []
        self.target_weights = []
        for i in xrange(int(buckets[-1][0] + 1)):
            self.encoder_masks.append(
                tf.placeholder(tf.float32,
                               shape=[None, 1],
                               name="encoder_mask{0}".format(i)))
        for i in xrange(buckets[-1][1] + 1):
            self.decoder_inputs.append(
                tf.placeholder(tf.int32,
                               shape=[None],
                               name="decoder{0}".format(i)))
            self.target_weights.append(
                tf.placeholder(tf.float32,
                               shape=[None],
                               name="weight{0}".format(i)))

        self.reg_val = reg_val
        self.sess = session  #会话,客户端使用会话来和TF系统交互,一般的模式是,建立会话,此时会生成一张空图;在会话中添加节点和边,形成一张图,然后执行。
        self.evaluate = evaluate
        self.steps_per_checkpoint = steps_per_checkpoint
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.buckets = buckets
        self.batch_size = batch_size
        self.num_epoch = num_epoch
        self.global_step = tf.Variable(0, trainable=False)
        self.valid_target_length = valid_target_length
        self.phase = phase
        self.visualize = visualize
        self.learning_rate = initial_learning_rate
        self.clip_gradients = clip_gradients
        self.target_vocab_size = target_vocab_size

        if phase == 'train':
            self.forward_only = False
        elif phase == 'test':
            self.forward_only = True
        else:
            assert False, phase

        with tf.device(gpu_device_id):
            cnn_model = CNN(self.img_data, not self.forward_only)  #(True))
            self.conv_output = cnn_model.tf_output()
            #print ("self.conv_output1:{}".format(self.conv_output.get_shape()))
            self.concat_conv_output = tf.concat(
                axis=1, values=[self.conv_output, self.zero_paddings])
            #print ("self.conv_output2:{}".format(self.concat_conv_output.get_shape()))

            self.perm_conv_output = tf.transpose(self.concat_conv_output,
                                                 perm=[1, 0, 2])

        with tf.device(gpu_device_id):
            self.attention_decoder_model = Seq2SeqModel(
                encoder_masks=self.encoder_masks,
                encoder_inputs_tensor=self.perm_conv_output,
                decoder_inputs=self.decoder_inputs,
                target_weights=self.target_weights,
                target_vocab_size=target_vocab_size,
                buckets=buckets,
                target_embedding_size=target_embedding_size,
                attn_num_layers=attn_num_layers,
                attn_num_hidden=attn_num_hidden,
                forward_only=self.forward_only,
                use_gru=use_gru)

        if not self.forward_only:

            self.updates = []
            self.summaries_by_bucket = []
            with tf.device(gpu_device_id):
                params = tf.trainable_variables()  #返回的是需要训练的变量列表
                # Gradients and SGD update operation for training the model.
                opt = tf.train.AdadeltaOptimizer(
                    learning_rate=initial_learning_rate
                )  #learning_rate=0.001, rho=0.95, epsilon=1e-08
                for b in xrange(len(buckets)):
                    if self.reg_val > 0:
                        reg_losses = tf.get_collection(
                            tf.GraphKeys.REGULARIZATION_LOSSES
                        )  #从一个结合中取出全部变量,是一个列表;损失函数上加上正则项是防止过拟合的
                        logging.info('Adding %s regularization losses',
                                     len(reg_losses))
                        logging.debug('REGULARIZATION_LOSSES: %s', reg_losses)
                        loss_op = self.reg_val * tf.reduce_sum(
                            reg_losses) + self.attention_decoder_model.losses[
                                b]  #tf.reduce_sum 沿着指定轴求和
                    else:
                        loss_op = self.attention_decoder_model.losses[b]

                    gradients, params = zip(
                        *opt.compute_gradients(loss_op, params))
                    if self.clip_gradients:
                        gradients, _ = tf.clip_by_global_norm(
                            gradients, max_gradient_norm)  #函数控制梯度的大小,避免梯度膨胀的问题
                    # Add summaries for loss, variables, gradients, gradient norms and total gradient norm.
                    summaries = []
                    '''
                    for gradient, variable in gradients:
                        if isinstance(gradient, tf.IndexedSlices):
                            grad_values = gradient.values
                        else:
                            grad_values = gradient
                        summaries.append(tf.summary.histogram(variable.name, variable))
                        summaries.append(tf.summary.histogram(variable.name + "/gradients", grad_values))
                        summaries.append(tf.summary.scalar(variable.name + "/gradient_norm",
                                             tf.global_norm([grad_values])))
                    '''
                    summaries.append(tf.summary.scalar("loss", loss_op))
                    summaries.append(
                        tf.summary.scalar("total_gradient_norm",
                                          tf.global_norm(gradients)))
                    all_summaries = tf.summary.merge(summaries)
                    self.summaries_by_bucket.append(all_summaries)
                    # update op - apply gradients
                    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
                    with tf.control_dependencies(update_ops):
                        self.updates.append(
                            opt.apply_gradients(zip(gradients, params),
                                                global_step=self.global_step))
        #
        if self.phase == 'train':
            self.sess.run(tf.initialize_all_variables())

            variables_to_restore = []
            for v in tf.global_variables():
                if not (v.name.startswith(
                        "embedding_attention_decoder/attention_decoder/AttnOutputProjection"
                ) or (v.name.startswith(
                        "embedding_attention_decoder/embedding"))):
                    variables_to_restore.append(v)
            self.saver_all = tf.train.Saver(variables_to_restore)
        else:
            self.saver_all = tf.train.Saver(tf.all_variables())

        ckpt = tf.train.get_checkpoint_state(model_dir)
        if ckpt and load_model:
            logging.info("Reading model parameters from %s" %
                         ckpt.model_checkpoint_path)
            #self.saver.restore(self.sess, ckpt.model_checkpoint_path)
            self.saver_all.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            logging.info("Created model with fresh parameters.")
            self.sess.run(tf.initialize_all_variables())
        #self.sess.run(init_new_vars_op)
        self.saver_all = tf.train.Saver(tf.all_variables())

    def chinese_display(self, filename, output, fw, flag):
        lst = []
        #print (len(output))
        if flag == 1:  #标签
            fw.write(filename[0] + " ")

        for c in output:
            if c > 2 and c < self.target_vocab_size + 1:  #4450:#24:#951:
                #print ("label type:{}".format(type(c))) #<type 'numpy.int64'>
                c = self.s_gen.label_word(c)
                #print ("word type:{}".format(type(c))) #<type 'unicode'>
                #print ("c word:{}".format(c.encode("utf-8")))
                lst.append(c.encode(
                    "utf-8"))  #c.decode("raw_unicode_escape").encode("utf-8")
                fw.write(c)
                #print ("11111111111")
                #print (lst)
        s = "".join(lst)
        if flag == 1:  #标签
            fw.write(" ")
        else:
            fw.write("\n")

        return s

    def label_to_word(self, labels):
        for c in labels:
            if c > 2 and c < self.target_vocab_size + 1:  #4450:#24:#951:
                #print ("label type:{}".format(type(c))) #<type 'numpy.int64'>
                c = self.s_gen.label_word(c)
                #print ("word type:{}".format(type(c))) #<type 'unicode'>
                #print ("c word:{}".format(c.encode("utf-8")))
                lst.append(c.encode(
                    "utf-8"))  #c.decode("raw_unicode_escape").encode("utf-8")
        #words = "".join(lst)

        return lst

    def write_result(self, words, fw, flag):
        for c in words:
            fw.write(c)
        if flag == 1:  #标签
            fw.write(" ")
        else:
            fw.write('\n')

    # train or test as specified by phase
    def launch(self):

        #加载字典
        #trieTree.trie = trieTree.construction_trietree("dict.txt", "lm.binary") #("userdic.txt")

        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        writer = tf.summary.FileWriter(self.model_dir, self.sess.graph)
        if self.phase == 'test':
            if not distance_loaded:
                logging.info(
                    'Warning: distance module not installed. Do whole sequence comparison instead.'
                )
            else:
                logging.info('Compare word based on edit distance.')
            num_correct = 0
            num_total = 0
            num_correct1 = 0
            for batch in self.s_gen.gen(self.batch_size):
                # Get a batch and make a step.
                start_time = time.time()
                bucket_id = batch['bucket_id']
                img_data = batch['data']
                zero_paddings = batch['zero_paddings']
                decoder_inputs = batch['decoder_inputs']
                target_weights = batch['target_weights']
                encoder_masks = batch['encoder_mask']
                file_list = batch['filenames']
                real_len = batch['real_len']

                grounds = [
                    a for a in np.array([
                        decoder_input.tolist()
                        for decoder_input in decoder_inputs
                    ]).transpose()
                ]
                _, step_loss, step_logits, step_attns = self.step(
                    encoder_masks, img_data, zero_paddings, decoder_inputs,
                    target_weights, bucket_id, self.forward_only)
                curr_step_time = (time.time() - start_time)
                step_time += curr_step_time / self.steps_per_checkpoint
                logging.info(
                    'step_time: %f, loss: %f, step perplexity: %f' %
                    (curr_step_time, step_loss,
                     math.exp(step_loss) if step_loss < 300 else float('inf')))
                loss += step_loss / self.steps_per_checkpoint
                current_step += 1
                step_outputs = [
                    b for b in np.array([
                        np.argmax(logit, axis=1).tolist()
                        for logit in step_logits
                    ]).transpose()
                ]
                '''if self.visualize:
                    step_attns = np.array([[a.tolist() for a in step_attn] for step_attn in step_attns]).transpose([1, 0, 2])
                    #print (step_attns)'''

                for idx, output, ground in zip(range(len(grounds)),
                                               step_outputs, grounds):
                    flag_ground, flag_out = True, True
                    num_total += 1
                    output_valid = []
                    ground_valid = []
                    for j in range(1, len(ground)):
                        s1 = output[j - 1]
                        s2 = ground[j]
                        if s2 != 2 and flag_ground:
                            ground_valid.append(s2)
                        else:
                            flag_ground = False
                        if s1 != 2 and flag_out:
                            output_valid.append(s1)
                        else:
                            flag_out = False
                    '''       
                    # 将标签转换成字
                    output_words = self.label_to_word(output_valid)
                    #ground_words = self.label_to_word(ground_valid)
                           
                    # 添加字典矫正,英文准确率可提供10%                                                     
                    s = ''.join(output_words)#s = ''.join([chr(c-13+97) if c-13+97>96 else chr(c-3+48) for c in output_valid])
                    #print ("recongition word:{}".format(s))  #如果是汉字这样显示是有问题的
                     
                    output_valid[:] = []
                    #输出矫正后的word        
                    if trieTree.trie is not None:                         
                        word = trieTree.correct_word(s, 2, trieTree.trie) 
                        #print ("correct word:{}".format(word))
                        for c in word:
                            output_valid.append( self.s_gen.word_label(c) )#ord(c) - 97 + 13 if ord(c) > 96 else ord(c) - 48 + 3 )
                        #print ("output_valid:{}".format(output_valid))'''

                    #print (output)  #[14 13 31 31 35 27 27 16  2  2  2  2  2  2  2  12  2  2  2  2  2  2  23  2  2 2  2  2  2  2  2  2]
                    #print (output_valid) #[14, 13, 31, 31, 35, 27, 27, 16]
                    if distance_loaded:
                        num_incorrect = distance.levenshtein(
                            output_valid, ground_valid)
                        #print ("num_incorrect:{}\n".format(num_incorrect))
                        '''if self.visualize:
                            self.visualize_attention(file_list[idx], step_attns[idx], output_valid, ground_valid, num_incorrect>0, real_len)'''
                        num_incorrect = float(num_incorrect) / len(
                            ground_valid)  #
                        num_incorrect1 = min(1.0, num_incorrect)
                        #print (num_incorrect1)

                        if output_valid == ground_valid:
                            num_incorrect = 0
                        else:
                            num_incorrect = 1

                    else:
                        if output_valid == ground_valid:
                            num_incorrect = 0
                        else:
                            num_incorrect = 1
                        '''if self.visualize:
                            self.visualize_attention(file_list[idx], step_attns[idx], output_valid, ground_valid, num_incorrect>0, real_len)'''
                        #print ("num_incorrect --- :{}\n".format(num_incorrect))
                    num_correct1 += 1. - num_incorrect1
                    #print (num_incorrect)

                    #display chinese
                    #print ("grounds:{}".format(grounds))
                    #print ("ground_valid:{}".format(ground_valid))
                    #print ("output_valid:{}".format(output_valid))

                    if num_incorrect < 1:
                        self.chinese_display(file_list, ground_valid,
                                             fw_correct, 1)
                        self.chinese_display(file_list, output_valid,
                                             fw_correct, 0)
                    else:
                        self.chinese_display(file_list, ground_valid,
                                             fw_imcorrect, 1)
                        self.chinese_display(file_list, output_valid,
                                             fw_imcorrect, 0)

                    if num_incorrect < 1:
                        num_correct = num_correct + 1
                logging.info('%f out of %d correct, precision=%f, %f, %f' %
                             (num_correct, num_total, num_correct /
                              (num_total + 1), num_correct1, num_correct1 /
                              (num_total + 1)))
        elif self.phase == 'train':
            total = (self.s_gen.get_size() // self.batch_size)
            with tqdm(desc='Train: ', total=total) as pbar:
                for epoch in range(self.num_epoch):

                    logging.info('Generating first batch)')
                    precision_ave = 0.0
                    for i, batch in enumerate(self.s_gen.gen(self.batch_size)):
                        # Get a batch and make a step.
                        num_total = 0
                        num_correct = 0
                        start_time = time.time()
                        batch_len = batch['real_len']  #图片的宽度
                        bucket_id = batch['bucket_id']
                        img_data = batch['data']
                        zero_paddings = batch['zero_paddings']
                        decoder_inputs = batch['decoder_inputs']
                        target_weights = batch['target_weights']
                        encoder_masks = batch['encoder_mask']
                        #logging.info('current_step: %d'%current_step)
                        #logging.info(np.array([decoder_input.tolist() for decoder_input in decoder_inputs]).transpose()[0])
                        #print (np.array([target_weight.tolist() for target_weight in target_weights]).transpose()[0])
                        summaries, step_loss, step_logits, _ = self.step(
                            encoder_masks, img_data, zero_paddings,
                            decoder_inputs, target_weights, bucket_id,
                            self.forward_only)

                        grounds = [
                            a for a in np.array([
                                decoder_input.tolist()
                                for decoder_input in decoder_inputs
                            ]).transpose()
                        ]
                        step_outputs = [
                            b for b in np.array([
                                np.argmax(logit, axis=1).tolist()
                                for logit in step_logits
                            ]).transpose()
                        ]

                        for idx, output, ground in zip(range(len(grounds)),
                                                       step_outputs, grounds):
                            flag_ground, flag_out = True, True
                            num_total += 1
                            output_valid = []
                            ground_valid = []
                            for j in range(1, len(ground)):
                                s1 = output[j - 1]
                                s2 = ground[j]
                                if s2 != 2 and flag_ground:
                                    ground_valid.append(s2)
                                else:
                                    flag_ground = False
                                if s1 != 2 and flag_out:
                                    output_valid.append(s1)
                                else:
                                    flag_out = False

                            if distance_loaded:
                                num_incorrect = distance.levenshtein(
                                    output_valid, ground_valid)
                                num_incorrect = float(num_incorrect) / len(
                                    ground_valid)
                                num_incorrect = min(1.0, num_incorrect)
                            else:
                                if output_valid == ground_valid:
                                    num_incorrect = 0
                                else:
                                    num_incorrect = 1
                            num_correct += 1. - num_incorrect

                        writer.add_summary(summaries, current_step)
                        curr_step_time = (time.time() - start_time)
                        step_time += curr_step_time / self.steps_per_checkpoint
                        precision = num_correct / num_total
                        precision_ave = precision_ave + precision  #2000次后统计平均准确率
                        logging.info(
                            'step %f - time: %f, loss: %f, perplexity: %f, precision: %f, batch_len: %f'
                            % (current_step, curr_step_time, step_loss,
                               math.exp(step_loss) if step_loss < 300 else
                               float('inf'), precision, batch_len))
                        loss += step_loss / self.steps_per_checkpoint
                        #pbar.set_description('Train, loss={:.8f}'.format(step_loss))
                        #pbar.update()
                        current_step += 1
                        #print (epoch, current_step)

                        # If there is an EOS symbol in outputs, cut them at that point.
                        #if data_utils.EOS_ID in step_outputs:
                        #    step_outputs = step_outputs[:step_outputs.index(data_utils.EOS_ID)]
                        #if data_utils.PAD_ID in decoder_inputs:
                        #decoder_inputs = decoder_inputs[:decoder_inputs.index(data_utils.PAD_ID)]
                        #    print (step_outputs[0])

                        # Once in a while, we save checkpoint, print statistics, and run evals.
                        if current_step % self.steps_per_checkpoint == 0:
                            # Print statistics for the previous epoch.
                            perplexity = math.exp(
                                loss) if loss < 300 else float('inf')
                            logging.info(
                                "global step %d step-time %.2f loss %f  perplexity %.2f precision_ave %.4f"
                                % (self.global_step.eval(), step_time, loss,
                                   perplexity,
                                   precision_ave / self.steps_per_checkpoint))
                            precision_ave = 0.0
                            previous_losses.append(loss)
                            # Save checkpoint and zero timer and loss.
                            if not self.forward_only:
                                checkpoint_path = os.path.join(
                                    self.model_dir, "translate.ckpt")
                                logging.info("Saving model, current_step: %d" %
                                             current_step)
                                #self.saver_all.save(self.sess, checkpoint_path, global_step=self.global_step)

                                self.saver_all.save(
                                    self.sess,
                                    checkpoint_path,
                                    global_step=self.global_step)
                            step_time, loss = 0.0, 0.0
                            #sys.stdout.flush()

    # step, read one batch, generate gradients
    def step(self, encoder_masks, img_data, zero_paddings, decoder_inputs,
             target_weights, bucket_id, forward_only):
        # Check if the sizes match.
        encoder_size, decoder_size = self.buckets[bucket_id]
        if len(decoder_inputs) != decoder_size:
            raise ValueError(
                "Decoder length must be equal to the one in bucket,"
                " %d != %d." % (len(decoder_inputs), decoder_size))
        if len(target_weights) != decoder_size:
            raise ValueError(
                "Weights length must be equal to the one in bucket,"
                " %d != %d." % (len(target_weights), decoder_size))

        # Input feed: encoder inputs, decoder inputs, target_weights, as provided.
        input_feed = {}
        input_feed[self.img_data.name] = img_data
        input_feed[self.zero_paddings.name] = zero_paddings
        for l in xrange(decoder_size):
            input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
            input_feed[self.target_weights[l].name] = target_weights[l]
        for l in xrange(int(encoder_size)):
            try:
                input_feed[self.encoder_masks[l].name] = encoder_masks[l]
            except Exception as e:
                pass
                #ipdb.set_trace()

        # Since our targets are decoder inputs shifted by one, we need one more.
        last_target = self.decoder_inputs[decoder_size].name
        input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)

        # Output feed: depends on whether we do a backward step or not.
        if not forward_only:  #train
            output_feed = [
                self.updates[bucket_id],  # Update Op that does SGD.
                #self.gradient_norms[bucket_id],  # Gradient norm.
                self.attention_decoder_model.losses[bucket_id],
                self.summaries_by_bucket[bucket_id]
            ]

            for l in xrange(decoder_size):  # Output logits.
                output_feed.append(
                    self.attention_decoder_model.outputs[bucket_id][l])
        else:
            output_feed = [self.attention_decoder_model.losses[bucket_id]
                           ]  # Loss for this batch.
            #output_feed.append(self.conv_output1)#perm_conv_output)# img_data)
            for l in xrange(decoder_size):  # Output logits.
                output_feed.append(
                    self.attention_decoder_model.outputs[bucket_id][l])
            if self.visualize:
                output_feed += self.attention_decoder_model.attention_weights_histories[
                    bucket_id]

        outputs = self.sess.run(output_feed, input_feed)  #(dict,list)
        #print ("output_feed:{}\n".format(len(output_feed))) #list, w=96, len(output)=20
        #print ("img_data size:{}\n".format(img_data.shape))  #img_data size:(1, 1, 32, 100)
        #print ("cnn:{}\n".format(outputs[1][0][0][:10])) #list  [27,?,512]
        '''size1 = outputs[1].shape
        print ("cnn:{}\n".format(size1))
        for i in range(size1[3]):
            fw_data.write(str(outputs[1][0][0][0][i])[:5] + " ")
        fw_data.write("\n")'''
        '''
        for i in range(32):
           for j in range(100):
                fw_data.write(str(img_data[0][0][i][j]) + " ")
        fw_data.write("\n")'''

        #for i in range():
        #print ("cnn:{}\n".format(outputs[1][:10][0][0]))

        if not forward_only:
            return outputs[2], outputs[1], outputs[3:(
                3 + self.buckets[bucket_id][1]
            )], None  # Gradient norm summary, loss, no outputs, no attentions.
        else:
            return None, outputs[0], outputs[1:(
                1 + self.buckets[bucket_id][1])], outputs[(
                    1 + self.buckets[bucket_id][1]
                ):]  # No gradient norm, loss, outputs, attentions.

    def visualize_attention(self, filename, attentions, output_valid,
                            ground_valid, flag_incorrect, real_len):
        if flag_incorrect:
            output_dir = os.path.join(self.output_dir, 'incorrect')
        else:
            output_dir = os.path.join(self.output_dir, 'correct')
        output_dir = os.path.join(output_dir, filename.replace('/', '_'))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(os.path.join(output_dir, 'word.txt'), 'w') as fword:
            #fword.write(' '.join([chr(c-13+97) if c-13+97>96 else chr(c-3+48) for c in ground_valid])+'\n')
            #fword.write(' '.join([chr(c-13+97) if c-13+97>96 else chr(c-3+48) for c in output_valid]))
            #fword.write(' '.join([chinese(c).decode("raw_unicode_escape").encode("utf-8")  for c in ground_valid])+'\n')
            #fword.write(' '.join([chinese(c).decode("raw_unicode_escape").encode("utf-8")  for c in output_valid]))
            with open(filename, 'rb') as img_file:
                img = Image.open(img_file)
                w, h = img.size
                h = 32
                img = img.resize((real_len, h), Image.ANTIALIAS)
                img_data = np.asarray(img, dtype=np.uint8)
                for idx in range(len(output_valid)):
                    output_filename = os.path.join(output_dir,
                                                   'image_%d.jpg' % (idx))
                    attention = attentions[idx][:(int(real_len / 4) - 1)]

                    # I have got the attention_orig here, which is of size 32*len(ground_truth), the only thing left is to visualize it and save it to output_filename
                    # TODO here
                    attention_orig = np.zeros(real_len)
                    for i in range(real_len):
                        if 0 < i / 4 - 1 and i / 4 - 1 < len(attention):
                            attention_orig[i] = attention[int(i / 4) - 1]
                    attention_orig = np.convolve(
                        attention_orig,
                        [0.199547, 0.200226, 0.200454, 0.200226, 0.199547],
                        mode='same')
                    attention_orig = np.maximum(attention_orig, 0.3)
                    attention_out = np.zeros((h, real_len))
                    for i in range(real_len):
                        attention_out[:, i] = attention_orig[i]
                    if len(img_data.shape) == 3:
                        attention_out = attention_out[:, :, np.newaxis]
                    img_out_data = img_data * attention_out
                    img_out = Image.fromarray(img_out_data.astype(np.uint8))
                    img_out.save(output_filename)