Beispiel #1
0
def main():
  # Loads parammeters.
  config = _get_config()

  if FLAGS.validation:
    train_str = "traintrain"
    test_str = "trainval"
    log.warning("Running validation set")
  else:
    train_str = "train"
    test_str = "test"

  if FLAGS.id is None:
    exp_id = "exp_" + FLAGS.dataset + "_" + FLAGS.model
    exp_id = gen_id(exp_id)
  else:
    exp_id = FLAGS.id

  if FLAGS.results is not None:
    save_folder = os.path.realpath(
        os.path.abspath(os.path.join(FLAGS.results, exp_id)))
    if not os.path.exists(save_folder):
      os.makedirs(save_folder)
  else:
    save_folder = None

  if FLAGS.logs is not None:
    logs_folder = os.path.realpath(
        os.path.abspath(os.path.join(FLAGS.logs, exp_id)))
    if not os.path.exists(logs_folder):
      os.makedirs(logs_folder)
  else:
    logs_folder = None

  # Configures dataset objects.
  log.info("Building dataset")
  train_data = get_dataset(FLAGS.dataset, train_str)
  trainval_data = get_dataset(
      FLAGS.dataset,
      train_str,
      num_batches=100,
      data_aug=False,
      cycle=False,
      prefetch=False)
  test_data = get_dataset(
      FLAGS.dataset, test_str, data_aug=False, cycle=False, prefetch=False)

  # Trains a model.
  acc = train_model(
      exp_id,
      config,
      train_data,
      test_data,
      trainval_data,
      save_folder=save_folder,
      logs_folder=logs_folder)
  log.info("Final test accuracy = {:.3f}".format(acc * 100))
def main():
    # Loads parammeters.
    config = _get_config()

    if FLAGS.id is None:
        exp_id = "exp_" + DATASET + "_" + FLAGS.model
        exp_id = gen_id(exp_id)
    else:
        exp_id = FLAGS.id

    if FLAGS.results is not None:
        save_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.results, exp_id)))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
    else:
        save_folder = None

    if FLAGS.logs is not None:
        logs_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.logs, exp_id)))
        if not os.path.exists(logs_folder):
            os.makedirs(logs_folder)
    else:
        logs_folder = None

    # Initializes variables.
    with tf.Graph().as_default():
        np.random.seed(0)
        tf.set_random_seed(1234)

        # Configures dataset objects.
        log.info("Building dataset")
        inp, label = _get_dataset(config)

        # Builds models.
        log.info("Building models")
        model = _get_model(config,
                           inp,
                           label,
                           num_replica=FLAGS.num_gpu,
                           num_pass=FLAGS.num_pass,
                           is_training=True)

        # Trains a model.
        train_model(exp_id,
                    config,
                    model,
                    save_folder=save_folder,
                    logs_folder=logs_folder)
Beispiel #3
0
def main():
    # Loads parammeters.
    config = _get_config()

    if FLAGS.id is None:
        exp_id = "exp_" + DATASET + "_" + FLAGS.model
        exp_id = gen_id(exp_id)
    else:
        exp_id = FLAGS.id

    if FLAGS.results is not None:
        save_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.results, exp_id)))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
    else:
        save_folder = None

    if FLAGS.logs is not None:
        logs_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.logs, exp_id)))
        if not os.path.exists(logs_folder):
            os.makedirs(logs_folder)
    else:
        logs_folder = None

    # Configures dataset objects.
    log.info("Building dataset")
    train_data = get_dataset(DATASET,
                             "train",
                             batch_size=config.batch_size,
                             preprocessor=config.preprocessor)

    # Trains a model.
    train_model(exp_id,
                config,
                train_data,
                save_folder=save_folder,
                logs_folder=logs_folder)
flags.DEFINE_integer("num_pass", 1, "Number of forward-backwad passes")
flags.DEFINE_integer("min_interval", 7200, "Minimum number of seconds")
flags.DEFINE_string("model", "resnet-50", "Model name")
flags.DEFINE_string("machine", None, "Preferred machine")
FLAGS = flags.FLAGS
DATASET = "imagenet"

# Get dispatcher factory.
if FLAGS.local:
  dispatch_factory = LocalCommandDispatcherFactory()
else:
  dispatch_factory = SlurmCommandDispatcherFactory("slurm_config.json")

# Generate experiment ID.
if FLAGS.id is None:
  exp_id = gen_id("exp_" + DATASET + "_" + FLAGS.model)
  restore = False
  # raise Exception("You need to specify model ID.")
