コード例 #1
0
    def setup_lrn_rate(self, global_step):
        """Setup the learning rate (and number of training iterations)."""

        batch_size = FLAGS.batch_size * (1 if not FLAGS.enbl_multi_gpu else
                                         mgw.size())
        if FLAGS.mobilenet_version == 1:
            nb_epochs = 100
            nb_epochs = 412
            idxs_epoch = [12000, 20000]
            step_rate = [200, 200, 4000]
            epoch_step = setup_lrn_rate_piecewise_constant(
                global_step, batch_size, idxs_epoch, step_rate)
            decay_rates = [0.985, 0.980, 0.505]
            decay_rate = setup_lrn_rate_piecewise_constant(
                global_step, batch_size, idxs_epoch, decay_rates)
            lrn_rate = setup_lrn_rate_exponential_decay(
                global_step, batch_size, epoch_step, decay_rate)
            nb_iters = int(30000)
        elif FLAGS.mobilenet_version == 2:
            nb_epochs = 412
            epoch_step = 500
            decay_rate = 0.9  # which is better, 0.98 OR (0.98 ** epoch_step)?
            lrn_rate = setup_lrn_rate_exponential_decay(
                global_step, batch_size, epoch_step, decay_rate)
            nb_iters = int(15000)
        else:
            raise ValueError('invalid MobileNet version: {}'.format(
                FLAGS.mobilenet_version))

        return lrn_rate, nb_iters
コード例 #2
0
    def setup_lrn_rate(self, global_step):
        """Setup the learning rate (and number of training iterations)."""

        batch_size = FLAGS.batch_size * (1 if not FLAGS.enbl_multi_gpu else
                                         mgw.size())
        if FLAGS.mobilenet_version == 1:
            nb_epochs = 100
            idxs_epoch = [30, 60, 80, 90]
            decay_rates = [1.0, 0.1, 0.01, 0.001, 0.0001]
            lrn_rate = setup_lrn_rate_piecewise_constant(
                global_step, batch_size, idxs_epoch, decay_rates)
            nb_iters = int(FLAGS.nb_smpls_train * nb_epochs *
                           FLAGS.nb_epochs_rat / batch_size)
        elif FLAGS.mobilenet_version == 2:
            nb_epochs = 412
            epoch_step = 2.5
            decay_rate = 0.98**epoch_step  # which is better, 0.98 OR (0.98 ** epoch_step)?
            lrn_rate = setup_lrn_rate_exponential_decay(
                global_step, batch_size, epoch_step, decay_rate)
            nb_iters = int(FLAGS.nb_smpls_train * nb_epochs *
                           FLAGS.nb_epochs_rat / batch_size)
        else:
            raise ValueError('invalid MobileNet version: {}'.format(
                FLAGS.mobilenet_version))

        return lrn_rate, nb_iters
コード例 #3
0
  def setup_lrn_rate(self, global_step):
    """Setup the learning rate (and number of training iterations)."""

    nb_epochs = 250
    idxs_epoch = [100, 150, 200]
    decay_rates = [1.0, 0.1, 0.01, 0.001]
    batch_size = FLAGS.batch_size * (1)
    lrn_rate = setup_lrn_rate_piecewise_constant(global_step, batch_size, idxs_epoch, decay_rates)
    nb_iters = int(FLAGS.nb_smpls_train * nb_epochs * FLAGS.nb_epochs_rat / batch_size)

    return lrn_rate, nb_iters
コード例 #4
0
  def setup_lrn_rate(self, global_step):
    """Setup the learning rate (and number of training iterations)."""

    nb_epochs = 100
    idxs_epoch = [30, 60, 80, 90]
    decay_rates = [1.0, 0.1, 0.01, 0.001, 0.0001]
    batch_size = FLAGS.batch_size * (1 if not FLAGS.enbl_multi_gpu else mgw.size())
    lrn_rate = setup_lrn_rate_piecewise_constant(global_step, batch_size, idxs_epoch, decay_rates)
    nb_iters = int(FLAGS.nb_smpls_train * nb_epochs * FLAGS.nb_epochs_rat / batch_size)

    return lrn_rate, nb_iters
コード例 #5
0
    def setup_lrn_rate(self, global_step):
        """Setup the learning rate (and number of training iterations)."""

        nb_epochs = 100
        idxs_epoch = [0.4, 0.8]
        decay_rates = [0.001, 0.0005, 0.0001]
        batch_size = FLAGS.batch_size * (1 if not FLAGS.enbl_multi_gpu else
                                         mgw.size())
        lrn_rate = setup_lrn_rate_piecewise_constant(global_step, batch_size,
                                                     idxs_epoch, decay_rates)
        nb_iters = int(12000)
        #nb_iters = int(200)
        return lrn_rate, nb_iters