def main(): """Create the model and start the training.""" args = get_arguments() """ Get configurations here. We pass some arguments from command line to init configurations, for training hyperparameters, you can set them in TrainConfig Class. Note: we set filter scale to 1 for pruned model, 2 for non-pruned model. The filters numbers of non-pruned model is two times larger than prunde model, e.g., [h, w, 64] <-> [h, w, 32]. """ cfg = TrainConfig(dataset=args.dataset, is_training=True, random_scale=args.random_scale, random_mirror=args.random_mirror, filter_scale=args.filter_scale) cfg.display() # Setup training network and training samples train_reader = ImageReader(cfg=cfg, mode='train') train_net = ICNet_BN(image_reader=train_reader, cfg=cfg, mode='train') loss_sub4, loss_sub24, loss_sub124, reduced_loss = create_losses(train_net, train_net.labels, cfg) # Setup validation network and validation samples with tf.variable_scope('', reuse=True): val_reader = ImageReader(cfg, mode='eval') val_net = ICNet_BN(image_reader=val_reader, cfg=cfg, mode='train') val_loss_sub4, val_loss_sub24, val_loss_sub124, val_reduced_loss = create_losses(val_net, val_net.labels, cfg) # Using Poly learning rate policy base_lr = tf.constant(cfg.LEARNING_RATE) step_ph = tf.placeholder(dtype=tf.float32, shape=()) learning_rate = tf.scalar_mul(base_lr, tf.pow((1 - step_ph / cfg.TRAINING_STEPS), cfg.POWER)) # Set restore variable restore_var = tf.global_variables() all_trainable = [v for v in tf.trainable_variables() if ('beta' not in v.name and 'gamma' not in v.name) or args.train_beta_gamma] # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS if args.update_mean_var == False: update_ops = None else: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt_conv = tf.train.MomentumOptimizer(learning_rate, cfg.MOMENTUM) grads = tf.gradients(reduced_loss, all_trainable) train_op = opt_conv.apply_gradients(zip(grads, all_trainable)) # Create session & restore weights (Here we only need to use train_net to create session since we reuse it) train_net.create_session() # train_net.restore(cfg.model_weight, restore_var) saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=5) # Iterate over training steps. for step in range(cfg.TRAINING_STEPS): start_time = time.time() feed_dict = {step_ph: step} if step % cfg.SAVE_PRED_EVERY == 0: loss_value, loss1, loss2, loss3, val_loss_value, _ = train_net.sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, val_reduced_loss, train_op], feed_dict=feed_dict) train_net.save(saver, cfg.SNAPSHOT_DIR, step) else: loss_value, loss1, loss2, loss3, val_loss_value, _ = train_net.sess.run([reduced_loss, loss_sub4, loss_sub24, loss_sub124, val_reduced_loss, train_op], feed_dict=feed_dict) duration = time.time() - start_time print('step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f}, val_loss: {:.3f} ({:.3f} sec/step)'.\ format(step, loss_value, loss1, loss2, loss3, val_loss_value, duration))
def main(): """Create the model and start the training.""" args = get_arguments() """ Get configurations here. We pass some arguments from command line to init configurations, for training hyperparameters, you can set them in TrainConfig Class. Note: we set filter scale to 1 for pruned model, 2 for non-pruned model. The filters numbers of non-pruned model is two times larger than prunde model, e.g., [h, w, 64] <-> [h, w, 32]. """ cfg = TrainConfig(dataset=args.dataset, is_training=True, random_scale=args.random_scale, random_mirror=args.random_mirror, filter_scale=args.filter_scale) if args.num_classes is not None: cfg.param["num_classes"] = args.num_classes if args.data_dir is not None: cfg.param["data_dir"] = args.data_dir if args.val_list is not None: cfg.param["eval_list"] = args.val_list if args.train_list is not None: cfg.param["train_list"] = args.train_list if args.ignore_label is not None: cfg.param["ignore_label"] = args.ignore_label if args.eval_size is not None: cfg.param["eval_size"] = [ int(x.strip()) for x in args.eval_size.split("x")[::-1] ] if args.training_size is not None: cfg.TRAINING_SIZE = [ int(x.strip()) for x in args.training_size.split("x")[::-1] ] if args.batch_size is not None: cfg.BATCH_SIZE = args.batch_size if args.learning_rate is not None: cfg.LEARNING_RATE = args.learning_rate if args.restore_from is not None: cfg.model_weight = args.restore_from if args.snapshot_dir is not None: cfg.SNAPSHOT_DIR = args.snapshot_dir if args.restore_from == "scratch": from tqdm import tqdm import cv2 import joblib as joblib if not args.img_mean: print( "Calculating img mean for custom dataset. To prevent this, specify it with --img-mean next time" ) image_files, annotation_files = read_labeled_image_list( cfg.param["data_dir"], cfg.param["train_list"]) means = joblib.Parallel(n_jobs=6)( joblib.delayed(calc_mean)(image_file, cv2) for image_file in tqdm(image_files, desc="calc img mean")) cfg.IMG_MEAN = np.mean(means, axis=0).tolist() else: cfg.IMG_MEAN = [float(x.strip()) for x in args.img_mean.split(",")] cfg.display() # Setup training network and training samples train_reader = ImageReader(cfg=cfg, mode='train') train_net = ICNet_BN(image_reader=train_reader, cfg=cfg, mode='train') loss_sub4, loss_sub24, loss_sub124, reduced_loss = create_losses( train_net, train_net.labels, cfg) # Setup validation network and validation samples with tf.variable_scope('', reuse=True): val_reader = ImageReader(cfg, mode='eval') val_net = ICNet_BN(image_reader=val_reader, cfg=cfg, mode='train') val_loss_sub4, val_loss_sub24, val_loss_sub124, val_reduced_loss = create_losses( val_net, val_net.labels, cfg) # Using Poly learning rate policy base_lr = tf.constant(cfg.LEARNING_RATE) step_ph = tf.placeholder(dtype=tf.float32, shape=()) learning_rate = tf.scalar_mul( base_lr, tf.pow((1 - step_ph / cfg.TRAINING_STEPS), cfg.POWER)) # Set restore variable restore_var = tf.global_variables() all_trainable = [ v for v in tf.trainable_variables() if ('beta' not in v.name and 'gamma' not in v.name) or args.train_beta_gamma ] # Gets moving_mean and moving_variance update operations from tf.GraphKeys.UPDATE_OPS if args.update_mean_var == False: update_ops = None else: update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): opt_conv = tf.train.MomentumOptimizer(learning_rate, cfg.MOMENTUM) grads = tf.gradients(reduced_loss, all_trainable) train_op = opt_conv.apply_gradients(zip(grads, all_trainable)) # Create session & restore weights (Here we only need to use train_net to create session since we reuse it) train_net.create_session() if args.initializer: train_net.set_initializer(initializer_algorithm=args.initializer) train_net.initialize_variables() if not args.restore_from or args.restore_from != "scratch": train_net.restore(cfg.model_weight, restore_var) saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=20) total_parameters = 0 for variable in tf.trainable_variables(): # shape is an array of tf.Dimension shape = variable.get_shape() variable_parameters = 1 for dim in shape: variable_parameters *= dim.value total_parameters += variable_parameters print("Total trainable parameters: " + str(total_parameters)) # Iterate over training steps. val_loss_value = 10.0 min_val_loss = float("inf") stagnation = 0 max_non_decreasing_val_loss = int( np.ceil(args.early_stopping_patience * len(train_reader.image_list) / (cfg.BATCH_SIZE * cfg.EVAL_EVERY))) print( "Maximum times that val loss can stagnate before early stopping is applied: " + str(max_non_decreasing_val_loss)) for step in range(cfg.TRAINING_STEPS): start_time = time.time() feed_dict = {step_ph: step} if step % cfg.EVAL_EVERY == 0: loss_value, loss1, loss2, loss3, val_loss_value, _ = train_net.sess.run( [ reduced_loss, loss_sub4, loss_sub24, loss_sub124, val_reduced_loss, train_op ], feed_dict=feed_dict) if val_loss_value < min_val_loss: print("New best val loss {:.3f}. Saving weights...".format( val_loss_value)) train_net.save( saver, cfg.SNAPSHOT_DIR, step, model_name="val{:.3f}model.ckpt".format(val_loss_value)) min_val_loss = val_loss_value stagnation = 0 else: stagnation += 1 else: loss_value, loss1, loss2, loss3, _ = train_net.sess.run( [reduced_loss, loss_sub4, loss_sub24, loss_sub124, train_op], feed_dict=feed_dict) duration = time.time() - start_time print('step {:d} \t total loss = {:.3f}, sub4 = {:.3f}, sub24 = {:.3f}, sub124 = {:.3f}, val_loss: {:.3f} ({:.3f} sec/step)'.\ format(step, loss_value, loss1, loss2, loss3, val_loss_value, duration)) if stagnation > max_non_decreasing_val_loss: print("Early stopping") break