fine_tune(sess)
    start_epoch = args.continue_from_epoch if args.continue_from_epoch != -1 else 0
    best_val_acc_mean = 0.
    best_val_epoch = 6
    with tqdm.tqdm(total=args.total_epochs) as pbar_e:
        for e in range(start_epoch, args.total_epochs):
            total_train_c_loss_mean, total_train_c_loss_std, total_train_accuracy_mean, total_train_accuracy_std =\
                experiment.run_training_epoch(total_train_batches=total_train_batches,
                                                                                sess=sess)
            print(
                "Epoch {}: train_loss_mean: {}, train_loss_std: {}, train_accuracy_mean: {}, train_accuracy_std: {}"
                .format(e, total_train_c_loss_mean, total_train_c_loss_std,
                        total_train_accuracy_mean, total_train_accuracy_std))

            total_val_c_loss_mean, total_val_c_loss_std, total_val_accuracy_mean, total_val_accuracy_std = \
                experiment.run_validation_epoch(total_val_batches=total_val_batches,
                                              sess=sess)
            print(
                "Epoch {}: val_loss_mean: {}, val_loss_std: {}, val_accuracy_mean: {}, val_accuracy_std: {}"
                .format(e, total_val_c_loss_mean, total_val_c_loss_std,
                        total_val_accuracy_mean, total_val_accuracy_std))

            if total_val_accuracy_mean >= best_val_acc_mean:  #if new best val accuracy -> produce test statistics
                best_val_acc_mean = total_val_accuracy_mean
                best_val_epoch = e

                val_save_path = val_saver.save(
                    sess,
                    "{}/best_val_{}_{}.ckpt".format(saved_models_filepath,
                                                    args.experiment_title, e))

                total_test_c_loss_mean, total_test_c_loss_std, total_test_accuracy_mean, total_test_accuracy_std \
Exemplo n.º 2
0
                                                   variables_to_restore,
                                                   ignore_missing_vars=True)
        fine_tune(sess)

    best_val = 0.
    with tqdm.tqdm(total=total_epochs) as pbar_e:
        for e in range(0, total_epochs):
            total_c_loss, total_accuracy = experiment.run_training_epoch(
                total_train_batches=total_train_batches, sess=sess)
            # tf.summary.scalar("loss_train", total_c_loss)
            # tf.summary.scalar("acc_train", total_accuracy)

            print("Epoch {}: train_loss: {}, train_accuracy: {}".format(
                e, total_c_loss, total_accuracy))

            total_val_c_loss, total_val_accuracy = experiment.run_validation_epoch(
                total_val_batches=total_val_batches, sess=sess)
            # tf.summary.scalar("loss_val", total_val_c_loss)
            # tf.summary.scalar("loss_val", total_val_accuracy)

            print("Epoch {}: val_loss: {}, val_accuracy: {}".format(
                e, total_val_c_loss, total_val_accuracy))

            if total_val_accuracy >= best_val:  #if new best val accuracy -> produce test statistics
                best_val = total_val_accuracy
                total_test_c_loss, total_test_accuracy = experiment.run_testing_epoch(
                    total_test_batches=total_test_batches, sess=sess)
                print("Epoch {}: test_loss: {}, test_accuracy: {}".format(
                    e, total_test_c_loss, total_test_accuracy))
            else:
                total_test_c_loss = -1
                total_test_accuracy = -1