Exemplo n.º 1
0
def eval_model(config,
               trn_model,
               val_model,
               save_folder,
               logs_folder=None,
               ckpt_num=-1):
    log.info("Config: {}".format(config.__dict__))
    exp_logger = ExperimentLogger(logs_folder)
    # Initializes variables.
    with tf.Session() as sess:
        # Start the queue runners.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        saver = tf.train.Saver()
        if ckpt_num == -1:
            ckpt = tf.train.latest_checkpoint(save_folder)
        elif ckpt_num >= 0:
            ckpt = os.path.join(save_folder, "model.ckpt-{}".format(ckpt_num))
        else:
            raise ValueError("Invalid checkpoint number {}".format(ckpt_num))
        log.info("Restoring from {}".format(ckpt))
        if not os.path.exists(ckpt + ".meta"):
            raise ValueError("Checkpoint not exists")
        saver.restore(sess, ckpt)
        train_acc = evaluate(sess, trn_model, num_batch=100)
        val_acc = evaluate(sess, val_model, num_batch=NUM_BATCH)
        niter = int(ckpt.split("-")[-1])
        exp_logger.log_train_acc(niter, train_acc)
        exp_logger.log_valid_acc(niter, val_acc)

        # Stop queues.
        coord.request_stop()
        coord.join(threads)
    return val_acc
Exemplo n.º 2
0
def eval_model(config, train_data, test_data, save_folder, logs_folder=None):
    log.info("Config: {}".format(config.__dict__))

    with tf.Graph().as_default():
        np.random.seed(0)
        tf.set_random_seed(1234)
        exp_logger = ExperimentLogger(logs_folder)

        # Builds models.
        log.info("Building models")
        mvalid = get_model(config)

        # # A hack to load compatible models.
        # variables = tf.global_variables()
        # names = map(lambda x: x.name, variables)
        # names = map(lambda x: x.replace("Model/", "Model/Towers/"), names)
        # names = map(lambda x: x.replace(":0", ""), names)
        # var_dict = dict(zip(names, variables))

        # Initializes variables.
        with tf.Session() as sess:
            # saver = tf.train.Saver(var_dict)
            saver = tf.train.Saver()
            ckpt = tf.train.latest_checkpoint(save_folder)
            # log.fatal(ckpt)
            saver.restore(sess, ckpt)
            train_acc = evaluate(sess, mvalid, train_data)
            val_acc = evaluate(sess, mvalid, test_data)
            niter = int(ckpt.split("-")[-1])
            exp_logger.log_train_acc(niter, train_acc)
            exp_logger.log_valid_acc(niter, val_acc)
        return val_acc
Exemplo n.º 3
0
def train_model(exp_id,
                config,
                train_iter,
                test_iter,
                trainval_iter=None,
                save_folder=None,
                logs_folder=None):
    """Trains a CIFAR model.

  Args:
      exp_id: String. Experiment ID.
      config: Config object
      train_data: Dataset iterator.
      test_data: Dataset iterator.

  Returns:
      acc: Final test accuracy
  """
    # log.info("Config: {}".format(config.__dict__))

    log.info("Config: {}".format(config.__dict__))
    exp_logger = ExperimentLogger(logs_folder)

    # Initializes variables.
    with tf.Graph().as_default():
        np.random.seed(0)
        if not hasattr(config, "seed"):
            tf.set_random_seed(1234)
            log.info("Setting tensorflow random seed={:d}".format(1234))
        else:
            log.info("Setting tensorflow random seed={:d}".format(config.seed))
            tf.set_random_seed(config.seed)
        m, mvalid = _get_models(config)

        with tf.Session() as sess:
            saver = tf.train.Saver()
            if FLAGS.restore:
                log.info("Restore checkpoint \"{}\"".format(save_folder))
                saver.restore(sess, tf.train.latest_checkpoint(save_folder))
            else:
                sess.run(tf.global_variables_initializer())
            niter_start = int(m.global_step.eval())
            w_list = tf.trainable_variables()
            log.info("Model initialized.")
            num_params = np.array([
                np.prod(np.array([int(ss) for ss in w.get_shape()]))
                for w in w_list
            ]).sum()
            log.info('\033[92m' +
                     "Number of parameters {}".format(num_params) + '\033[00m')

            # Set up learning rate schedule.
            if config.lr_scheduler_type == "fixed":
                lr_scheduler = FixedLearnRateScheduler(sess,
                                                       m,
                                                       config.base_learn_rate,
                                                       config.lr_decay_steps,
                                                       lr_list=config.lr_list)
            else:
                raise Exception("Unknown learning rate scheduler {}".format(
                    config.lr_scheduler))

            for niter in tqdm(range(niter_start, config.max_train_iter),
                              desc=exp_id):
                lr_scheduler.step(niter)
                ce = train_step(sess, m, train_iter.next())

                if (niter + 1) % config.disp_iter == 0 or niter == 0:
                    exp_logger.log_train_ce(niter, ce)

                if (niter + 1) % config.valid_iter == 0 or niter == 0:
                    if trainval_iter is not None:
                        trainval_iter.reset()
                        acc = evaluate(sess, mvalid, trainval_iter)
                        exp_logger.log_train_acc(niter, acc)
                    test_iter.reset()
                    acc = evaluate(sess, mvalid, test_iter)
                    exp_logger.log_valid_acc(niter, acc)

                if (niter + 1) % config.save_iter == 0 or niter == 0:
                    save(sess, saver, m.global_step, config, save_folder)
                    exp_logger.log_learn_rate(niter, m.lr.eval())

            test_iter.reset()
            acc = evaluate(sess, mvalid, test_iter)
    return acc
