Exemplo n.º 1
0
def tower_loss(scope, images, labels):
    """Calculate the total loss on a single tower running the CIFAR model.

    Args:
      scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0'
      images: Images. 4D tensor of shape [batch_size, height, width, 3].
      labels: Labels. 1D tensor of shape [batch_size].

    Returns:
       Tensor of shape [] containing the total loss for a batch of data
    """

    # Build inference Graph.
    logits = inference.inference(images)

    # Build the portion of the Graph calculating the losses. Note that we will
    # assemble the total_loss using a custom function below.
    _ = inference.loss(logits, labels)

    # Assemble all of the losses for the current tower only.
    losses = tf.get_collection('losses', scope)

    # Calculate the total loss for the current tower.
    total_loss = tf.add_n(losses, name='total_loss')

    # Attach a scalar summary to all individual losses and the total loss; do the
    # same for the averaged version of the losses.
    for l in losses + [total_loss]:
        # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
        # session. This helps the clarity of presentation on tensorboard.
        loss_name = re.sub('%s_[0-9]*/' % inference.TOWER_NAME, '', l.op.name)
        tf.summary.scalar(loss_name, l)

    return total_loss
Exemplo n.º 2
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = inference.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = inference.inference(images)

        # Calculate loss.
        loss = inference.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = inference.train(loss, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
Exemplo n.º 3
0
def main(_):
    Samples_placeholder = tf.placeholder(dtype=tf.float32,
                                         shape=[None, 89],
                                         name='X_input')
    Labels_placeholder = tf.placeholder(dtype=tf.float32,
                                        shape=[None, 5],
                                        name='Y_input')

    global_step = tf.Variable(0, trainable=False)

    logits = inference.inference(Samples_placeholder)

    loss = inference.loss(logits, Labels_placeholder)

    train_op = inference.train(loss, global_step)

    evaluation = inference.evaluation(logits, Labels_placeholder)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        model_dir = './model/train.ckpt-99999'  #tf.train.latest_checkpoint('./logfile')

        saver.restore(sess, model_dir)

        test_data_file = os.path.join(FLAGS.data_dir, 'Data_o.mat')
        if not tf.gfile.Exists(test_data_file):
            raise ValueError('Failed to find file: ' + test_data_file)

        data = scio.loadmat(test_data_file)
        samples = np.array(data['Train_in'], dtype=np.float32)
        labels = np.array(data['Train_out'], dtype=np.float32)

        try:

            eval_value, predict = sess.run([evaluation, logits],
                                           feed_dict={
                                               Samples_placeholder: samples,
                                               Labels_placeholder: labels
                                           })
            print("Loss of testing NN")
            print(eval_value)
            test_file = os.path.join(FLAGS.data_dir, 'test_result.mat')
            predict = predict.astype(np.int32)
            scio.savemat(test_file, {'Predict': predict})

        except tf.errors.OutOfRangeError:
            print('Done testing --epoch limit reached')

        finally:
            coord.request_stop()
            coord.join(threads)
Exemplo n.º 4
0
def train():
    """
	
	:return: 
	"""
    # 读取数据集
    filenames = os.listdir(datasets_dir)
    # 过滤不合格数据集
    for filename in filenames:
        if not os.path.splitext(filename)[1] == '.pickle':
            filenames.remove(filename)

    logits = inference.inference(image_holder, reuse=False)
    global_step = tf.Variable(0, trainable=False)
    # 定义滑动平滑平均值
    variable_averages = tf.train.ExponentialMovingAverage(
        MOVING_AVERAGE, global_step)
    variable_averages_op = variable_averages.apply(tf.trainable_variables())
    # 损失函数值
    loss = inference.loss(logits, label_holder)
    # 使用反向传播函数之前优化学习率
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,
                                               global_step,
                                               MAX_STEPS,
                                               decay_rate=LEARNING_RATE_DECAY)
    # 定义反向传播函数
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(
        loss, global_step=global_step)
    # 使用反向函数和滑动平滑值更新参数
    train_op = tf.group(train_step, variable_averages_op)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        tf.train.start_queue_runners()
        if not os.path.exists(models_dir):
            os.makedirs(models_dir)
        for step in range(MAX_STEPS):
            for filename in filenames:
                train_image, train_label = dataset.read(filename)
                assert isinstance(train_image, list)
                assert isinstance(train_label, list)
                _, loss_value = sess.run([train_op, loss],
                                         feed_dict={
                                             image_holder: train_image,
                                             label_holder: train_label
                                         })
            if step % 2 == 0:
                print("after %d steps, the loss value is %g" %
                      (step, loss_value))
                saver.save(sess, models_file, global_step=step)
