示例#1
0
def train3(train_data, train_label, valid_data, valid_label, train_dir,
           num_classes, batch_size, arch_model, learning_r_decay,
           learning_rate_base, decay_rate, dropout_prob, epoch, height, width,
           checkpoint_exclude_scopes, early_stop, EARLY_STOP_PATIENCE,
           fine_tune, train_all_layers, checkpoint_path, train_n, valid_n,
           g_parameter):
    # ---------------------------------------------------------------------------------#
    X1, X2, X3, Y, is_train, keep_prob_fc = input_placeholder3(
        height, width, num_classes)
    net, _ = build_net3(X1, X2, X3, num_classes, keep_prob_fc, is_train,
                        arch_model)
    variables_to_restore, variables_to_train = g_parameter(
        checkpoint_exclude_scopes)
    loss = cost(Y, net)
    global_step = tf.Variable(0, trainable=False)
    if learning_r_decay:
        learning_rate = tf.train.exponential_decay(learning_rate_base,
                                                   global_step * batch_size,
                                                   train_n,
                                                   decay_rate,
                                                   staircase=True)
    else:
        learning_rate = learning_rate_base
    if train_all_layers:
        variables_to_train = []
    optimizer = train_op(learning_rate, loss, variables_to_train, global_step)
    accuracy = model_accuracy(net, Y, num_classes)
    #------------------------------------------------------------------------------------#
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    saver2 = tf.train.Saver(tf.global_variables())
    if not train_all_layers:
        saver_net = tf.train.Saver(variables_to_restore)
        saver_net.restore(sess, checkpoint_path)

    if fine_tune:
        # saver2.restore(sess, fine_tune_dir)
        latest = tf.train.latest_checkpoint(train_dir)
        if not latest:
            print("No checkpoint to continue from in", train_dir)
            sys.exit(1)
        print("resume", latest)
        saver2.restore(sess, latest)

    # early stopping
    best_valid = np.inf
    best_valid_epoch = 0

    for epoch_i in range(epoch):
        for batch_i in range(int(train_n / batch_size)):
            # images, labels = get_next_batch_from_path(train_data, train_label, batch_i, height, width, batch_size=batch_size, training=True)
            images1, images2, images3, labels = get_next_batch_from_path3(
                train_data,
                train_label,
                batch_i,
                height,
                width,
                batch_size=batch_size,
                training=True)
            los, _ = sess.run(
                [loss, optimizer],
                feed_dict={
                    X1: images1,
                    X2: images2,
                    X3: images3,
                    Y: labels,
                    is_train: True,
                    keep_prob_fc: dropout_prob
                })
            print(los)
            if batch_i % 20 == 0:
                loss_, acc_ = sess.run(
                    [loss, accuracy],
                    feed_dict={
                        X1: images1,
                        X2: images2,
                        X3: images3,
                        Y: labels,
                        is_train: False,
                        keep_prob_fc: 1.0
                    })
                print(
                    'Batch: {:>2}: Training loss: {:>3.5f}, Training accuracy: {:>3.5f}'
                    .format(batch_i, loss_, acc_))

            if batch_i % 100 == 0:
                # images, labels = get_next_batch_from_path(valid_data, valid_label, batch_i%(int(valid_n/batch_size)), height, width, batch_size=batch_size, training=False)
                images1, images2, images3, labels = get_next_batch_from_path3(
                    valid_data,
                    valid_label,
                    batch_i % (int(valid_n / batch_size)),
                    height,
                    width,
                    batch_size=batch_size,
                    training=False)
                ls, acc = sess.run(
                    [loss, accuracy],
                    feed_dict={
                        X1: images1,
                        X2: images2,
                        X3: images3,
                        Y: labels,
                        is_train: False,
                        keep_prob_fc: 1.0
                    })
                print(
                    'Batch: {:>2}: Validation loss: {:>3.5f}, Validation accuracy: {:>3.5f}'
                    .format(batch_i, ls, acc))

        print(
            'Epoch===================================>: {:>2}'.format(epoch_i))
        valid_ls = 0
        valid_acc = 0
        for batch_i in range(int(valid_n / batch_size)):
            images_valid1, images_valid2, images_valid3, labels_valid = get_next_batch_from_path3(
                valid_data,
                valid_label,
                batch_i,
                height,
                width,
                batch_size=batch_size,
                training=False)
            epoch_ls, epoch_acc = sess.run(
                [loss, accuracy],
                feed_dict={
                    X1: images_valid1,
                    X2: images_valid2,
                    X3: images_valid3,
                    Y: labels_valid,
                    keep_prob_fc: 1.0,
                    is_train: False
                })
            valid_ls = valid_ls + epoch_ls
            valid_acc = valid_acc + epoch_acc
        print(
            'Epoch: {:>2}: Validation loss: {:>3.5f}, Validation accuracy: {:>3.5f}'
            .format(epoch_i, valid_ls / int(valid_n / batch_size),
                    valid_acc / int(valid_n / batch_size)))

        if valid_acc / int(valid_n / batch_size) > 0.90:
            checkpoint_path = os.path.join(train_dir, 'model.ckpt')
            saver2.save(sess,
                        checkpoint_path,
                        global_step=epoch_i,
                        write_meta_graph=True)
        # ---------------------------------------------------------------------------------#
        if early_stop:
            loss_valid = valid_ls / int(valid_n / batch_size)
            if loss_valid < best_valid:
                best_valid = loss_valid
                best_valid_epoch = epoch_i
            elif best_valid_epoch + EARLY_STOP_PATIENCE < epoch_i:
                print("Early stopping.")
                print("Best valid loss was {:.6f} at epoch {}.".format(
                    best_valid, best_valid_epoch))
                break
        train_data, train_label = shuffle_train_data(train_data, train_label)
    sess.close()