else:
  exp_id = FLAGS.id
  restore = True

save_folder = os.path.realpath(
    os.path.abspath(os.path.join(FLAGS.results, exp_id)))

while True:

  # Check if we need to launch another job.
  if os.path.exists(save_folder):
    latest_ckpt = tf.train.latest_checkpoint(save_folder)
    cur_steps = int(latest_ckpt.split("-")[-1])
def main():
    # Loads parammeters.
    config = _get_config()
    # config.margin = FLAGS.margin
    # print('config margin = {}'.format(config.margin))
    assert (FLAGS.dataset == 'cifar-100')
    # if FLAGS.dataset == "cifar-10":
    #   config.num_classes = 10
    # elif FLAGS.dataset == "cifar-100":
    #   config.num_classes = 100
    # else:
    #   raise ValueError("Unknown dataset name {}".format(FLAGS.dataset))

    if FLAGS.validation:
        train_str = "traintrain"
        test_str = "trainval"
        log.warning("Running validation set")
    else:
        train_str = "train"
        test_str = "test"

    if FLAGS.id is None:
        dataset_name = FLAGS.dataset
        exp_id = "exp_" + dataset_name + "_" + FLAGS.model
        exp_id = gen_id(exp_id)
    else:
        exp_id = FLAGS.id
        dataset_name = exp_id.split("_")[1]

    if FLAGS.results is not None:
        save_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.results, exp_id)))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
    else:
        save_folder = None

    if FLAGS.logs is not None:
        logs_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.logs, exp_id)))
        if not os.path.exists(logs_folder):
            os.makedirs(logs_folder)
    else:
        logs_folder = None

    # Configures dataset objects.
    log.info("Building dataset")
    train_data = get_dataset(dataset_name, train_str)
    trainval_data = get_dataset(dataset_name,
                                train_str,
                                num_batches=100,
                                data_aug=False,
                                cycle=False,
                                prefetch=False)
    test_data = get_dataset(dataset_name,
                            test_str,
                            data_aug=False,
                            cycle=False,
                            prefetch=False)

    # Trains a model.
    acc = train_model(exp_id,
                      config,
                      train_data,
                      test_data,
                      trainval_data,
                      save_folder=save_folder,
                      logs_folder=logs_folder)
    log.info("Final test accuracy = {:.3f}".format(acc * 100))
Beispiel #6
0
def main():
    config = _get_config()  #获取配置参数
    #设置数据集---验证数据集配置
    if FLAGS.dataset == "cifar-10":
        config.num_classes = 10
    elif FLAGS.dataset == "cifar-100":
        config.num_classes = 100
    else:
        raise ValueError("Unknown dataset name {}",
                         format(FLAGS.dataset))  #输出错误信息,用于检查代码输入是否正确

    # 有关验证集使用情况
    if FLAGS.validation:  #有关验证集使用情况
        train_str = "traintrain"
        test_str = "trainval"
        log.warning("Running validation set")
    else:
        train_str = "train"
        test_str = "test"

    #用于存储训练的模型
    if FLAGS.id is None:
        dataset_name = FLAGS.dataset
        exp_id = "exp_" + dataset_name + "_" + FLAGS.model
        exp_id = gen_id(exp_id)
    else:
        exp_id = FLAGS.id
        dataset_name = exp_id.split("_")[1]

    #创建保存模型训练的结果的文件夹
    if FLAGS.results is not None:
        save_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.results, exp_id)))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
    else:
        save_folder = None

    #创建保存日志的文件夹
    if FLAGS.logs is not None:
        logs_folder = os.path.realpath(
            os.path.abspath(os.path.join(FLAGS.logs, exp_id)))
        if not os.path.exists(logs_folder):
            os.makedirs(logs_folder)
    else:
        logs_folder = None

    #创建训练集验证集和测试集
    log.info("Building dataset")
    train_data = get_dataset(dataset_name, train_str)
    # print(dataset_name,train_str)
    trainval_data = get_dataset(dataset_name,
                                train_str,
                                num_batches=100,
                                data_aug=False,
                                cycle=False,
                                prefetch=False)
    test_data = get_dataset(dataset_name,
                            test_str,
                            data_aug=False,
                            cycle=False,
                            prefetch=False)

    #模型训练
    acc = train_model(exp_id,
                      config,
                      train_data,
                      test_data,
                      trainval_data,
                      save_folder=save_folder,
                      logs_folder=logs_folder)
    log.info("final test accuracy = {:.3f}".format(acc * 100))