Пример #1
0
	def build_model_m1(self):

		self.xu = tf.placeholder(tf.float32, shape=[None,] + self.input_shape, name='xu_input')

		###########################################################################
		# network define
		# 
		# x_encoder : x -> hx
		self.config['x encoder params']['name'] = 'EncoderHX_X'
		self.config['x encoder params']["output dims"] = self.hx_dim
		self.x_encoder = get_encoder(self.config['x encoder'], 
									self.config['x encoder params'], self.is_training)
		# decoder : hx -> x
		self.config['hx decoder params']['name'] = 'DecoderX_HX'
		# if self.config
		# self.config['hx decoder params']["output dims"] = int(np.product(self.input_shape))
		self.hx_decoder = get_decoder(self.config['hx decoder'], self.config['hx decoder params'], self.is_training)

		###########################################################################
		# for unsupervised training:
		# 
		# xu --> mean_hxu, log_var_hxu ==> kl loss
		#					|
		# 			   sample_hxu --> xu_decode ==> reconstruction loss
		mean_hxu, log_var_hxu = self.x_encoder(self.xu)
		sample_hxu = self.draw_sample(mean_hxu, log_var_hxu)
		xu_decode = self.hx_decoder(sample_hxu)

		self.m1_loss_kl_z = (get_loss('kl', 'gaussian', {'mean' : mean_hxu, 'log_var' : log_var_hxu})
								* self.m1_loss_weights.get('kl z loss weight', 1.0))
		self.m1_loss_recon = (get_loss('reconstruction', 'mse', {'x' : self.xu, 'y' : xu_decode})
								* self.m1_loss_weights.get('reconstruction loss weight', 1.0))
		self.m1_loss = self.m1_loss_kl_z + self.m1_loss_recon


		###########################################################################
		# optimizer configure
		self.m1_global_step, m1_global_step_update = get_global_step('m1_step')

		(self.m1_train_op, 
			self.m1_learning_rate, 
				_) = get_optimizer_by_config(self.config['m1 optimizer'], self.config['m1 optimizer params'], 
													self.m1_loss, self.m1_vars, self.m1_global_step, m1_global_step_update)
