Example #1
0
def main():
    config = get_config()
    exp_id = FLAGS.id

    save_folder = os.path.realpath(
        os.path.abspath(os.path.join(FLAGS.results, exp_id)))

    if FLAGS.logs is not None:
        logs_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.logs, exp_id)))
        if not os.path.exists(logs_folder):
            os.makedirs(logs_folder)
    else:
        logs_folder = None

    # Configures dataset objects.
    log.info("Building dataset")
    train_data = get_dataset("imagenet",
                             "train",
                             cycle=False,
                             data_aug=False,
                             batch_size=config.valid_batch_size,
                             num_batches=100,
                             preprocessor=config.preprocessor)
    test_data = get_dataset("imagenet",
                            "valid",
                            cycle=False,
                            data_aug=False,
                            batch_size=config.valid_batch_size,
                            preprocessor=config.preprocessor)

    # Evaluates a model.
    eval_model(config, train_data, test_data, save_folder, logs_folder)
Example #2
0
def main():
  # Loads parammeters.
  config = _get_config()

  if FLAGS.validation:
    train_str = "traintrain"
    test_str = "trainval"
    log.warning("Running validation set")
  else:
    train_str = "train"
    test_str = "test"

  if FLAGS.id is None:
    exp_id = "exp_" + FLAGS.dataset + "_" + FLAGS.model
    exp_id = gen_id(exp_id)
  else:
    exp_id = FLAGS.id

  if FLAGS.results is not None:
    save_folder = os.path.realpath(
        os.path.abspath(os.path.join(FLAGS.results, exp_id)))
    if not os.path.exists(save_folder):
      os.makedirs(save_folder)
  else:
    save_folder = None

  if FLAGS.logs is not None:
    logs_folder = os.path.realpath(
        os.path.abspath(os.path.join(FLAGS.logs, exp_id)))
    if not os.path.exists(logs_folder):
      os.makedirs(logs_folder)
  else:
    logs_folder = None

  # Configures dataset objects.
  log.info("Building dataset")
  train_data = get_dataset(FLAGS.dataset, train_str)
  trainval_data = get_dataset(
      FLAGS.dataset,
      train_str,
      num_batches=100,
      data_aug=False,
      cycle=False,
      prefetch=False)
  test_data = get_dataset(
      FLAGS.dataset, test_str, data_aug=False, cycle=False, prefetch=False)

  # Trains a model.
  acc = train_model(
      exp_id,
      config,
      train_data,
      test_data,
      trainval_data,
      save_folder=save_folder,
      logs_folder=logs_folder)
  log.info("Final test accuracy = {:.3f}".format(acc * 100))
Example #3
0
def main():
    config = _get_config()
    exp_id = FLAGS.id

    save_folder = os.path.realpath(
        os.path.abspath(os.path.join(FLAGS.results, exp_id)))

    if FLAGS.logs is not None:
        logs_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.logs, exp_id)))
        if not os.path.exists(logs_folder):
            os.makedirs(logs_folder)
    else:
        logs_folder = None

    # Configures dataset objects.
    log.info("Building dataset")
    train_data = get_dataset(FLAGS.dataset,
                             "train",
                             cycle=False,
                             data_aug=False,
                             prefetch=False)
    test_data = get_dataset(FLAGS.dataset,
                            "test",
                            cycle=False,
                            data_aug=False,
                            prefetch=False)

    # Evaluates a model.
    #eval_model(config, train_data, test_data, save_folder, logs_folder)
    if FLAGS.mode.lower() == 'eval':
        only_adv_eval(config, train_data, test_data, save_folder, logs_folder)
    elif FLAGS.mode.lower() == 'save':
        gen_and_save_adv_examples(config, test_data, save_folder, logs_folder)
    elif FLAGS.mode.lower() == 'transfer':
        assert FLAGS.bbox_id is not None
        bbox_save_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.results, FLAGS.bbox_id)))
        adv_examples = load_adv_examples(bbox_save_folder, FLAGS.fgm_eps,
                                         FLAGS.fgm_norm, FLAGS.targeted)
        transfer_adv_examples(config, adv_examples, save_folder, logs_folder)
