Exemplo n.º 1
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    """
    Get configurations here. We pass some arguments from command line to init configurations, for training hyperparameters,
    you can set them in TrainConfig Class.

    Note: we set filter scale to 1 for pruned model, 2 for non-pruned model. The filters numbers of non-pruned
          model is two times larger than prunde model, e.g., [h, w, 64] <-> [h, w, 32].
    """
    cfg = TrainConfig(dataset=args.dataset,
                is_training=True,
                random_scale=args.random_scale,
                random_mirror=args.random_mirror,
                filter_scale=args.filter_scale)
    cfg.display()

    # Setup training network and training samples
    train_reader = ImageReader(cfg=cfg, mode='train')
    train_net = ICNet_BN(image_reader=train_reader,
                            cfg=cfg, mode='train')

    loss_sub4, loss_sub24, loss_sub124, reduced_loss = create_losses(train_net, train_net.labels, cfg)

    # Setup validation network and validation samples
    with tf.variable_scope('', reuse=True):
        val_reader = ImageReader(cfg, mode='eval')
        val_net = ICNet_BN(image_reader=val_reader,
                            cfg=cfg, mode='train')

        val_loss_sub4, val_loss_sub24, val_loss_sub124, val_reduced_loss = create_losses(val_net, val_net.labels, cfg)

    # Using Poly learning rate policy
    base_lr = tf.constant(cfg.LEARNING_RATE)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / cfg.TRAINING_STEPS), cfg.POWER))

    # Set restore variable
    restore_var = tf.global_variables()
    all_trainable = [v for v in tf.trainable_variables() if ('beta' not in v.name and 'gamma' not in v.name) or args.train_beta_gamma]

    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, cfg.MOMENTUM)
        grads = tf.gradients(reduced_loss, all_trainable)
        train_op = opt_conv.apply_gradients(zip(grads, all_trainable))

    # Create session & restore weights (Here we only need to use train_net to create session since we reuse it)
    train_net.create_session()
    # train_net.restore(cfg.model_weight, restore_var)
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5)

    # Iterate over training steps.
    for step in range(cfg.TRAINING_STEPS):
        start_time = time.time()

        feed_dict = {step_ph: step}
        if step % cfg.SAVE_PRED_EVERY == 0:
            loss_value, loss1, loss2, loss3, val_loss_value, _ = train_net.sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, val_reduced_loss, train_op], feed_dict=feed_dict)
            train_net.save(saver, cfg.SNAPSHOT_DIR, step)
        else:
            loss_value, loss1, loss2, loss3, val_loss_value, _ = train_net.sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, val_reduced_loss, train_op], feed_dict=feed_dict)

        duration = time.time() - start_time
        print('step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f}, val_loss: {:.3f} ({:.3f} sec/step)'.\
                    format(step, loss_value, loss1, loss2, loss3, val_loss_value, duration))