def train_single(logger, tb_log_dir):
    """
    :param logger:
    :param tb_log_dir:
    :return:
    """
    # prepare dataset
    if config.USE_STN:
        tfrecords_dir = config.DATA_DIR + '_stn'
    else:
        tfrecords_dir = config.DATA_DIR

    # Set up summary writer
    global_summaries = set([])
    summary_writer = tf.summary.FileWriter(tb_log_dir)

    # Set up data provider
    char_dict_path = config.CHAR_DICT
    ord_map_dict_path = config.ORD_MAP_DICT
    train_images, train_labels, train_images_paths, \
    val_images, val_labels, val_images_paths = data.get_data(tfrecords_dir, char_dict_path, ord_map_dict_path)

    # Set up convert
    convert = tf_io_pipline_fast_tools.FeatureReader(
        char_dict_path=char_dict_path, ord_map_dict_path=ord_map_dict_path)

    # Set up network graph
    train_encoder_net = get_encoder('train')
    train_decoder_net = get_decoder('train')

    val_encoder_net = get_encoder('val')
    val_decoder_net = get_decoder('val')

    gpu_list = config.GPUS
    gpu_list = gpu_list.split(',')

    device_name = '/gpu:{}'.format(gpu_list[0])
    with tf.device(device_name):
        with tf.name_scope('Train') as train_scope:
            with tf.variable_scope(tf.get_variable_scope(), reuse=False):
                train_encoder_out = train_encoder_net.forward(
                    train_images, name=config.ENCODER.NETWORKTYPE + 'Encoder')
                train_decoder_net.set_label(train_labels)
                train_decoder_out = train_decoder_net.predict(
                    train_encoder_out)
                train_loss = train_decoder_net.loss(train_decoder_out)
                train_sequence_dist = train_decoder_net.sequence_dist(
                    train_decoder_out)

        with tf.name_scope('Val') as _:
            with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                val_encoder_out = val_encoder_net.forward(
                    val_images, name=config.ENCODER.NETWORKTYPE + 'Encoder')
                val_decoder_net.set_label(val_labels)
                val_decoder_out = val_decoder_net.predict(val_encoder_out)
                val_loss = val_decoder_net.loss(val_decoder_out)
                val_sequence_dist = val_decoder_net.sequence_dist(
                    val_decoder_out)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            # set optimizer
            optimizer = optimizer_builder.build(global_summaries)
            optimizer = optimizer.minimize(
                loss=train_loss,
                global_step=tf.train.get_or_create_global_step())

    # Gather initial summaries.
    global_summaries.add(
        tf.summary.scalar(name='train_loss', tensor=train_loss))
    global_summaries.add(tf.summary.scalar(name='val_loss', tensor=val_loss))
    global_summaries.add(
        tf.summary.scalar(name='train_seq_distance',
                          tensor=train_sequence_dist))
    global_summaries.add(
        tf.summary.scalar(name='val_seq_distance', tensor=val_sequence_dist))

    # Set saver configuration
    saver = tf.train.Saver()
    train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                     time.localtime(time.time()))
    model_name = '{:s}_{:s}.ckpt'.format(config.MODEL.NAME,
                                         str(train_start_time))
    model_save_path = os.path.join(config.MODEL_SAVE_DIR, model_name)

    # Set sess configuration
    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.TRAIN.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.TRAIN.TF_ALLOW_GROWTH

    #
    sess = tf.Session(config=sess_config)

    # Merge all summaries together.
    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES, train_scope))
    summaries |= global_summaries
    merge_summary_op = tf.summary.merge(list(summaries), name='summary_op')

    summary_writer.add_graph(sess.graph)
    # Set the training parameters
    train_epochs = config.TRAIN.EPOCHS

    with sess.as_default():
        epoch = 0
        tf.train.write_graph(graph_or_graph_def=sess.graph,
                             logdir='',
                             name='{:s}/single_model.pb'.format(
                                 config.MODEL_SAVE_DIR))
        if not config.RESUME_PATH:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            logger.info('Restore model from {:s}'.format(config.RESUME_PATH))
            saver.restore(sess=sess, save_path=config.TRAIN.RESUME)
            epoch = sess.run(tf.train.get_global_step())

        patience_counter = 1
        cost_history = [np.inf]
        while epoch < train_epochs:
            epoch += 1
            # setup early stopping
            if epoch > 1 and config.TRAIN.EARLY_STOPPING:
                # We always compare to the first point where cost didn't improve
                if cost_history[-1 - patience_counter] - cost_history[
                        -1] > config.TRAIN.PATIENCE_DELTA:
                    patience_counter = 1
                else:
                    patience_counter += 1
                if patience_counter > config.TRAIN.PATIENCE_EPOCHS:
                    logger.info(
                        "Cost didn't improve beyond {:f} for {:d} epochs, stopping early."
                        .format(config.TRAIN.PATIENCE_DELTA, patience_counter))
                    break

            if config.TRAIN.DECODE and epoch % 500 == 0:
                # train part
                _, train_loss_value, train_decoder_out_dict, train_seq_dist_value, train_labels_sparse, \
                merge_summary_value = sess.run(
                    [optimizer, train_loss, train_decoder_out, train_sequence_dist, train_labels, merge_summary_op])

                avg_train_accuracy = compute_avg_accuracy(
                    convert, train_labels_sparse, train_decoder_out_dict)
                if epoch % config.TRAIN.DISPLAY_STEP == 0:
                    logger.info(
                        'Epoch_Train: {:d} cost= {:9f} seq distance= {:9f} train accuracy= {:9f}'
                        .format(epoch + 1, train_loss_value,
                                train_seq_dist_value, avg_train_accuracy))

                # validation part
                val_loss_value, val_decoder_out_dict, val_seq_dist_value, val_labels_sparse = sess.run(
                    [val_loss, val_decoder_out, val_sequence_dist, val_labels])

                avg_val_accuracy = compute_avg_accuracy(
                    convert, val_labels_sparse, val_decoder_out_dict)
                if epoch % config.TRAIN.VAL_DISPLAY_STEP == 0:
                    logger.info(
                        'Epoch_Val: {:d} cost= {:9f} seq distance= {:9f} val accuracy= {:9f}'
                        .format(epoch + 1, val_loss_value, val_seq_dist_value,
                                avg_val_accuracy))

                summary_fly = tf.Summary(value=[
                    tf.Summary.Value(tag='acc_train',
                                     simple_value=avg_train_accuracy),
                    tf.Summary.Value(tag='acc_val',
                                     simple_value=avg_val_accuracy),
                ])
                summary_writer.add_summary(summary=summary_fly,
                                           global_step=epoch)
            else:
                _, train_loss_value, merge_summary_value = sess.run(
                    [optimizer, train_loss, merge_summary_op])

                if epoch % config.TRAIN.DISPLAY_STEP == 0:
                    logger.info('Epoch_Train: {:d} cost= {:9f}'.format(
                        epoch + 1, train_loss_value))

            # record history train ctc loss
            cost_history.append(train_loss_value)
            # add training sumary
            summary_writer.add_summary(summary=merge_summary_value,
                                       global_step=epoch)

            if epoch % 2000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)

    return np.array(cost_history[1:])  # Don't return the first np.inf