Exemplo n.º 4
0
def train_model(exp_id,
                config,
                train_iter,
                save_folder=None,
                logs_folder=None):
    """Trains a CIFAR model.

  Args:
    exp_id: String. Experiment ID.
    config: Config object
    train_data: Dataset object

  Returns:
    acc: Final test accuracy
  """
    log.info("Config: {}".format(config.__dict__))
    exp_logger = ExperimentLogger(logs_folder)

    # Initializes variables.
    with tf.Graph().as_default():
        np.random.seed(0)
        tf.set_random_seed(1234)

        # Builds models.
        log.info("Building models")
        with tf.name_scope("Train"):
            with tf.variable_scope("Model", reuse=None):
                m = get_model(config,
                              num_replica=FLAGS.num_gpu,
                              num_pass=FLAGS.num_pass,
                              is_training=True)

        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True)) as sess:
            saver = tf.train.Saver(
                max_to_keep=None)  ### Keep all checkpoints here!
            if FLAGS.restore:
                log.info("Restore checkpoint \"{}\"".format(save_folder))
                saver.restore(sess, tf.train.latest_checkpoint(save_folder))
            else:
                sess.run(tf.global_variables_initializer())

            max_train_iter = config.max_train_iter
            niter_start = int(m.global_step.eval())

            # Add upper bound to the number of steps.
            if FLAGS.max_num_steps > 0:
                max_train_iter = min(max_train_iter,
                                     niter_start + FLAGS.max_num_steps)

            # Set up learning rate schedule.
            if config.lr_scheduler == "fixed":
                lr_scheduler = FixedLearnRateScheduler(sess,
                                                       m,
                                                       config.base_learn_rate,
                                                       config.lr_decay_steps,
                                                       lr_list=config.lr_list)
            elif config.lr_scheduler == "exponential":
                lr_scheduler = ExponentialLearnRateScheduler(
                    sess, m, config.base_learn_rate, config.lr_decay_offset,
                    config.max_train_iter, config.final_learn_rate,
                    config.lr_decay_interval)
            else:
                raise Exception("Unknown learning rate scheduler {}".format(
                    config.lr_scheduler))

            for niter in tqdm(range(niter_start, config.max_train_iter),
                              desc=exp_id):
                lr_scheduler.step(niter)
                ce = train_step(sess, m, train_iter.next())

                if (niter + 1) % config.disp_iter == 0 or niter == 0:
                    exp_logger.log_train_ce(niter, ce)

                if (niter + 1) % config.save_iter == 0 or niter == 0:
                    if save_folder is not None:
                        save(sess, saver, m.global_step, config, save_folder)
                    exp_logger.log_learn_rate(niter, m.lr.eval())
