Пример #1
0
batch_size = config.batch_size
sample_dir = 'gender'
train_rate = 1.0
test_data, test_label, valid_data, valid_label, test_n, valid_n, note_label = load_image(
    sample_dir, train_rate).gen_train_valid()
print('test_n', test_n)
print('valid_n', valid_n)
test_label = np_utils.to_categorical(test_label, num_classes)
valid_label = np_utils.to_categorical(valid_label, num_classes)

X1, X2, X3, Y, is_train, keep_prob_fc = input_placeholder3(
    height, width, num_classes)
net, net_vis = build_net3(X1, X2, X3, num_classes, keep_prob_fc, is_train,
                          arch_model)
loss = cost(Y, net)
accuracy = model_accuracy(net, Y, num_classes)
#predict = tf.reshape(net, [-1, num_classes], name='predictions')

if __name__ == '__main__':
    train_dir = 'model'
    latest = tf.train.latest_checkpoint(train_dir)
    if not latest:
        print("No checkpoint to continue from in", train_dir)
        sys.exit(1)
    print("resume", latest)
    sess = tf.Session()
    saver = tf.train.Saver(tf.global_variables())
    saver.restore(sess, latest)
    test_ls = 0
    test_acc = 0
    for batch_i in range(int(test_n / batch_size)):
Пример #2
0
def train3(train_data, train_label, valid_data, valid_label, train_dir,
           num_classes, batch_size, arch_model, learning_r_decay,
           learning_rate_base, decay_rate, dropout_prob, epoch, height, width,
           checkpoint_exclude_scopes, early_stop, EARLY_STOP_PATIENCE,
           fine_tune, train_all_layers, checkpoint_path, train_n, valid_n,
           g_parameter):
    # ---------------------------------------------------------------------------------#
    X1, X2, X3, Y, is_train, keep_prob_fc = input_placeholder3(
        height, width, num_classes)
    net, _ = build_net3(X1, X2, X3, num_classes, keep_prob_fc, is_train,
                        arch_model)
    variables_to_restore, variables_to_train = g_parameter(
        checkpoint_exclude_scopes)
    loss = cost(Y, net)
    global_step = tf.Variable(0, trainable=False)
    if learning_r_decay:
        learning_rate = tf.train.exponential_decay(learning_rate_base,
                                                   global_step * batch_size,
                                                   train_n,
                                                   decay_rate,
                                                   staircase=True)
    else:
        learning_rate = learning_rate_base
    if train_all_layers:
        variables_to_train = []
    optimizer = train_op(learning_rate, loss, variables_to_train, global_step)
    accuracy = model_accuracy(net, Y, num_classes)
    #------------------------------------------------------------------------------------#
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    saver2 = tf.train.Saver(tf.global_variables())
    if not train_all_layers:
        saver_net = tf.train.Saver(variables_to_restore)
        saver_net.restore(sess, checkpoint_path)

    if fine_tune:
        # saver2.restore(sess, fine_tune_dir)
        latest = tf.train.latest_checkpoint(train_dir)
        if not latest:
            print("No checkpoint to continue from in", train_dir)
            sys.exit(1)
        print("resume", latest)
        saver2.restore(sess, latest)

    # early stopping
    best_valid = np.inf
    best_valid_epoch = 0

    for epoch_i in range(epoch):
        for batch_i in range(int(train_n / batch_size)):
            # images, labels = get_next_batch_from_path(train_data, train_label, batch_i, height, width, batch_size=batch_size, training=True)
            images1, images2, images3, labels = get_next_batch_from_path3(
                train_data,
                train_label,
                batch_i,
                height,
                width,
                batch_size=batch_size,
                training=True)
            los, _ = sess.run(
                [loss, optimizer],
                feed_dict={
                    X1: images1,
                    X2: images2,
                    X3: images3,
                    Y: labels,
                    is_train: True,
                    keep_prob_fc: dropout_prob
                })
            print(los)
            if batch_i % 20 == 0:
                loss_, acc_ = sess.run(
                    [loss, accuracy],
                    feed_dict={
                        X1: images1,
                        X2: images2,
                        X3: images3,
                        Y: labels,
                        is_train: False,
                        keep_prob_fc: 1.0
                    })
                print(
                    'Batch: {:>2}: Training loss: {:>3.5f}, Training accuracy: {:>3.5f}'
                    .format(batch_i, loss_, acc_))

            if batch_i % 100 == 0:
                # images, labels = get_next_batch_from_path(valid_data, valid_label, batch_i%(int(valid_n/batch_size)), height, width, batch_size=batch_size, training=False)
                images1, images2, images3, labels = get_next_batch_from_path3(
                    valid_data,
                    valid_label,
                    batch_i % (int(valid_n / batch_size)),
                    height,
                    width,
                    batch_size=batch_size,
                    training=False)
                ls, acc = sess.run(
                    [loss, accuracy],
                    feed_dict={
                        X1: images1,
                        X2: images2,
                        X3: images3,
                        Y: labels,
                        is_train: False,
                        keep_prob_fc: 1.0
                    })
                print(
                    'Batch: {:>2}: Validation loss: {:>3.5f}, Validation accuracy: {:>3.5f}'
                    .format(batch_i, ls, acc))

        print(
            'Epoch===================================>: {:>2}'.format(epoch_i))
        valid_ls = 0
        valid_acc = 0
        for batch_i in range(int(valid_n / batch_size)):
            images_valid1, images_valid2, images_valid3, labels_valid = get_next_batch_from_path3(
                valid_data,
                valid_label,
                batch_i,
                height,
                width,
                batch_size=batch_size,
                training=False)
            epoch_ls, epoch_acc = sess.run(
                [loss, accuracy],
                feed_dict={
                    X1: images_valid1,
                    X2: images_valid2,
                    X3: images_valid3,
                    Y: labels_valid,
                    keep_prob_fc: 1.0,
                    is_train: False
                })
            valid_ls = valid_ls + epoch_ls
            valid_acc = valid_acc + epoch_acc
        print(
            'Epoch: {:>2}: Validation loss: {:>3.5f}, Validation accuracy: {:>3.5f}'
            .format(epoch_i, valid_ls / int(valid_n / batch_size),
                    valid_acc / int(valid_n / batch_size)))

        if valid_acc / int(valid_n / batch_size) > 0.90:
            checkpoint_path = os.path.join(train_dir, 'model.ckpt')
            saver2.save(sess,
                        checkpoint_path,
                        global_step=epoch_i,
                        write_meta_graph=True)
        # ---------------------------------------------------------------------------------#
        if early_stop:
            loss_valid = valid_ls / int(valid_n / batch_size)
            if loss_valid < best_valid:
                best_valid = loss_valid
                best_valid_epoch = epoch_i
            elif best_valid_epoch + EARLY_STOP_PATIENCE < epoch_i:
                print("Early stopping.")
                print("Best valid loss was {:.6f} at epoch {}.".format(
                    best_valid, best_valid_epoch))
                break
        train_data, train_label = shuffle_train_data(train_data, train_label)
    sess.close()