def train_multi(logger, tb_log_dir):
    """
    :param logger:
    :param tb_log_dir:
    :return:
    """
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # prepare dataset
        if config.USE_STN:
            tfrecords_dir = config.DATA_DIR + '_stn'
        else:
            tfrecords_dir = config.DATA_DIR

        # Set up summary writer
        global_summaries = set([])
        summary_writer = tf.summary.FileWriter(tb_log_dir)

        # Set up data provider
        char_dict_path = config.CHAR_DICT
        ord_map_dict_path = config.ORD_MAP_DICT
        train_images, train_labels, train_images_paths, \
        val_images, val_labels, val_images_paths = data.get_data(tfrecords_dir, char_dict_path, ord_map_dict_path)

        # Set up network graph
        train_encoder_net = get_encoder('train')
        train_decoder_net = get_decoder('train')

        val_encoder_net = get_encoder('val')
        val_decoder_net = get_decoder('val')

        # set average container
        tower_grads = []
        train_scopes = []
        train_tower_loss = []
        val_tower_loss = []
        batchnorm_updates = None

        # Set up optimizer
        optimizer = optimizer_builder.build(global_summaries)

        gpu_list = config.GPUS
        gpu_list = gpu_list.split(',')
        # set distributed train op

        is_network_initialized = False
        for i, gpu_id in enumerate(gpu_list):
            with tf.device('/gpu:{}'.format(gpu_id)):
                with tf.name_scope('Train_{:d}'.format(i)) as train_scope:
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=is_network_initialized):
                        train_loss, grads = compute_net_gradients(
                            train_images, train_labels, train_encoder_net,
                            train_decoder_net, optimizer)

                        is_network_initialized = True
                        train_scopes.append(train_scope)

                        # Only use the mean and var in the first gpu tower to update the parameter
                        if i == 0:
                            batchnorm_updates = tf.get_collection(
                                tf.GraphKeys.UPDATE_OPS)

                        tower_grads.append(grads)
                        train_tower_loss.append(train_loss)
                with tf.name_scope('Val_{:d}'.format(i)) as _:
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=is_network_initialized):
                        val_loss, _ = compute_net_gradients(
                            val_images, val_labels, val_encoder_net,
                            val_decoder_net)
                        val_tower_loss.append(val_loss)

        with tf.name_scope('Average_Grad'):
            grads = average_gradients(tower_grads)
        with tf.name_scope('Average_Loss'):
            avg_train_loss = tf.reduce_mean(train_tower_loss)
            avg_val_loss = tf.reduce_mean(val_tower_loss)

        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
            global_summaries.add(tf.summary.histogram(var.op.name, var))

        # Track the moving averages of all trainable variables
        variable_averages = tf.train.ExponentialMovingAverage(
            config.TRAIN.MOVING_AVERAGE_DECAY,
            num_updates=tf.train.get_or_create_global_step())
        variables_to_average = tf.trainable_variables(
        ) + tf.moving_average_variables()
        variables_averages_op = variable_averages.apply(variables_to_average)

        # Group all the op needed for training
        batchnorm_updates_op = tf.group(*batchnorm_updates)
        apply_gradient_op = optimizer.apply_gradients(
            grads, global_step=tf.train.get_or_create_global_step())
        train_op = tf.group(apply_gradient_op, variables_averages_op,
                            batchnorm_updates_op)

        global_summaries.add(
            tf.summary.scalar(name='average_train_loss',
                              tensor=avg_train_loss))

        # Merge all summaries together.
        summaries = set(
            tf.get_collection(tf.GraphKeys.SUMMARIES, train_scopes[0]))
        summaries |= global_summaries
        train_merge_summary_op = tf.summary.merge(list(summaries),
                                                  name='train_summary_op')
        val_merge_summary_op = tf.summary.merge(
            [tf.summary.scalar(name='average_val_loss', tensor=avg_val_loss)],
            name='val_summary_op')

        # Set saver configuration
        saver = tf.train.Saver()
        train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                                         time.localtime(time.time()))
        model_name = '{:s}_{:s}.ckpt'.format(config.MODEL.NAME,
                                             str(train_start_time))
        model_save_path = os.path.join(config.MODEL_SAVE_DIR, model_name)

        # set sess config
        sess_config = tf.ConfigProto(device_count={'GPU': len(gpu_list)},
                                     allow_soft_placement=True)
        sess_config.gpu_options.per_process_gpu_memory_fraction = config.TRAIN.GPU_MEMORY_FRACTION
        sess_config.gpu_options.allow_growth = config.TRAIN.TF_ALLOW_GROWTH
        sess_config.gpu_options.allocator_type = 'BFC'

        # Set the training parameters
        train_epochs = config.TRAIN.EPOCHS

        logger.info('Global configuration is as follows:')
        logger.info(config)

        sess = tf.Session(config=sess_config)

        summary_writer.add_graph(sess.graph)

        epoch = 0
        tf.train.write_graph(graph_or_graph_def=sess.graph,
                             logdir='',
                             name='{:s}/multi_model.pb'.format(
                                 config.MODEL_SAVE_DIR))
        if not config.RESUME_PATH:
            logger.info('Training from scratch')
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            logger.info('Restore model from {:s}'.format(config.RESUME_PATH))
            saver.restore(sess=sess, save_path=config.TRAIN.RESUME)
            epoch = sess.run(tf.train.get_global_step())

        train_cost_time_mean = []
        val_cost_time_mean = []

        while epoch < train_epochs:
            epoch += 1
            # training part
            t_start = time.time()

            _, train_loss_value, train_summary = sess.run(
                fetches=[train_op, avg_train_loss, train_merge_summary_op])

            if math.isnan(train_loss_value):
                raise ValueError('Train loss is nan')

            cost_time = time.time() - t_start
            train_cost_time_mean.append(cost_time)

            summary_writer.add_summary(summary=train_summary,
                                       global_step=epoch)

            # validation part
            t_start_val = time.time()

            val_loss_value, val_summary = sess.run(
                fetches=[avg_val_loss, val_merge_summary_op])
            cost_time_val = time.time() - t_start_val
            val_cost_time_mean.append(cost_time_val)

            summary_writer.add_summary(val_summary, global_step=epoch)

            if epoch % config.TRAIN.DISPLAY_STEP == 0:
                logger.info(
                    'Epoch_Train: {:d} total_loss= {:6f} mean_cost_time= {:5f}s '
                    .format(epoch + 1, train_loss_value,
                            np.mean(train_cost_time_mean)))
                train_cost_time_mean.clear()

            if epoch % config.TRAIN.VAL_DISPLAY_STEP == 0:
                logger.info(
                    'Epoch_Val: {:d} total_loss= {:6f} mean_cost_time= {:5f}s '
                    .format(epoch + 1, val_loss_value,
                            np.mean(val_cost_time_mean)))
                val_cost_time_mean.clear()

            if epoch % 5000 == 0:
                saver.save(sess=sess,
                           save_path=model_save_path,
                           global_step=epoch)
        sess.close()

    return