Exemplo n.º 5
0
def train_model(exp_id, config, model, save_folder=None, logs_folder=None):
    """Trains an ImageNet model.

  Args:
    exp_id: String. Experiment ID.
    config: Config object.
    model: Model object.
    save_folder: Folder to save all checkpoints.
    logs_folder: Folder to save all training logs.

  Returns:
    acc: Final test accuracy
  """
    log.info("Config: {}".format(config.__dict__))
    exp_logger = ExperimentLogger(logs_folder)

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:

        found_old_name = False
        if FLAGS.restore:
            ckpt = tf.train.latest_checkpoint(save_folder)
            from tensorflow.python import pywrap_tensorflow
            reader = pywrap_tensorflow.NewCheckpointReader(ckpt)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in sorted(var_to_shape_map):
                if key == "Train/Model/learn_rate":
                    found_old_name = True
                    break

        # A hack to load compatible models.
        if found_old_name:
            variables = tf.global_variables()
            names = map(lambda x: x.name, variables)
            names = map(lambda x: x.strip(":0"), names)
            names = map(
                lambda x: x.replace("Model/learn_rate",
                                    "Train/Model/learn_rate"), names)
            var_dict = dict(zip(names, variables))
        else:
            var_dict = None

        ### Keep all checkpoints here!
        #saver = tf.train.Saver(max_to_keep=None)
        saver = tf.train.Saver(var_dict, max_to_keep=None)
        if FLAGS.restore:
            log.info("Restore checkpoint \"{}\"".format(save_folder))
            saver.restore(sess, tf.train.latest_checkpoint(save_folder))
        else:
            sess.run(tf.global_variables_initializer())

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        # Count parameters.
        w_list = tf.trainable_variables()
        num_params = np.array([
            np.prod(np.array([int(ss) for ss in w.get_shape()]))
            for w in w_list
        ]).sum()
        log.info("Number of parameters {}".format(num_params))

        max_train_iter = config.max_train_iter
        niter_start = int(model.global_step.eval())

        # Add upper bound to the number of steps.
        if FLAGS.max_num_steps > 0:
            max_train_iter = min(max_train_iter,
                                 niter_start + FLAGS.max_num_steps)

        # Set up learning rate schedule.
        if config.lr_scheduler == "fixed":
            lr_scheduler = FixedLearnRateScheduler(sess,
                                                   model,
                                                   config.base_learn_rate,
                                                   config.lr_decay_steps,
                                                   lr_list=config.lr_list)
        else:
            raise Exception("Unknown learning rate scheduler {}".format(
                config.lr_scheduler))

        for niter in tqdm(range(niter_start, config.max_train_iter),
                          desc=exp_id):
            lr_scheduler.step(niter)
            ce = train_step(sess, model)

            if (niter + 1) % config.disp_iter == 0 or niter == 0:
                exp_logger.log_train_ce(niter, ce)

            if (niter + 1) % config.save_iter == 0 or niter == 0:
                if save_folder is not None:
                    save(sess, saver, model.global_step, config, save_folder)
                exp_logger.log_learn_rate(niter, model.lr.eval())
Exemplo n.º 6
0
def train_model(exp_id,
                config,
                train_iter,
                test_iter,
                trainval_iter=None,
                save_folder=None,
                logs_folder=None):
  """Trains a CIFAR model.

  Args:
      exp_id: String. Experiment ID.
      config: Config object
      train_data: Dataset iterator.
      test_data: Dataset iterator.

  Returns:
      acc: Final test accuracy
  """
  log.info("Config: {}".format(config.__dict__))
  exp_logger = ExperimentLogger(logs_folder)

  # Initializes variables.
  with tf.Graph().as_default():
    np.random.seed(0)
    if not hasattr(config, "seed"):
      tf.set_random_seed(1234)
      log.info("Setting tensorflow random seed={:d}".format(1234))
    else:
      log.info("Setting tensorflow random seed={:d}".format(config.seed))
      tf.set_random_seed(config.seed)
    m, mvalid = get_models(config)

    with tf.Session() as sess:
      saver = tf.train.Saver()
      sess.run(tf.global_variables_initializer())

      # Set up learning rate schedule.
      if config.lr_scheduler_type == "fixed":
        lr_scheduler = FixedLearnRateScheduler(
            sess,
            m,
            config.base_learn_rate,
            config.lr_decay_steps,
            lr_list=config.lr_list)
      elif config.lr_scheduler_type == "exponential":
        lr_scheduler = ExponentialLearnRateScheduler(
            sess, m, config.base_learn_rate, config.lr_decay_offset,
            config.max_train_iter, config.final_learn_rate,
            config.lr_decay_interval)
      else:
        raise Exception("Unknown learning rate scheduler {}".format(
            config.lr_scheduler))

      for niter in tqdm(range(config.max_train_iter), desc=exp_id):
        lr_scheduler.step(niter)
        ce = train_step(sess, m, train_iter.next())

        if (niter + 1) % config.disp_iter == 0 or niter == 0:
          exp_logger.log_train_ce(niter, ce)

        if (niter + 1) % config.valid_iter == 0 or niter == 0:
          if trainval_iter is not None:
            trainval_iter.reset()
            acc = evaluate(sess, mvalid, trainval_iter)
            exp_logger.log_train_acc(niter, acc)
          test_iter.reset()
          acc = evaluate(sess, mvalid, test_iter)
          exp_logger.log_valid_acc(niter, acc)

        if (niter + 1) % config.save_iter == 0 or niter == 0:
          save(sess, saver, m.global_step, config, save_folder)
          exp_logger.log_learn_rate(niter, m.lr.eval())

      test_iter.reset()
      acc = evaluate(test_iter, -1)
  return acc
