コード例 #1
0
def main(_):
    path_to_train_tfrecords_file = os.path.join(FLAGS.data_dir,
                                                'train.tfrecords')
    path_to_val_tfrecords_file = os.path.join(FLAGS.data_dir, 'val.tfrecords')
    path_to_tfrecords_meta_file = os.path.join(FLAGS.data_dir, 'meta.json')
    path_to_train_log_dir = FLAGS.train_logdir
    path_to_train_log_dir = os.path.join(
        FLAGS.train_logdir,
        "ssim_{:.2f}-defend_{}-attacker_{}".format(FLAGS.ssim_weight,
                                                   FLAGS.defend_layer,
                                                   FLAGS.attacker_type))
    print("log path: {}".format(path_to_train_log_dir))
    path_to_restore_model_checkpoint_file = FLAGS.restore_checkpoint
    training_options = {
        'batch_size': FLAGS.batch_size,
        'learning_rate': FLAGS.learning_rate,
        'patience': FLAGS.patience,
        'decay_steps': FLAGS.decay_steps,
        'decay_rate': FLAGS.decay_rate
    }

    meta = Meta()
    meta.load(path_to_tfrecords_meta_file)

    _train(path_to_train_tfrecords_file, meta.num_train_examples,
           path_to_val_tfrecords_file, meta.num_val_examples,
           path_to_train_log_dir, path_to_restore_model_checkpoint_file,
           training_options)
コード例 #2
0
def main(_):
    path_to_test_tfrecords_file = os.path.join(FLAGS.data_dir, 'generated.tfrecords')
    path_to_tfrecords_meta_file = os.path.join(FLAGS.data_dir, 'meta.json')
    path_to_checkpoint_dir = FLAGS.checkpoint_dir
    save_file = 'result_generated.txt'

    path_to_test_eval_log_dir = os.path.join(FLAGS.eval_logdir, 'test')

    meta = Meta()
    meta.load(path_to_tfrecords_meta_file)

    _eval(path_to_checkpoint_dir, path_to_test_tfrecords_file, meta.num_test_examples, path_to_test_eval_log_dir, save_file)
コード例 #3
0
def main(_):
    path_to_train_tfrecords_file = os.path.join(FLAGS.data_dir,
                                                'train.tfrecords')
    path_to_val_tfrecords_file = os.path.join(FLAGS.data_dir, 'val.tfrecords')
    path_to_tfrecords_meta_file = os.path.join(FLAGS.data_dir,
                                               'tfrecords_meta.json')
    path_to_train_log_dir = FLAGS.train_logdir
    path_to_restore_checkpoint_file = FLAGS.restore_checkpoint

    meta = Meta()
    meta.load(path_to_tfrecords_meta_file)

    _train(path_to_train_tfrecords_file, path_to_val_tfrecords_file,
           meta.num_val_examples, path_to_train_log_dir,
           path_to_restore_checkpoint_file)
コード例 #4
0
ファイル: eval.py プロジェクト: gjmulder/meter-pop
def main(_):
    path_to_train_tfrecords_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
    path_to_val_tfrecords_file = os.path.join(FLAGS.data_dir, 'val.tfrecords')
    path_to_test_tfrecords_file = os.path.join(FLAGS.data_dir, 'test.tfrecords')
    path_to_tfrecords_meta_file = os.path.join(FLAGS.data_dir, 'meta.json')
    path_to_checkpoint_dir = FLAGS.checkpoint_dir

    path_to_train_eval_log_dir = os.path.join(FLAGS.eval_logdir, 'train')
    path_to_val_eval_log_dir = os.path.join(FLAGS.eval_logdir, 'val')
    path_to_test_eval_log_dir = os.path.join(FLAGS.eval_logdir, 'test')

    meta = Meta()
    meta.load(path_to_tfrecords_meta_file)

    _eval(path_to_checkpoint_dir, path_to_train_tfrecords_file, meta.num_train_examples, path_to_train_eval_log_dir)
    _eval(path_to_checkpoint_dir, path_to_val_tfrecords_file, meta.num_val_examples, path_to_val_eval_log_dir)
    _eval(path_to_checkpoint_dir, path_to_test_tfrecords_file, meta.num_test_examples, path_to_test_eval_log_dir)