Пример #4
0
    def build_model(self):

        self.xu = tf.placeholder(tf.float32,
                                 shape=[
                                     None,
                                 ] + self.input_shape,
                                 name='xu_input')
        self.xl = tf.placeholder(tf.float32,
                                 shape=[
                                     None,
                                 ] + self.input_shape,
                                 name='xl_input')
        self.yl = tf.placeholder(tf.float32,
                                 shape=[None, self.nb_classes],
                                 name='yl_input')

        ###########################################################################
        # network define
        #
        # x_encoder : x -> hx
        self.config['x encoder params']['name'] = 'EncoderHX_X'
        self.config['x encoder params']["output dims"] = self.hx_dim
        self.x_encoder = get_encoder(self.config['x encoder'],
                                     self.config['x encoder params'],
                                     self.is_training)
        # hx_y_encoder : [hx, y] -> hz
        self.config['hx y encoder params']['name'] = 'EncoderHZ_HXY'
        self.config['hx y encoder params']["output dims"] = self.hz_dim
        self.hx_y_encoder = get_encoder(self.config['hx y encoder'],
                                        self.config['hx y encoder params'],
                                        self.is_training)
        # hz_y_decoder : [hz, y] -> x_decode
        self.config['hz y decoder params']['name'] = 'DecoderX_HZY'
        self.config['hz y decoder params']["output dims"] = int(
            np.product(self.input_shape))
        self.hz_y_decoder = get_decoder(self.config['hz y decoder'],
                                        self.config['hz y decoder params'],
                                        self.is_training)
        # x_classifier : hx -> ylogits
        self.config['x classifier params']['name'] = 'ClassifierX'
        self.config['x classifier params']["output dims"] = self.nb_classes
        self.x_classifier = get_classifier(self.config['x classifier'],
                                           self.config['x classifier params'],
                                           self.is_training)

        ###########################################################################
        # for supervised training:
        #
        # xl --> mean_hxl, log_var_hxl
        #		  		  |
        #			 sample_hxl --> yllogits ==> classification loss
        #				  |
        # 			[sample_hxl, yl] --> mean_hzl, log_var_hzl ==> kl loss
        #				          |               |
        # 	  			        [yl,	   	   sample_hzl] --> xl_decode ==> reconstruction loss
        #

        hxl = self.x_encoder(self.xl)
        mean_hzl, log_var_hzl = self.hx_y_encoder(
            tf.concat([hxl, self.yl], axis=1))
        sample_hzl = self.draw_sample(mean_hzl, log_var_hzl)
        decode_xl = self.hz_y_decoder(tf.concat([sample_hzl, self.yl], axis=1))
        # decode_xl = self.hx_decoder(decode_hxl)

        yllogits = self.x_classifier(self.xl)

        self.su_loss_kl_z = (get_loss('kl', 'gaussian', {
            'mean': mean_hzl,
            'log_var': log_var_hzl,
        }) * self.loss_weights.get('kl z loss weight', 1.0))
        self.su_loss_recon = (get_loss('reconstruction', 'mse', {
            'x': self.xl,
            'y': decode_xl
        }) * self.loss_weights.get('reconstruction loss weight', 1.0))
        self.su_loss_cls = (get_loss('classification', 'cross entropy', {
            'logits': yllogits,
            'labels': self.yl
        }) * self.loss_weights.get('classiciation loss weight', 1.0))

        self.su_loss_reg = (
            get_loss('regularization', 'l2',
                     {'var_list': self.x_classifier.vars}) *
            self.loss_weights.get('regularization loss weight', 0.0001))

        self.su_loss = ((self.su_loss_kl_z + self.su_loss_recon +
                         self.su_loss_cls + self.su_loss_reg) *
                        self.loss_weights.get('supervised loss weight', 1.0))

        ###########################################################################
        # for unsupervised training:
        #
        # xu --> mean_hxu, log_var_hxu
        #                |
        #             sample_hxu --> yulogits --> yuprobs
        # 				  |
        #   		 [sample_hxu,    y0] --> mean_hzu0, log_var_hzu0 ==> kl_loss * yuprobs[0]
        # 				  |			  |					|
        #				  |			[y0,           sample_hzu0] --> decode_hxu0 ==> reconstruction loss * yuprobs[0]
        #				  |
        #   	     [sample_hxu,    y1] --> mean_hzu1, log_var_hzu1 ==> kl_loss * yuprobs[1]
        #				  |			  |			        |
        #				  |			[y1,           sample_hzu1] --> decode_hxu1 ==> reconstruction loss * yuprobs[1]
        #		.......
        #
        hxu = self.x_encoder(self.xu)
        yulogits = self.x_classifier(self.xu)
        yuprobs = tf.nn.softmax(yulogits)

        unsu_loss_kl_z_list = []
        unsu_loss_recon_list = []

        for i in range(self.nb_classes):
            yu_fake = tf.ones([
                tf.shape(self.xu)[0],
            ], dtype=tf.int32) * i
            yu_fake = tf.one_hot(yu_fake, depth=self.nb_classes)

            mean_hzu, log_var_hzu = self.hx_y_encoder(
                tf.concat([hxu, yu_fake], axis=1))
            sample_hzu = self.draw_sample(mean_hzu, log_var_hzu)
            decode_xu = self.hz_y_decoder(
                tf.concat([sample_hzu, yu_fake], axis=1))
            # decode_xu = self.hx_decoder(decode_hxu)

            unsu_loss_kl_z_list.append(
                get_loss(
                    'kl', 'gaussian', {
                        'mean': mean_hzu,
                        'log_var': log_var_hzu,
                        'instance_weight': yuprobs[:, i]
                    }))

            unsu_loss_recon_list.append(
                get_loss('reconstruction', 'mse', {
                    'x': self.xu,
                    'y': decode_xu,
                    'instance_weight': yuprobs[:, i]
                }))

        self.unsu_loss_kl_y = (
            get_loss('kl', 'bernoulli', {'probs': yuprobs}) *
            self.loss_weights.get('kl y loss weight', 1.0))
        self.unsu_loss_kl_z = (tf.reduce_sum(unsu_loss_kl_z_list) *
                               self.loss_weights.get('kl z loss weight', 1.0))
        self.unsu_loss_recon = (
            tf.reduce_sum(unsu_loss_recon_list) *
            self.loss_weights.get('reconstruction loss weight', 1.0))

        self.unsu_loss_reg = (
            get_loss('regularization', 'l2',
                     {'var_list': self.x_classifier.vars}) *
            self.loss_weights.get('regularization loss weight', 0.0001))

        self.unsu_loss = (
            (self.unsu_loss_kl_z + self.unsu_loss_recon + self.unsu_loss_kl_y +
             self.unsu_loss_reg) *
            self.loss_weights.get('unsupervised loss weight', 1.0))

        self.xt = tf.placeholder(tf.float32,
                                 shape=[
                                     None,
                                 ] + self.input_shape,
                                 name='xt_input')

        ###########################################################################
        # for test models
        #
        # xt --> mean_hxt, log_var_hxt
        #               |
        #             sample_hxt --> ytlogits --> ytprobs
        # 			   |			    			 |
        #		     [sample_hxt,    			  ytprobs] --> mean_hzt, log_var_hzt
        #
        hxt = self.x_encoder(self.xt)
        ytlogits = self.x_classifier(self.xt)
        self.ytprobs = tf.nn.softmax(ytlogits)
        self.mean_hzt, self.log_var_hzt = self.hx_y_encoder(
            tf.concat([hxt, self.ytprobs], axis=1))

        ###########################################################################
        # optimizer configure

        global_step, global_step_update = get_global_step()

        (self.supervised_train_op, self.supervised_learning_rate,
         _) = get_optimizer_by_config(self.config['optimizer'],
                                      self.config['optimizer params'],
                                      self.su_loss, self.vars, global_step,
                                      global_step_update)
        (self.unsupervised_train_op, self.unsupervised_learning_rate,
         _) = get_optimizer_by_config(self.config['optimizer'],
                                      self.config['optimizer parmas'],
                                      self.unsu_loss, self.vars, global_step,
                                      global_step_update)

        ###########################################################################
        # model saver
        self.saver = tf.train.Saver(self.vars + [
            self.global_step,
        ])