Exemplo n.º 7
0
def train_model(exp_id,
                config,
                train_iter,
                test_iter,
                trainval_iter=None,
                save_folder=None,
                logs_folder=None):
    log.info("congif:{}".format(config.__dict__))  #加入日志
    exp_logger = ExperimentLogger(logs_folder)
    # print(config.__dict__)
    #初始化所有变量
    with tf.Graph().as_default():
        np.random.seed(0)
        #设置tensorflow随机数
        if not hasattr(config, "seed"):  #具有什么属性
            tf.set_random_seed(1234)
            log.info("setting tesorflow random seed={:d}".format(1234))
        else:
            log.info("setting tensorflow random seed={:d}".format(config.seed))
            tf.set_random_seed(config.seed)

        #返回训练和验证的模型
        m, mvalid = _get_models(config)

        with tf.Session() as sess:
            saver = tf.train.Saver()

            #是否加载训练好的模型
            if FLAGS.restore:  #加载训练过的模型
                log.info("Restore checkpoint \"{}\"".format(save_folder))
                saver.restore(
                    sess, tf.train.latest_checkpoint(save_folder))  #加载存储好的模型
            else:
                sess.run(
                    tf.global_variables_initializer())  #如果是第一次训练,则需要初始化所有变量

            #
            niter_start = int(m.global_step.eval())
            w_list = tf.trainable_variables()
            log.info("Model initialized")

            # 用来计算总的参数量
            num_params = np.array([
                np.prod(np.array([int(ss) for ss in w.get_shape()]))
                for w in w_list
            ]).sum()
            log.info("Number of parameters {}".format(num_params))

            if config.lr_scheduler_type == "fixed":
                lr_scheduler = FixedLearnRateScheduler(sess,
                                                       m,
                                                       config.base_learn_rate,
                                                       config.lr_decay_steps,
                                                       lr_list=config.lr_list)
            else:
                raise Exception(
                    "Unknown learning rate shceduler {}".format(  #raise显示抛出异常
                        config.lr_scheduler))

            for niter in tqdm(range(niter_start, config.max_train_iter),
                              desc=exp_id):
                lr_scheduler.step(niter)
                ce = train_step(sess, m, train_iter.next())

                if (niter + 1) % config.disp_iter == 0 or niter == 0:
                    exp_logger.log_train_ce(niter, ce)
                if (niter + 1) % config.valid_iter == 0 or niter == 0:
                    if trainval_iter is not None:
                        trainval_iter.reset()
                        acc = evaluate(sess, mvalid, trainval_iter)
                        exp_logger.log_train_ce(niter, acc)
                    test_iter.reset()
                    acc = evaluate(sess, mvalid, test_iter)
                    exp_logger.log_valid_acc(niter, acc)
                if (niter + 1) % config.save_iter == 0 or niter == 0:
                    save(sess, saver, m.global_step, config, save_folder)
                    exp_logger.log_learn_rate(niter, m.lr.eval())

            test_iter.reset()
            acc = evaluate(sess, mvalid, test_iter)
        return acc