Exemplo n.º 1
0
def eval_model(config, train_data, test_data, save_folder, logs_folder=None):
    log.info("Config: {}".format(config.__dict__))

    with tf.Graph().as_default():
        np.random.seed(0)
        tf.set_random_seed(1234)
        exp_logger = ExperimentLogger(logs_folder)

        # Builds models.
        log.info("Building models")
        mvalid = get_model(config)

        # # A hack to load compatible models.
        # variables = tf.global_variables()
        # names = map(lambda x: x.name, variables)
        # names = map(lambda x: x.replace("Model/", "Model/Towers/"), names)
        # names = map(lambda x: x.replace(":0", ""), names)
        # var_dict = dict(zip(names, variables))

        # Initializes variables.
        with tf.Session() as sess:
            # saver = tf.train.Saver(var_dict)
            saver = tf.train.Saver()
            ckpt = tf.train.latest_checkpoint(save_folder)
            # log.fatal(ckpt)
            saver.restore(sess, ckpt)
            train_acc = evaluate(sess, mvalid, train_data)
            val_acc = evaluate(sess, mvalid, test_data)
            niter = int(ckpt.split("-")[-1])
            exp_logger.log_train_acc(niter, train_acc)
            exp_logger.log_valid_acc(niter, val_acc)
        return val_acc
Exemplo n.º 2
0
def eval_model(config,
               trn_model,
               val_model,
               save_folder,
               logs_folder=None,
               ckpt_num=-1):
    log.info("Config: {}".format(config.__dict__))
    exp_logger = ExperimentLogger(logs_folder)
    # Initializes variables.
    with tf.Session() as sess:
        # Start the queue runners.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        saver = tf.train.Saver()
        if ckpt_num == -1:
            ckpt = tf.train.latest_checkpoint(save_folder)
        elif ckpt_num >= 0:
            ckpt = os.path.join(save_folder, "model.ckpt-{}".format(ckpt_num))
        else:
            raise ValueError("Invalid checkpoint number {}".format(ckpt_num))
        log.info("Restoring from {}".format(ckpt))
        if not os.path.exists(ckpt + ".meta"):
            raise ValueError("Checkpoint not exists")
        saver.restore(sess, ckpt)
        train_acc = evaluate(sess, trn_model, num_batch=100)
        val_acc = evaluate(sess, val_model, num_batch=NUM_BATCH)
        niter = int(ckpt.split("-")[-1])
        exp_logger.log_train_acc(niter, train_acc)
        exp_logger.log_valid_acc(niter, val_acc)

        # Stop queues.
        coord.request_stop()
        coord.join(threads)
    return val_acc
Exemplo n.º 3
0
def train_model(exp_id,
                config,
                train_iter,
                test_iter,
                trainval_iter=None,
                save_folder=None,
                logs_folder=None):
    """Trains a CIFAR model.

  Args:
      exp_id: String. Experiment ID.
      config: Config object
      train_data: Dataset iterator.
      test_data: Dataset iterator.

  Returns:
      acc: Final test accuracy
  """
    # log.info("Config: {}".format(config.__dict__))

    log.info("Config: {}".format(config.__dict__))
    exp_logger = ExperimentLogger(logs_folder)

    # Initializes variables.
    with tf.Graph().as_default():
        np.random.seed(0)
        if not hasattr(config, "seed"):
            tf.set_random_seed(1234)
            log.info("Setting tensorflow random seed={:d}".format(1234))
        else:
            log.info("Setting tensorflow random seed={:d}".format(config.seed))
            tf.set_random_seed(config.seed)
        m, mvalid = _get_models(config)

        with tf.Session() as sess:
            saver = tf.train.Saver()
            if FLAGS.restore:
                log.info("Restore checkpoint \"{}\"".format(save_folder))
                saver.restore(sess, tf.train.latest_checkpoint(save_folder))
            else:
                sess.run(tf.global_variables_initializer())
            niter_start = int(m.global_step.eval())
            w_list = tf.trainable_variables()
            log.info("Model initialized.")
            num_params = np.array([
                np.prod(np.array([int(ss) for ss in w.get_shape()]))
                for w in w_list
            ]).sum()
            log.info('\033[92m' +
                     "Number of parameters {}".format(num_params) + '\033[00m')

            # Set up learning rate schedule.
            if config.lr_scheduler_type == "fixed":
                lr_scheduler = FixedLearnRateScheduler(sess,
                                                       m,
                                                       config.base_learn_rate,
                                                       config.lr_decay_steps,
                                                       lr_list=config.lr_list)
            else:
                raise Exception("Unknown learning rate scheduler {}".format(
                    config.lr_scheduler))

            for niter in tqdm(range(niter_start, config.max_train_iter),
                              desc=exp_id):
                lr_scheduler.step(niter)
                ce = train_step(sess, m, train_iter.next())

                if (niter + 1) % config.disp_iter == 0 or niter == 0:
                    exp_logger.log_train_ce(niter, ce)

                if (niter + 1) % config.valid_iter == 0 or niter == 0:
                    if trainval_iter is not None:
                        trainval_iter.reset()
                        acc = evaluate(sess, mvalid, trainval_iter)
                        exp_logger.log_train_acc(niter, acc)
                    test_iter.reset()
                    acc = evaluate(sess, mvalid, test_iter)
                    exp_logger.log_valid_acc(niter, acc)

                if (niter + 1) % config.save_iter == 0 or niter == 0:
                    save(sess, saver, m.global_step, config, save_folder)
                    exp_logger.log_learn_rate(niter, m.lr.eval())

            test_iter.reset()
            acc = evaluate(sess, mvalid, test_iter)
    return acc