示例#2
0
def train(train_data, train_label, valid_data, valid_label, train_n, valid_n,
          train_dir, num_classes, batch_size, arch_model, learning_r_decay,
          learning_rate_base, decay_rate, dropout_prob, epoch, height, width,
          checkpoint_exclude_scopes, early_stop, EARLY_STOP_PATIENCE,
          fine_tune, train_all_layers, checkpoint_path, g_parameter):
    # ---------------------------------------------------------------------------------#
    # X, Y, is_train, keep_prob_fc = input_placeholder(height, width, num_classes)
    # net, _ = build_net(X, num_classes, keep_prob_fc, is_train,arch_model)
    #---------------------------------------train---------------------------------------------#
    net, _ = build_net(train_data, num_classes, dropout_prob, True, arch_model)
    variables_to_restore, variables_to_train = g_parameter(
        checkpoint_exclude_scopes)
    loss = cost(train_label, net)
    global_step = tf.Variable(0, trainable=False)
    if learning_r_decay:
        learning_rate = tf.train.exponential_decay(
            learning_rate_base,
            global_step * batch_size,
            1000,  # 多少次衰减一次           
            decay_rate,
            staircase=True)
    else:
        learning_rate = learning_rate_base
    if train_all_layers:
        variables_to_train = []
    optimizer = train_op(learning_rate, loss, variables_to_train, global_step)
    accuracy = model_accuracy(net, train_label, num_classes)
    #---------------------------------------valid---------------------------------------------#
    with tf.variable_scope("", reuse=tf.AUTO_REUSE) as scope:
        # valid_net, _ = build_net(valid_data, num_classes, dropout_prob, False, arch_model)
        valid_net, _ = build_net(valid_data, num_classes, 1.0, False,
                                 arch_model)
    valid_loss = cost(valid_label, valid_net)
    valid_accuracy = model_accuracy(valid_net, valid_label, num_classes)
    #------------------------------------------------------------------------------------#
    sess = tf.InteractiveSession()
    tf.local_variables_initializer().run()
    tf.global_variables_initializer().run()
    saver2 = tf.train.Saver(tf.global_variables())
    # if not train_all_layers:
    saver_net = tf.train.Saver(variables_to_restore)
    saver_net.restore(sess, checkpoint_path)

    if fine_tune:
        # saver2.restore(sess, fine_tune_dir)
        latest = tf.train.latest_checkpoint(train_dir)
        if not latest:
            print("No checkpoint to continue from in", train_dir)
            sys.exit(1)
        print("resume", latest)
        saver2.restore(sess, latest)

    # Start the queue runners.
    tf.train.start_queue_runners(sess=sess)

    # early stopping
    best_valid = np.inf
    best_valid_epoch = 0

    for epoch_i in range(epoch):
        for batch_i in range(int(train_n / batch_size)):
            los, _ = sess.run([loss, optimizer])
            # print (los)
            if batch_i % 100 == 0:
                loss_, acc_ = sess.run([loss, accuracy])
                print(
                    'Batch: {:>2}: Training loss: {:>3.5f}, Training accuracy: {:>3.5f}'
                    .format(batch_i, loss_, acc_))

            if batch_i % 500 == 0:
                ls, acc = sess.run([valid_loss, valid_accuracy])
                print(
                    'Batch: {:>2}: Validation loss: {:>3.5f}, Validation accuracy: {:>3.5f}'
                    .format(batch_i, ls, acc))
            if batch_i % 500 == 0:
                checkpoint_path = os.path.join(train_dir, 'model.ckpt')
                saver2.save(sess,
                            checkpoint_path,
                            global_step=epoch_i,
                            write_meta_graph=False)
    sess.close()