def get_opt(self, init_lr, num_train_steps, **kargs):

        learning_rate = init_lr
        if self.config.get("decay", "no") == "decay":
            print("==apply lr decay==")
            learning_rate = self.lr_decay_fn(learning_rate, num_train_steps,
                                             **kargs)
        if self.config.get("warmup", "no") == "warmup":
            print("==apply warmup==")
            learning_rate = self.warm_up(learning_rate, init_lr, **kargs)
        self.learning_rate = learning_rate  #* (self.config.get('gpu_count', 1) / 2)
        # self.learning_rate = learning_rate / np.sqrt(self.config.get('gpu_count', 1) / 2)
        # self.learning_rate = learning_rate * np.sqrt(self.config.get('gpu_count', 1)) * 2
        self.single_node_learning = learning_rate

        # add uber horvod distributed optimizer
        if hvd and self.config["opt_type"] == "hvd":
            print("==optimizer hvd size=={}".format(
                self.config.get("worker_count", hvd.size())))
            opt = self.optimizer_op(
                self.learning_rate *
                self.config.get("worker_count", hvd.size()), **kargs)
            self.opt = hvd.DistributedOptimizer(opt)
            self.distributed_hooks = [hvd.BroadcastGlobalVariablesHook(0)]
        # add pai soar distributed optimizer
        elif pai and self.config["opt_type"] == "pai_soar":
            print("==optimizer pai_soar size=={}".format(
                self.config.get("worker_count", 4)))
            opt = self.optimizer_op(
                self.learning_rate * self.config.get("worker_count", 4),
                **kargs)
            self.opt = pai.ReplicatedVarsOptimizer(opt,
                                                   clip_norm=self.config.get(
                                                       "clip_norm", 1.0))
            self.distributed_hooks = []
        # add tensorflow ps sync distributed optimizer
        elif self.config["opt_type"] == "ps_sync":
            print("==optimizer ps_sync size=={}".format(
                self.config.get("worker_count", 4)))
            opt = self.optimizer_op(
                self.learning_rate * self.config.get("worker_count", 4),
                **kargs)
            self.opt = tf.train.SyncReplicasOptimizer(
                opt,
                replicas_to_aggregate=self.config.get("worker_count", 4),
                total_num_replicas=self.config.get("worker_count", 4))
            self.distributed_hooks = [
                self.opt.make_session_run_hook(self.config["is_chief"],
                                               num_tokens=0)
            ]
        elif self.config["opt_type"] == "ps":
            print("==optimizer ps_async size=={}".format(
                self.config.get("worker_count", 4)))
            self.opt = self.optimizer_op(
                self.learning_rate * self.config.get("worker_count", 4),
                **kargs)
        else:
            print("==initialization of single node optimizer==")
            self.opt = self.optimizer_op(self.learning_rate, **kargs)
            self.distributed_hooks = []
Exemple #2
0
def run_model(target, num_workers, global_step):
    ##########################
    #  Config learning_rate  #
    ##########################
    learning_rate = optimizer_utils.configure_learning_rate(
        FLAGS.num_sample_per_epoch, global_step)

    ##########################################################
    #  Config optimizer and Wrapper optimizer with PAI-Soar  #
    ##########################################################
    samples_per_step = FLAGS.batch_size
    optimizer = optimizer_utils.configure_optimizer(learning_rate)
    if FLAGS.enable_paisoar:
        import paisoar
        optimizer = paisoar.ReplicatedVarsOptimizer(
            optimizer, clip_norm=FLAGS.max_gradient_norm)
        ctx = paisoar.Config.get()
        samples_per_step *= len(ctx.device_indices) * num_workers

    #######################
    #  Config model func  #
    #######################
    model_fn = model_factory.get_model_fn(FLAGS.model_name,
                                          num_classes=FLAGS.num_classes,
                                          weight_decay=FLAGS.weight_decay,
                                          is_training=True)

    #############################
    #  Config dataset iterator  #
    #############################
    with tf.device('/cpu:0'):
        train_image_size = model_fn.default_image_size

        # split dataset by worker
        data_sources = get_tfrecord_files(
            _DATASET_TRAIN_FILES[FLAGS.dataset_name] or FLAGS.train_files,
            num_workers)

        # select the preprocessing func
        preprocessing_fn = preprocessing_factory.get_preprocessing(
            FLAGS.preprocessing_name or FLAGS.model_name,
            is_training=True) if (FLAGS.preprocessing_name
                                  or FLAGS.model_name) else None

        dataset_iterator = dataset_factory.get_dataset_iterator(
            FLAGS.dataset_name, train_image_size, preprocessing_fn,
            data_sources, FLAGS.reader)
    ###############################################
    #  Config loss_func and Wrapper with PAI-Soar #
    ###############################################
    accuracy = []

    def loss_fn():
        with tf.device('/cpu:0'):
            images, labels = dataset_iterator.get_next()
        logits, end_points = model_fn(images)
        loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
                                                      logits=tf.cast(
                                                          logits, tf.float32),
                                                      weights=1.0)
        if 'AuxLogits' in end_points:
            loss += tf.losses.sparse_softmax_cross_entropy(
                labels=labels,
                logits=tf.cast(end_points['AuxLogits'], tf.float32),
                weights=0.4)
        per_accuracy = tf.reduce_mean(
            tf.cast(tf.equal(tf.argmax(logits, axis=1), labels), tf.float32))
        accuracy.append(per_accuracy)
        return loss

    # wrapper loss_fn with PAI-Soar 2.0
    loss = optimizer.compute_loss(loss_fn, loss_scale=FLAGS.loss_scale) if FLAGS.enable_paisoar \
      else loss_fn()

    ########################
    #  Config train tensor #
    ########################
    train_op = optimizer.minimize(loss, global_step=global_step)

    ###############################################
    #  Log trainable or optimizer variables info, #
    #  including name and size.                   #
    ###############################################
    log_trainable_or_optimizer_vars_info()

    ################
    # Restore ckpt #
    ################
    if FLAGS.model_dir and FLAGS.task_type == 'finetune':
        utils.load_checkpoint()

    #########################
    # Config training hooks #
    #########################
    params = dict()
    if FLAGS.log_loss_every_n_iters > 0:
        tensors_to_log = {
            'loss': loss if isinstance(loss, tf.Tensor) else loss.replicas[0],
            'accuracy': tf.reduce_mean(accuracy),
            'lrate': learning_rate
        }
        params['tensors_to_log'] = tensors_to_log
        params['samples_per_step'] = samples_per_step
    hooks = get_hooks(params=params)

    ###########################
    # Kicks off the training. #
    ###########################
    logging.info('training starts.')

    with tf.train.MonitoredTrainingSession(target,
                                           is_chief=(FLAGS.task_index == 0),
                                           hooks=hooks) as sess:
        try:
            while not sess.should_stop():
                sess.run(train_op)
        except tf.errors.OutOfRangeError:
            print('All threads done.')
        except Exception as e:
            import sys
            import traceback
            logging.error(e.message)
            traceback.print_exc(file=sys.stdout)
    logging.info('training ends.')