Пример #5
0
    def build_model(self):

        self.x_real = tf.placeholder(
            tf.float32,
            shape=[None, np.product(self.input_shape)],
            name='x_input')
        self.y_real = tf.placeholder(tf.float32,
                                     shape=[None, self.nb_classes],
                                     name='y_input')

        # self.encoder_input_shape = int(np.product(self.input_shape))

        self.config['encoder parmas']['name'] = 'EncoderX'
        self.config['encoder params']["output dims"] = self.z_dim
        self.encoder = get_encoder(self.config['x encoder'],
                                   self.config['encoder params'],
                                   self.is_training)

        self.config['decoder params']['name'] = 'Decoder'
        self.config['decoder params']["output dims"] = self.encoder_input_shape

        # self.y_encoder = get_encoder(self.config['y encoder'], self.config['y encoder params'], self.is_training)
        self.decoder = get_decoder(self.config['decoder'],
                                   self.config['decoder params'],
                                   self.is_training)

        # build encoder
        self.z_mean, self.z_log_var = self.x_encoder(
            tf.concatenate([self.x_real, self.y_real]))
        self.z_mean_y = self.y_encoder(self.y_real)

        # sample z from z_mean and z_log_var
        self.z_sample = self.draw_sample(self.z_mean, self.z_log_var)

        # build decoder
        self.x_decode = self.decoder(self.z_sample)

        # build test decoder
        self.z_test = tf.placeholder(tf.float32,
                                     shape=[None, self.z_dim],
                                     name='z_test')
        self.x_test = self.decoder(self.z_test, reuse=True)

        # loss function
        self.kl_loss = (get_loss(
            'kl', self.config['kl loss'], {
                'z_mean': (self.z_mean - self.z_mean_y),
                'z_log_var': self.z_log_var
            }) * self.config.get('kl loss prod', 1.0))
        self.xent_loss = (
            get_loss('reconstruction', self.config['reconstruction loss'], {
                'x': self.x_real,
                'y': self.x_decode
            }) * self.config.get('reconstruction loss prod', 1.0))
        self.loss = self.kl_loss + self.xent_loss

        # optimizer configure
        self.global_step, self.global_step_update = get_global_step()
        if 'lr' in self.config:
            self.learning_rate = get_learning_rate(self.config['lr_scheme'],
                                                   float(self.config['lr']),
                                                   self.global_step,
                                                   self.config['lr_params'])
            self.optimizer = get_optimizer(
                self.config['optimizer'],
                {'learning_rate': self.learning_rate}, self.loss,
                self.decoder.vars + self.x_encoder.vars + self.y_encoder.vars)
        else:
            self.optimizer = get_optimizer(
                self.config['optimizer'], {}, self.loss,
                self.decoder.vars + self.x_encoder.vars + self.y_encoder.vars)

        self.train_update = tf.group([self.optimizer, self.global_step_update])

        # model saver
        self.saver = tf.train.Saver(self.x_encoder.vars + self.y_encoder.vars,
                                    self.decoder.vars + [
                                        self.global_step,
                                    ])