Example #4
0
def main():
    # Loads parammeters.
    config = _get_config()

    if FLAGS.id is None:
        exp_id = "exp_" + DATASET + "_" + FLAGS.model
        exp_id = gen_id(exp_id)
    else:
        exp_id = FLAGS.id

    if FLAGS.results is not None:
        save_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.results, exp_id)))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
    else:
        save_folder = None

    if FLAGS.logs is not None:
        logs_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.logs, exp_id)))
        if not os.path.exists(logs_folder):
            os.makedirs(logs_folder)
    else:
        logs_folder = None

    # Configures dataset objects.
    log.info("Building dataset")
    train_data = get_dataset(DATASET,
                             "train",
                             batch_size=config.batch_size,
                             preprocessor=config.preprocessor)

    # Trains a model.
    train_model(exp_id,
                config,
                train_data,
                save_folder=save_folder,
                logs_folder=logs_folder)
def main():
    # Loads parammeters.
    config = _get_config()
    # config.margin = FLAGS.margin
    # print('config margin = {}'.format(config.margin))
    assert (FLAGS.dataset == 'cifar-100')
    # if FLAGS.dataset == "cifar-10":
    #   config.num_classes = 10
    # elif FLAGS.dataset == "cifar-100":
    #   config.num_classes = 100
    # else:
    #   raise ValueError("Unknown dataset name {}".format(FLAGS.dataset))

    if FLAGS.validation:
        train_str = "traintrain"
        test_str = "trainval"
        log.warning("Running validation set")
    else:
        train_str = "train"
        test_str = "test"

    if FLAGS.id is None:
        dataset_name = FLAGS.dataset
        exp_id = "exp_" + dataset_name + "_" + FLAGS.model
        exp_id = gen_id(exp_id)
    else:
        exp_id = FLAGS.id
        dataset_name = exp_id.split("_")[1]

    if FLAGS.results is not None:
        save_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.results, exp_id)))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
    else:
        save_folder = None

    if FLAGS.logs is not None:
        logs_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.logs, exp_id)))
        if not os.path.exists(logs_folder):
            os.makedirs(logs_folder)
    else:
        logs_folder = None

    # Configures dataset objects.
    log.info("Building dataset")
    train_data = get_dataset(dataset_name, train_str)
    trainval_data = get_dataset(dataset_name,
                                train_str,
                                num_batches=100,
                                data_aug=False,
                                cycle=False,
                                prefetch=False)
    test_data = get_dataset(dataset_name,
                            test_str,
                            data_aug=False,
                            cycle=False,
                            prefetch=False)

    # Trains a model.
    acc = train_model(exp_id,
                      config,
                      train_data,
                      test_data,
                      trainval_data,
                      save_folder=save_folder,
                      logs_folder=logs_folder)
    log.info("Final test accuracy = {:.3f}".format(acc * 100))
Example #6
0
def main():
    config = _get_config()  #获取配置参数
    #设置数据集---验证数据集配置
    if FLAGS.dataset == "cifar-10":
        config.num_classes = 10
    elif FLAGS.dataset == "cifar-100":
        config.num_classes = 100
    else:
        raise ValueError("Unknown dataset name {}",
                         format(FLAGS.dataset))  #输出错误信息,用于检查代码输入是否正确

    # 有关验证集使用情况
    if FLAGS.validation:  #有关验证集使用情况
        train_str = "traintrain"
        test_str = "trainval"
        log.warning("Running validation set")
    else:
        train_str = "train"
        test_str = "test"

    #用于存储训练的模型
    if FLAGS.id is None:
        dataset_name = FLAGS.dataset
        exp_id = "exp_" + dataset_name + "_" + FLAGS.model
        exp_id = gen_id(exp_id)
    else:
        exp_id = FLAGS.id
        dataset_name = exp_id.split("_")[1]

    #创建保存模型训练的结果的文件夹
    if FLAGS.results is not None:
        save_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.results, exp_id)))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
    else:
        save_folder = None

    #创建保存日志的文件夹
    if FLAGS.logs is not None:
        logs_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.logs, exp_id)))
        if not os.path.exists(logs_folder):
            os.makedirs(logs_folder)
    else:
        logs_folder = None

    #创建训练集验证集和测试集
    log.info("Building dataset")
    train_data = get_dataset(dataset_name, train_str)
    # print(dataset_name,train_str)
    trainval_data = get_dataset(dataset_name,
                                train_str,
                                num_batches=100,
                                data_aug=False,
                                cycle=False,
                                prefetch=False)
    test_data = get_dataset(dataset_name,
                            test_str,
                            data_aug=False,
                            cycle=False,
                            prefetch=False)

    #模型训练
    acc = train_model(exp_id,
                      config,
                      train_data,
                      test_data,
                      trainval_data,
                      save_folder=save_folder,
                      logs_folder=logs_folder)
    log.info("final test accuracy = {:.3f}".format(acc * 100))