Пример #3
0
def train(train_data, train_label, valid_data, valid_label, train_dir,
          num_classes, batch_size, arch_model, learning_r_decay,
          learning_rate_base, decay_rate, dropout_prob, epoch, height, width,
          checkpoint_exclude_scopes, early_stop, EARLY_STOP_PATIENCE,
          fine_tune, train_all_layers, checkpoint_path, train_n, valid_n,
          g_parameter):
    # ---------------------------------------------------------------------------------#
    X, Y, is_train, keep_prob_fc = input_placeholder(height, width,
                                                     num_classes)

    global_step = tf.Variable(0, trainable=False)
    if learning_r_decay:
        learning_rate = tf.train.exponential_decay(learning_rate_base,
                                                   global_step * batch_size,
                                                   train_n,
                                                   decay_rate,
                                                   staircase=True)
    else:
        learning_rate = learning_rate_base
    if train_all_layers:
        variables_to_train = []

    # Create an optimizer that performs gradient descent.
    opt = tf.train.AdamOptimizer(learning_rate)

    with tf.device('/cpu:0'):
        tower_grads = []
        #for i in range(num_gpus=2):
        #with tf.device(assign_to_device('/gpu:{}'.format(i), ps_device='/cpu:0')):
        with tf.device('/gpu:0'):
            with tf.variable_scope('', reuse=tf.AUTO_REUSE) as scope:
                net, _ = build_net(X[0:int(batch_size / 2)], num_classes,
                                   keep_prob_fc, is_train, arch_model)
                loss_0 = tf.reduce_mean(
                    tf.nn.softmax_cross_entropy_with_logits(
                        labels=Y[0:int(batch_size / 2)], logits=net))
                accuracy_0 = model_accuracy(net, Y[0:int(batch_size / 2)],
                                            num_classes)
                grads_0 = opt.compute_gradients(loss_0)
                tower_grads.append(grads_0)
        with tf.device('gpu:1'):
            with tf.variable_scope('', reuse=tf.AUTO_REUSE) as scope:
                net, _ = build_net(X[int(batch_size / 2):batch_size],
                                   num_classes, keep_prob_fc, is_train,
                                   arch_model)
                loss_1 = tf.reduce_mean(
                    tf.nn.softmax_cross_entropy_with_logits(
                        labels=Y[int(batch_size / 2):batch_size], logits=net))
                accuracy_1 = model_accuracy(net,
                                            Y[int(batch_size / 2):batch_size],
                                            num_classes)
                grads_1 = opt.compute_gradients(loss_1)
                tower_grads.append(grads_1)
        variables_to_restore, variables_to_train = g_parameter(
            checkpoint_exclude_scopes)
        grads = average_gradients(tower_grads)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            optimizer = opt.apply_gradients(grads, global_step=global_step)
        loss = tf.reduce_mean([loss_0, loss_1])
        accuracy = tf.reduce_mean([accuracy_0, accuracy_1])

        #------------------------------------------------------------------------------------#
        sess = tf.Session()
        init = tf.global_variables_initializer()
        sess.run(init)
        saver2 = tf.train.Saver(tf.global_variables())
        #if not train_all_layers:
        saver_net = tf.train.Saver(variables_to_restore)
        saver_net.restore(sess, checkpoint_path)

        if fine_tune:
            # saver2.restore(sess, fine_tune_dir)
            latest = tf.train.latest_checkpoint(train_dir)
            if not latest:
                print("No checkpoint to continue from in", train_dir)
                sys.exit(1)
            print("resume", latest)
            saver2.restore(sess, latest)

        # early stopping
        best_valid = np.inf
        best_valid_epoch = 0

        for epoch_i in range(epoch):
            for batch_i in range(int(train_n / batch_size)):
                images, labels = get_next_batch_from_path(
                    train_data,
                    train_label,
                    batch_i,
                    height,
                    width,
                    batch_size=batch_size,
                    training=True)
                los, _ = sess.run(
                    [loss, optimizer],
                    feed_dict={
                        X: images,
                        Y: labels,
                        is_train: True,
                        keep_prob_fc: dropout_prob
                    })
                print(los)
                if batch_i % 20 == 0:
                    loss_, acc_ = sess.run([loss, accuracy],
                                           feed_dict={
                                               X: images,
                                               Y: labels,
                                               is_train: False,
                                               keep_prob_fc: 1.0
                                           })
                    print(
                        'Batch: {:>2}: Training loss: {:>3.5f}, Training accuracy: {:>3.5f}'
                        .format(batch_i, loss_, acc_))

                if batch_i % 100 == 0:
                    images, labels = get_next_batch_from_path(
                        valid_data,
                        valid_label,
                        batch_i % (int(valid_n / batch_size)),
                        height,
                        width,
                        batch_size=batch_size,
                        training=False)
                    ls, acc = sess.run([loss, accuracy],
                                       feed_dict={
                                           X: images,
                                           Y: labels,
                                           is_train: False,
                                           keep_prob_fc: 1.0
                                       })
                    print(
                        'Batch: {:>2}: Validation loss: {:>3.5f}, Validation accuracy: {:>3.5f}'
                        .format(batch_i, ls, acc))

            print('Epoch===================================>: {:>2}'.format(
                epoch_i))
            valid_ls = 0
            valid_acc = 0
            for batch_i in range(int(valid_n / batch_size)):
                images_valid, labels_valid = get_next_batch_from_path(
                    valid_data,
                    valid_label,
                    batch_i,
                    height,
                    width,
                    batch_size=batch_size,
                    training=False)
                epoch_ls, epoch_acc = sess.run(
                    [loss, accuracy],
                    feed_dict={
                        X: images_valid,
                        Y: labels_valid,
                        keep_prob_fc: 1.0,
                        is_train: False
                    })
                valid_ls = valid_ls + epoch_ls
                valid_acc = valid_acc + epoch_acc
            print(
                'Epoch: {:>2}: Validation loss: {:>3.5f}, Validation accuracy: {:>3.5f}'
                .format(epoch_i, valid_ls / int(valid_n / batch_size),
                        valid_acc / int(valid_n / batch_size)))

            if valid_acc / int(valid_n / batch_size) > 0.90:
                checkpoint_path = os.path.join(train_dir, 'model.ckpt')
                saver2.save(sess,
                            checkpoint_path,
                            global_step=epoch_i,
                            write_meta_graph=False)
            # ---------------------------------------------------------------------------------#
            if early_stop:
                loss_valid = valid_ls / int(valid_n / batch_size)
                if loss_valid < best_valid:
                    best_valid = loss_valid
                    best_valid_epoch = epoch_i
                elif best_valid_epoch + EARLY_STOP_PATIENCE < epoch_i:
                    print("Early stopping.")
                    print("Best valid loss was {:.6f} at epoch {}.".format(
                        best_valid, best_valid_epoch))
                    break
            train_data, train_label = shuffle_train_data(
                train_data, train_label)
        sess.close()