Пример #6
0
 def _build_decoder(self, name, params=None):
     net_config = self.config[name + ' params'].copy()
     if params is not None:
         net_config.update(params)
     return get_decoder(self.config[name], net_config, self.is_training)
def recognize(image_path, weights_path, is_vis):
    """

    :param image_path:
    :param weights_path:
    :param is_vis:
    :return:
    """
    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    if config.USE_STN:
        new_heigth = config.STN.IH
        new_width = config.STN.IW
        image = cv2.resize(image,
                           dsize=(config.STN.IW, config.STN.IH),
                           interpolation=cv2.INTER_LINEAR)
    else:
        new_heigth = config.ENCODER.IH
        new_width = config.ENCODER.IW
        image = cv2.resize(image,
                           dsize=(config.ENCODER.IW, config.ENCODER.IH),
                           interpolation=cv2.INTER_LINEAR)
    image_vis = image

    # Set up data placeholder
    char_dict_path = config.CHAR_DICT
    ord_map_dict_path = config.ORD_MAP_DICT
    inputdata = tf.placeholder(
        dtype=tf.float32,
        shape=[1, new_heigth, new_width, config.ENCODER.INPUT_CHANNELS],
        name='input')

    # Set up convert
    convert = tf_io_pipline_fast_tools.FeatureReader(
        char_dict_path=char_dict_path, ord_map_dict_path=ord_map_dict_path)

    # Set up network graph
    encoder_net = get_encoder('test')
    decoder_net = get_decoder('test')

    gpu_list = config.GPUS
    gpu_list = gpu_list.split(',')

    device_name = '/gpu:{}'.format(gpu_list[0])
    with tf.device(device_name):
        with tf.name_scope('Test'):
            encoder_out = encoder_net.forward(inputdata,
                                              name=config.ENCODER.NETWORKTYPE +
                                              'Encoder')
            decoder_out = decoder_net.predict(encoder_out)

    # Config tf saver
    saver = tf.train.Saver()

    # config tf session
    sess_config = tf.ConfigProto(
        allow_soft_placement=True)  # 允许tf自动选择一个存在并且可用的设备来运行操作
    sess_config.gpu_options.per_process_gpu_memory_fraction = config.TEST.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = config.TEST.TF_ALLOW_GROWTH

    sess = tf.Session(config=sess_config)

    with sess.as_default():

        saver.restore(sess=sess, save_path=weights_path)

        decoder_out_value = sess.run(decoder_out,
                                     feed_dict={inputdata: [image]})

        if config.DECODER_MODEL in config.VALID_DECODER_MODEL:
            if config.DECODER_MODEL == 'normal_ctc' or config.DECODER_MODEL == 'normal_sts':
                preds = convert.sparse_tensor_to_str(
                    decoder_out_value['labels'])
            elif config.DECODER_MODEL == 'reverse_sts':
                preds = convert.sparse_tensor_to_str(
                    decoder_out_value['labels'], reverse=True)
            elif config.DECODER_MODEL == 'bidirection_sts':
                normal_preds = convert.sparse_tensor_to_str(
                    decoder_out_value['normal']['labels'])
                reverse_preds = convert.sparse_tensor_to_str(
                    decoder_out_value['reverse']['labels'], reverse=True)
                if decoder_out_value['normal']['scores'] > decoder_out_value[
                        'reverse']['scores']:
                    preds = normal_preds
                else:
                    preds = reverse_preds
            else:
                raise ValueError('Unknown decoder model: {}'.format(
                    config.DECODER_MODEL))
        else:
            raise ValueError('Unknown decoder model: {}'.format(
                config.DECODER_MODEL))

        logger.info('Predict image {:s} result {:s}'.format(
            ops.split(image_path)[1], preds[0]))

        if is_vis:
            plt.figure('CRNN Model Demo')
            plt.imshow(image_vis)
            plt.show()

    sess.close()

    return