Exemplo n.º 4
0
def train_model(exp_id,
                config,
                train_iter,
                test_iter,
                trainval_iter=None,
                save_folder=None,
                logs_folder=None):
  """Trains a CIFAR model.

  Args:
      exp_id: String. Experiment ID.
      config: Config object
      train_data: Dataset iterator.
      test_data: Dataset iterator.

  Returns:
      acc: Final test accuracy
  """
  log.info("Config: {}".format(config.__dict__))
  exp_logger = ExperimentLogger(logs_folder)

  # Initializes variables.
  with tf.Graph().as_default():
    np.random.seed(0)
    if not hasattr(config, "seed"):
      tf.set_random_seed(1234)
      log.info("Setting tensorflow random seed={:d}".format(1234))
    else:
      log.info("Setting tensorflow random seed={:d}".format(config.seed))
      tf.set_random_seed(config.seed)
    m, mvalid = get_models(config)

    with tf.Session() as sess:
      saver = tf.train.Saver()
      sess.run(tf.global_variables_initializer())

      # Set up learning rate schedule.
      if config.lr_scheduler_type == "fixed":
        lr_scheduler = FixedLearnRateScheduler(
            sess,
            m,
            config.base_learn_rate,
            config.lr_decay_steps,
            lr_list=config.lr_list)
      elif config.lr_scheduler_type == "exponential":
        lr_scheduler = ExponentialLearnRateScheduler(
            sess, m, config.base_learn_rate, config.lr_decay_offset,
            config.max_train_iter, config.final_learn_rate,
            config.lr_decay_interval)
      else:
        raise Exception("Unknown learning rate scheduler {}".format(
            config.lr_scheduler))

      for niter in tqdm(range(config.max_train_iter), desc=exp_id):
        lr_scheduler.step(niter)
        ce = train_step(sess, m, train_iter.next())

        if (niter + 1) % config.disp_iter == 0 or niter == 0:
          exp_logger.log_train_ce(niter, ce)

        if (niter + 1) % config.valid_iter == 0 or niter == 0:
          if trainval_iter is not None:
            trainval_iter.reset()
            acc = evaluate(sess, mvalid, trainval_iter)
            exp_logger.log_train_acc(niter, acc)
          test_iter.reset()
          acc = evaluate(sess, mvalid, test_iter)
          exp_logger.log_valid_acc(niter, acc)

        if (niter + 1) % config.save_iter == 0 or niter == 0:
          save(sess, saver, m.global_step, config, save_folder)
          exp_logger.log_learn_rate(niter, m.lr.eval())

      test_iter.reset()
      acc = evaluate(test_iter, -1)
  return acc
Exemplo n.º 5
0
def train_model(exp_id,
                config,
                train_iter,
                test_iter,
                trainval_iter=None,
                save_folder=None,
                logs_folder=None):
    log.info("congif:{}".format(config.__dict__))  #加入日志
    exp_logger = ExperimentLogger(logs_folder)
    # print(config.__dict__)
    #初始化所有变量
    with tf.Graph().as_default():
        np.random.seed(0)
        #设置tensorflow随机数
        if not hasattr(config, "seed"):  #具有什么属性
            tf.set_random_seed(1234)
            log.info("setting tesorflow random seed={:d}".format(1234))
        else:
            log.info("setting tensorflow random seed={:d}".format(config.seed))
            tf.set_random_seed(config.seed)

        #返回训练和验证的模型
        m, mvalid = _get_models(config)

        with tf.Session() as sess:
            saver = tf.train.Saver()

            #是否加载训练好的模型
            if FLAGS.restore:  #加载训练过的模型
                log.info("Restore checkpoint \"{}\"".format(save_folder))
                saver.restore(
                    sess, tf.train.latest_checkpoint(save_folder))  #加载存储好的模型
            else:
                sess.run(
                    tf.global_variables_initializer())  #如果是第一次训练,则需要初始化所有变量

            #
            niter_start = int(m.global_step.eval())
            w_list = tf.trainable_variables()
            log.info("Model initialized")

            # 用来计算总的参数量
            num_params = np.array([
                np.prod(np.array([int(ss) for ss in w.get_shape()]))
                for w in w_list
            ]).sum()
            log.info("Number of parameters {}".format(num_params))

            if config.lr_scheduler_type == "fixed":
                lr_scheduler = FixedLearnRateScheduler(sess,
                                                       m,
                                                       config.base_learn_rate,
                                                       config.lr_decay_steps,
                                                       lr_list=config.lr_list)
            else:
                raise Exception(
                    "Unknown learning rate shceduler {}".format(  #raise显示抛出异常
                        config.lr_scheduler))

            for niter in tqdm(range(niter_start, config.max_train_iter),
                              desc=exp_id):
                lr_scheduler.step(niter)
                ce = train_step(sess, m, train_iter.next())

                if (niter + 1) % config.disp_iter == 0 or niter == 0:
                    exp_logger.log_train_ce(niter, ce)
                if (niter + 1) % config.valid_iter == 0 or niter == 0:
                    if trainval_iter is not None:
                        trainval_iter.reset()
                        acc = evaluate(sess, mvalid, trainval_iter)
                        exp_logger.log_train_ce(niter, acc)
                    test_iter.reset()
                    acc = evaluate(sess, mvalid, test_iter)
                    exp_logger.log_valid_acc(niter, acc)
                if (niter + 1) % config.save_iter == 0 or niter == 0:
                    save(sess, saver, m.global_step, config, save_folder)
                    exp_logger.log_learn_rate(niter, m.lr.eval())

            test_iter.reset()
            acc = evaluate(sess, mvalid, test_iter)
        return acc