コード例 #5
0
def main_eval(_):

    parser = argparse.ArgumentParser(
        description="Evaluation Routine for SVHNClassifier")
    parser.add_argument("--data_dir",
                        required=True,
                        help="Directory to read TFRecords files")
    parser.add_argument("--path_to_checkpoint_dir",
                        required=True,
                        help="Directory to read checkpoint files")
    parser.add_argument("--eval_logdir",
                        required=True,
                        help="Directory to write evaluation logs")
    parser.add_argument("--path_to_train_tfrecords_file",
                        required=True,
                        help="Tfrecords file in train directory")
    parser.add_argument("--path_to_val_tfrecords_file",
                        required=True,
                        help="Tfrecords file in val directory")
    parser.add_argument("--path_to_test_tfrecords_file",
                        required=True,
                        help="Tfrecords file in test directory")
    parser.add_argument("--path_to_tfrecords_meta_file",
                        required=True,
                        help="Meta file in directory")
    parser.add_argument("--path_to_train_eval_log_dir",
                        required=True,
                        help="Training and evaluating log directory")
    parser.add_argument("--path_to_val_eval_log_dir",
                        required=True,
                        help="Validating and evaluating log directory")
    parser.add_argument("--path_to_test_eval_log_dir",
                        required=True,
                        help="Testing and evaluating log directory")
    args = parser.parse_args()

    meta = Meta()
    meta.load(args.path_to_tfrecords_meta_file)

    _eval(args.path_to_checkpoint_dir, args.path_to_train_tfrecords_file,
          meta.num_train_examples, args.path_to_train_eval_log_dir)
    _eval(args.path_to_checkpoint_dir, args.path_to_val_tfrecords_file,
          meta.num_val_examples, args.path_to_val_eval_log_dir)
    _eval(args.path_to_checkpoint_dir, args.path_to_test_tfrecords_file,
          meta.num_test_examples, args.path_to_test_eval_log_dir)
コード例 #6
0
def main(_):
    path_to_train_tfrecords_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
    path_to_val_tfrecords_file = os.path.join(FLAGS.data_dir, 'val.tfrecords')
    path_to_tfrecords_meta_file = os.path.join(FLAGS.data_dir, 'meta.json')
    path_to_train_log_dir = FLAGS.train_logdir
    path_to_restore_checkpoint_file = FLAGS.restore_checkpoint
    training_options = {
        'batch_size': FLAGS.batch_size,
        'learning_rate': FLAGS.learning_rate,
        'epoches': FLAGS.epoches,
        'decay_steps': FLAGS.decay_steps,
        'decay_rate': FLAGS.decay_rate
    }

    meta = Meta()
    meta.load(path_to_tfrecords_meta_file)

    _train(path_to_train_tfrecords_file, meta.num_train_examples,
           path_to_val_tfrecords_file, meta.num_val_examples,
           path_to_train_log_dir, path_to_restore_checkpoint_file,
           training_options)
コード例 #7
0
ファイル: train.py プロジェクト: Ankita-Das/SVHNClassifier
def main(_):
    path_to_train_tfrecords_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
    path_to_val_tfrecords_file = os.path.join(FLAGS.data_dir, 'val.tfrecords')
    path_to_tfrecords_meta_file = os.path.join(FLAGS.data_dir, 'meta.json')
    path_to_train_log_dir = FLAGS.train_logdir
    path_to_restore_checkpoint_file = FLAGS.restore_checkpoint
    training_options = {
        'batch_size': FLAGS.batch_size,
        'learning_rate': FLAGS.learning_rate,
        'patience': FLAGS.patience,
        'decay_steps': FLAGS.decay_steps,
        'decay_rate': FLAGS.decay_rate
    }

    meta = Meta()
    meta.load(path_to_tfrecords_meta_file)

    _train(path_to_train_tfrecords_file, meta.num_train_examples,
           path_to_val_tfrecords_file, meta.num_val_examples,
           path_to_train_log_dir, path_to_restore_checkpoint_file,
           training_options)