Пример #4
0
def train(train_data, train_label, valid_data, valid_label, train_n, valid_n,
          train_dir, num_classes, batch_size, arch_model, learning_r_decay,
          learning_rate_base, decay_rate, dropout_prob, epoch, height, width,
          checkpoint_exclude_scopes, early_stop, EARLY_STOP_PATIENCE,
          fine_tune, train_all_layers, checkpoint_path, g_parameter):
    # ---------------------------------------------------------------------------------#
    # X, Y, is_train, keep_prob_fc = input_placeholder(height, width, num_classes)
    # net, _ = build_net(X, num_classes, keep_prob_fc, is_train,arch_model)
    #---------------------------------------train---------------------------------------------#
    net, _ = build_net(train_data, num_classes, dropout_prob, True, arch_model)
    variables_to_restore, variables_to_train = g_parameter(
        checkpoint_exclude_scopes)
    loss = cost(train_label, net)
    global_step = tf.Variable(0, trainable=False)
    if learning_r_decay:
        learning_rate = tf.train.exponential_decay(
            learning_rate_base,
            global_step * batch_size,
            1000,  # 多少次衰减一次           
            decay_rate,
            staircase=True)
    else:
        learning_rate = learning_rate_base
    if train_all_layers:
        variables_to_train = []
    optimizer = train_op(learning_rate, loss, variables_to_train, global_step)
    accuracy = model_accuracy(net, train_label, num_classes)
    #---------------------------------------valid---------------------------------------------#
    with tf.variable_scope("", reuse=tf.AUTO_REUSE) as scope:
        # valid_net, _ = build_net(valid_data, num_classes, dropout_prob, False, arch_model)
        valid_net, _ = build_net(valid_data, num_classes, 1.0, False,
                                 arch_model)
    valid_loss = cost(valid_label, valid_net)
    valid_accuracy = model_accuracy(valid_net, valid_label, num_classes)
    #------------------------------------------------------------------------------------#
    sess = tf.InteractiveSession()
    tf.local_variables_initializer().run()
    tf.global_variables_initializer().run()
    saver2 = tf.train.Saver(tf.global_variables())
    # if not train_all_layers:
    saver_net = tf.train.Saver(variables_to_restore)
    saver_net.restore(sess, checkpoint_path)

    if fine_tune:
        # saver2.restore(sess, fine_tune_dir)
        latest = tf.train.latest_checkpoint(train_dir)
        if not latest:
            print("No checkpoint to continue from in", train_dir)
            sys.exit(1)
        print("resume", latest)
        saver2.restore(sess, latest)

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    # early stopping
    best_valid = np.inf
    best_valid_epoch = 0

    for epoch_i in range(epoch):
        for batch_i in range(int(train_n / batch_size)):
            los, _ = sess.run([loss, optimizer])
            # print (los)
            if batch_i % 100 == 0:
                loss_, acc_ = sess.run([loss, accuracy])
                print(
                    'Batch: {:>2}: Training loss: {:>3.5f}, Training accuracy: {:>3.5f}'
                    .format(batch_i, loss_, acc_))

            if batch_i % 500 == 0:
                ls, acc = sess.run([valid_loss, valid_accuracy])
                print(
                    'Batch: {:>2}: Validation loss: {:>3.5f}, Validation accuracy: {:>3.5f}'
                    .format(batch_i, ls, acc))
            if batch_i % 500 == 0:
                checkpoint_path = os.path.join(train_dir, 'model.ckpt')
                saver2.save(sess,
                            checkpoint_path,
                            global_step=epoch_i,
                            write_meta_graph=False)
    sess.close()