Exemplo n.º 2
0
def main():
    """Create the model and start the training."""
    args = get_arguments()
    """
    Get configurations here. We pass some arguments from command line to init configurations, for training hyperparameters, 
    you can set them in TrainConfig Class.

    Note: we set filter scale to 1 for pruned model, 2 for non-pruned model. The filters numbers of non-pruned
          model is two times larger than prunde model, e.g., [h, w, 64] <-> [h, w, 32].
    """
    cfg = TrainConfig(dataset=args.dataset,
                      is_training=True,
                      random_scale=args.random_scale,
                      random_mirror=args.random_mirror,
                      filter_scale=args.filter_scale)
    if args.num_classes is not None:
        cfg.param["num_classes"] = args.num_classes
    if args.data_dir is not None:
        cfg.param["data_dir"] = args.data_dir
    if args.val_list is not None:
        cfg.param["eval_list"] = args.val_list
    if args.train_list is not None:
        cfg.param["train_list"] = args.train_list
    if args.ignore_label is not None:
        cfg.param["ignore_label"] = args.ignore_label
    if args.eval_size is not None:
        cfg.param["eval_size"] = [
            int(x.strip()) for x in args.eval_size.split("x")[::-1]
        ]
    if args.training_size is not None:
        cfg.TRAINING_SIZE = [
            int(x.strip()) for x in args.training_size.split("x")[::-1]
        ]
    if args.batch_size is not None:
        cfg.BATCH_SIZE = args.batch_size
    if args.learning_rate is not None:
        cfg.LEARNING_RATE = args.learning_rate
    if args.restore_from is not None:
        cfg.model_weight = args.restore_from
    if args.snapshot_dir is not None:
        cfg.SNAPSHOT_DIR = args.snapshot_dir
    if args.restore_from == "scratch":
        from tqdm import tqdm
        import cv2
        import joblib as joblib
        if not args.img_mean:
            print(
                "Calculating img mean for custom dataset. To prevent this, specify it with --img-mean next time"
            )
            image_files, annotation_files = read_labeled_image_list(
                cfg.param["data_dir"], cfg.param["train_list"])
            means = joblib.Parallel(n_jobs=6)(
                joblib.delayed(calc_mean)(image_file, cv2)
                for image_file in tqdm(image_files, desc="calc img mean"))
            cfg.IMG_MEAN = np.mean(means, axis=0).tolist()
        else:
            cfg.IMG_MEAN = [float(x.strip()) for x in args.img_mean.split(",")]

    cfg.display()

    # Setup training network and training samples
    train_reader = ImageReader(cfg=cfg, mode='train')
    train_net = ICNet_BN(image_reader=train_reader, cfg=cfg, mode='train')

    loss_sub4, loss_sub24, loss_sub124, reduced_loss = create_losses(
        train_net, train_net.labels, cfg)

    # Setup validation network and validation samples
    with tf.variable_scope('', reuse=True):
        val_reader = ImageReader(cfg, mode='eval')
        val_net = ICNet_BN(image_reader=val_reader, cfg=cfg, mode='train')

        val_loss_sub4, val_loss_sub24, val_loss_sub124, val_reduced_loss = create_losses(
            val_net, val_net.labels, cfg)

    # Using Poly learning rate policy
    base_lr = tf.constant(cfg.LEARNING_RATE)
    step_ph = tf.placeholder(dtype=tf.float32, shape=())
    learning_rate = tf.scalar_mul(
        base_lr, tf.pow((1 - step_ph / cfg.TRAINING_STEPS), cfg.POWER))

    # Set restore variable
    restore_var = tf.global_variables()
    all_trainable = [
        v for v in tf.trainable_variables()
        if ('beta' not in v.name and 'gamma' not in v.name)
        or args.train_beta_gamma
    ]

    # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS
    if args.update_mean_var == False:
        update_ops = None
    else:
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

    with tf.control_dependencies(update_ops):
        opt_conv = tf.train.MomentumOptimizer(learning_rate, cfg.MOMENTUM)
        grads = tf.gradients(reduced_loss, all_trainable)
        train_op = opt_conv.apply_gradients(zip(grads, all_trainable))

    # Create session & restore weights (Here we only need to use train_net to create session since we reuse it)
    train_net.create_session()
    if args.initializer:
        train_net.set_initializer(initializer_algorithm=args.initializer)
    train_net.initialize_variables()
    if not args.restore_from or args.restore_from != "scratch":
        train_net.restore(cfg.model_weight, restore_var)
    saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=20)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Total trainable parameters: " + str(total_parameters))

    # Iterate over training steps.
    val_loss_value = 10.0
    min_val_loss = float("inf")
    stagnation = 0
    max_non_decreasing_val_loss = int(
        np.ceil(args.early_stopping_patience * len(train_reader.image_list) /
                (cfg.BATCH_SIZE * cfg.EVAL_EVERY)))
    print(
        "Maximum times that val loss can stagnate before early stopping is applied: "
        + str(max_non_decreasing_val_loss))
    for step in range(cfg.TRAINING_STEPS):
        start_time = time.time()

        feed_dict = {step_ph: step}
        if step % cfg.EVAL_EVERY == 0:
            loss_value, loss1, loss2, loss3, val_loss_value, _ = train_net.sess.run(
                [
                    reduced_loss, loss_sub4, loss_sub24, loss_sub124,
                    val_reduced_loss, train_op
                ],
                feed_dict=feed_dict)
            if val_loss_value < min_val_loss:
                print("New best val loss {:.3f}. Saving weights...".format(
                    val_loss_value))
                train_net.save(
                    saver,
                    cfg.SNAPSHOT_DIR,
                    step,
                    model_name="val{:.3f}model.ckpt".format(val_loss_value))
                min_val_loss = val_loss_value
                stagnation = 0
            else:
                stagnation += 1
        else:
            loss_value, loss1, loss2, loss3, _ = train_net.sess.run(
                [reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op],
                feed_dict=feed_dict)

        duration = time.time() - start_time
        print('step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f}, val_loss: {:.3f} ({:.3f} sec/step)'.\
                    format(step, loss_value, loss1, loss2, loss3, val_loss_value, duration))

        if stagnation > max_non_decreasing_val_loss:
            print("Early stopping")
            break