コード例 #8
0
def main(_):
    path_to_train_tfrecords_files = [
        os.path.join(FLAGS.data_dir, 'train.tfrecords'),
        os.path.join(
            os.path.split(FLAGS.data_dir)[0],
            'MNIST2SVHN/MNIST_Converted_train.tfrecords')
    ]
    path_to_val_tfrecords_files = [
        os.path.join(FLAGS.data_dir, 'val.tfrecords'),
        os.path.join(
            os.path.split(FLAGS.data_dir)[0],
            'MNIST2SVHN/MNIST_Converted_val.tfrecords')
    ]
    path_to_tfrecords_meta_files = [
        os.path.join(FLAGS.data_dir, 'meta.json'),
        os.path.join(
            os.path.split(FLAGS.data_dir)[0],
            'MNIST2SVHN/MNIST_Converted_meta.json')
    ]
    path_to_train_log_dir = FLAGS.train_logdir
    path_to_restore_checkpoint_file = FLAGS.restore_checkpoint
    training_options = {
        'batch_size': FLAGS.batch_size,
        'learning_rate': FLAGS.learning_rate,
        'epoches': FLAGS.epoches,
        'decay_steps': FLAGS.decay_steps,
        'decay_rate': FLAGS.decay_rate
    }

    meta = Meta()
    for path_to_tfrecords_meta_file in path_to_tfrecords_meta_files:
        meta.load(path_to_tfrecords_meta_file)

    _train(path_to_train_tfrecords_files, meta.num_train_examples,
           path_to_val_tfrecords_files, meta.num_val_examples,
           path_to_train_log_dir, path_to_restore_checkpoint_file,
           training_options)
コード例 #9
0
def main(unused_argv):
    path_to_train_tfrecords_file = os.path.join(FLAGS.data_dir, 'train.tfrecords') #./data/train.tfrecords
    path_to_val_tfrecords_file = os.path.join(FLAGS.data_dir, 'val.tfrecords') #./data/val.tfrecords
    path_to_tfrecords_meta_file = os.path.join(FLAGS.data_dir, 'meta.json') #./data/meta.json
    path_to_train_log_dir = FLAGS.train_logdir #./logs/train
    path_to_restore_checkpoint_file = FLAGS.restore_checkpoint # None or ./logs/train/latest.ckpt 

    training_options = {
        'batch_size': FLAGS.batch_size,
        'learning_rate': FLAGS.learning_rate,
        'patience': FLAGS.patience,
        'decay_steps': FLAGS.decay_steps,
        'decay_rate': FLAGS.decay_rate
    }

    meta = Meta()
    meta.load(path_to_tfrecords_meta_file)
    num_train_examples = meta.num_train_examples
    num_val_examples = meta.num_val_examples

    _train(path_to_train_tfrecords_file, num_train_examples,
           path_to_val_tfrecords_file, num_val_examples,
           path_to_train_log_dir, path_to_restore_checkpoint_file,
           training_options)
コード例 #10
0
ファイル: train.py プロジェクト: kdg1016/SVHNClassifier
def main_train(_):

    parser = argparse.ArgumentParser(
        description="Training Routine for SVHNClassifier")
    parser.add_argument("--data_dir",
                        required=True,
                        help="Path to SVHN (format 1) folders")
    parser.add_argument("--path_to_train_log_dir",
                        required=True,
                        help="Directory to write training logs")
    parser.add_argument(
        "--path_to_restore_checkpoint_file",
        required=False,
        help=
        "Path to restore checkpoint (without postfix), e.g. ./logs/train/model.ckpt-100"
    )
    parser.add_argument("--path_to_train_tfrecords_file",
                        required=True,
                        help="Tfrecords file in train directory")
    parser.add_argument("--path_to_val_tfrecords_file",
                        required=True,
                        help="Tfrecords file in val directory")
    parser.add_argument("--path_to_tfrecords_meta_file",
                        required=True,
                        help="Meta file in directory")

    parser.add_argument("--batch_size",
                        type=int,
                        required=True,
                        help="Default 32")
    parser.add_argument("--learning_rate",
                        type=float,
                        required=True,
                        help="Default 1e-2")
    parser.add_argument("--patience",
                        type=int,
                        required=True,
                        help="Default 100, set -1 to train infinitely")
    parser.add_argument("--decay_steps",
                        type=int,
                        required=True,
                        help="Default 10000")
    parser.add_argument("--decay_rate",
                        type=float,
                        required=True,
                        help="Default 0.9")
    args = parser.parse_args()

    training_options = {
        'batch_size': args.batch_size,
        'learning_rate': args.learning_rate,
        'patience': args.patience,
        'decay_steps': args.decay_steps,
        'decay_rate': args.decay_rate
    }

    meta = Meta()
    meta.load(args.path_to_tfrecords_meta_file)

    _train(args.path_to_train_tfrecords_file, meta.num_train_examples,
           args.path_to_val_tfrecords_file, meta.num_val_examples,
           args.path_to_train_log_dir, args.path_to_restore_checkpoint_file,
           training_options)