Exemplo n.º 5
0
def main(_):
    Samples_placeholder = tf.placeholder(dtype=tf.float32,shape=[None,89],name='X_input')
    Labels_placeholder = tf.placeholder(dtype=tf.float32,shape=[None,5],name='Y_input')
    
    global_step = tf.Variable(0,trainable=False)
        
    logits = inference.inference(Samples_placeholder)
    
    loss = inference.loss(logits,Labels_placeholder)
    
    train_op = inference.train(loss,global_step)    
    
    evaluation = inference.evaluation(logits, Labels_placeholder)
  
    saver = tf.train.Saver()
    summary = tf.summary.merge_all()    
    init = tf.global_variables_initializer()
    
    with tf.Session() as sess:

        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        
        sess.run(init)
        
        
        # load data
        train_file = os.path.join(FLAGS.data_dir,'Data_o.mat' )
        if not tf.gfile.Exists(train_file):
            raise ValueError('Failed to find file: ' + train_file)
        
        
        data = scio.loadmat(train_file)

        
        samples = np.array(data['Train_data'],dtype=np.float32)
        labels = np.array(data['Train_label'],dtype=np.float32)
        
        train_samples = samples
        train_labels = labels
        test_samples = samples
        test_labels = labels


        #Para = data['Data_PS']

        for i in xrange(FLAGS.max_steps):
            start_time = time.time()            
            loss_value,_=sess.run([loss,train_op],feed_dict={Samples_placeholder:samples,Labels_placeholder:labels})
            duration=time.time()-start_time
            if (i+1) % 10 == 0:                
                print("Loss of training NN, step: %d, loss: %f, and time:%f" % (i+1,loss_value,duration))
            #if (i+1) % 100 == 0 :
                eval_value = sess.run(evaluation, feed_dict={Samples_placeholder:train_samples,Labels_placeholder:train_labels})
                
                #print("Evaluation of test samples, step: %d, loss: %f." % (i+1,eval_value))
                #summary_writer.add_summary(summary_str, i)
                
            if i%1000 == 0 or (i+1) == FLAGS.max_steps:
                eval_value = sess.run(evaluation, feed_dict={Samples_placeholder:samples,Labels_placeholder:labels})
                print(eval_value)
                
                
        checkpoint_path = os.path.join(FLAGS.train_dir, 'train.ckpt')
        saver.save(sess, checkpoint_path, global_step=i)        
        predict = sess.run(logits, feed_dict={Samples_placeholder:test_samples,Labels_placeholder:test_labels})        
        test_file = os.path.join(FLAGS.data_dir,'predict.mat' )
        scio.savemat(test_file,{'Predict':predict})
Exemplo n.º 6
0
def main(_):
    Samples_placeholder = tf.placeholder(dtype=tf.float32,
                                         shape=[None, 89],
                                         name='X_input')
    Labels_placeholder = tf.placeholder(dtype=tf.float32,
                                        shape=[None, 5],
                                        name='Y_input')

    global_step = tf.Variable(0, trainable=False)

    logits = inference.inference(Samples_placeholder)  #网络结构

    loss = inference.loss(logits, Labels_placeholder)  #误差函数

    train_op = inference.train(loss, global_step)  #训练方式

    evaluation = inference.evaluation(logits,
                                      Labels_placeholder)  #预测结果和实际结果评估函数

    saver = tf.train.Saver()

    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        model_dir = './model/train.ckpt-99999'  #tf.train.latest_checkpoint('./logfile')

        saver.restore(sess, model_dir)

        test_data_file = os.path.join(FLAGS.data_dir, 'Data_o.mat')
        if not tf.gfile.Exists(test_data_file):
            raise ValueError('Failed to find file: ' + test_data_file)

        data = scio.loadmat(test_data_file)
        #samples = np.array(data['Train_in'],dtype=np.float32)
        labels = np.array(data['Train_out'], dtype=np.float32)
        #新增
        # Connect to the database
        connection = pymysql.connect(
            host='localhost',
            user='******',
            password='******',
            db='weibodata',
            charset='utf8mb4',
        )

        try:
            #   with connection.cursor() as cursor:
            # Create a new record
            #      sql = "INSERT INTO `users` (`email`, `password`) VALUES (%s, %s)"

            #      cursor.execute(sql, ('*****@*****.**', 'very-secret'))

            # connection is not autocommit by default. So you must commit to save
            # your changes.
            #  connection.commit()

            with connection.cursor() as cursor:
                # Read a single record
                sql = "SELECT * FROM `向量结果` WHERE `微博昵称`=%s"

                cursor.execute(sql, (sys.argv[1], ))
                #sql = "SELECT `微博昵称`, `严谨性` FROM `sheet1`ORDER BY `微博昵称` "
                #cursor.execute(sql)
                result = cursor.fetchone()
                arr_ys = list(result)
                index = 0
                arr = []

                for arr_y in arr_ys:
                    if (index not in [
                            102, 101, 100, 99, 93, 92, 89, 80, 77, 15, 14, 13,
                            10, 0
                    ]):
                        arr.append(arr_y)
                    index += 1
            #print(arr)

        finally:
            connection.close()

        samples = np.array(arr, dtype=np.float32)
        samples = samples.reshape((1, 89))
        #新增结束
        try:

            eval_value, predict = sess.run([evaluation, logits],
                                           feed_dict={
                                               Samples_placeholder: samples,
                                               Labels_placeholder: labels
                                           })  #打印loss
            print("Loss of testing NN")
            print(eval_value)
            test_file = os.path.join(FLAGS.data_dir, 'test_result.mat')
            predict = predict.astype(np.int32)
            scio.savemat(test_file, {'Predict': predict})

        except tf.errors.OutOfRangeError:
            print('Done testing --epoch limit reached')

        finally:
            coord.request_stop()
            coord.join(threads)