def train_model(self, sess, max_iters, restore=False): img_b, lb_b, lb_len_b, t_s_b = self.get_data( self.imgdb.path, batch_size=cfg.TRAIN.BATCH_SIZE, num_epochs=cfg.TRAIN.NUM_EPOCHS) val_img_b, val_lb_b, val_lb_len_b, val_t_s_b = self.get_data( self.imgdb.val_path, batch_size=cfg.VAL.BATCH_SIZE, num_epochs=cfg.VAL.NUM_EPOCHS) print('get_data****************') loss, dense_decoded = self.net.build_loss() tf.summary.scalar('loss', loss) summary_op = tf.summary.merge_all() # optimizer if cfg.TRAIN.SOLVER == 'Adam': opt = tf.train.AdamOptimizer(cfg.TRAIN.LEARNING_RATE) lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) elif cfg.TRAIN.SOLVER == 'RMS': opt = tf.train.RMSPropOptimizer(cfg.TRAIN.LEARNING_RATE) lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) elif cfg.TRAIN.SOLVER == 'SGD': opt = tf.train.GradientDescentOptimizer(cfg.TRAIN.LEARNING_RATE) lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) else: lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) momentum = cfg.TRAIN.MOMENTUM opt = tf.train.MomentumOptimizer(lr, momentum) global_step = tf.Variable(0, trainable=False) with_clip = True if with_clip: tvars = tf.trainable_variables() grads, norm = tf.clip_by_global_norm(tf.gradients(loss, tvars), 10.0) train_op = opt.apply_gradients(list(zip(grads, tvars)), global_step=global_step) else: train_op = opt.minimize(loss, global_step=global_step) # intialize variables local_vars_init_op = tf.local_variables_initializer() global_vars_init_op = tf.global_variables_initializer() combined_op = tf.group(local_vars_init_op, global_vars_init_op) sess.run(combined_op) restore_iter = 1 # resuming a trainer if restore: try: ckpt = tf.train.get_checkpoint_state(self.output_dir) print('Restoring from {}...'.format( ckpt.model_checkpoint_path), end=' ') self.saver.restore(sess, tf.train.latest_checkpoint(self.output_dir)) stem = os.path.splitext( os.path.basename(ckpt.model_checkpoint_path))[0] restore_iter = int(stem.split('_')[-1]) sess.run(global_step.assign(restore_iter)) print('done') except: raise Exception('Check your pretrained {:s}'.format( ckpt.model_checkpoint_path)) timer = Timer() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) loss_min = 0.02 try: while not coord.should_stop(): for iter in range(restore_iter, max_iters): timer.tic() # learning rate if iter != 0 and iter % cfg.TRAIN.STEPSIZE == 0: sess.run(tf.assign(lr, lr.eval() * cfg.TRAIN.GAMMA)) # get one batch img_Batch,labels_Batch, label_len_Batch,time_step_Batch = \ sess.run([img_b,lb_b,lb_len_b,t_s_b]) label_Batch = self.mergeLabel(labels_Batch, ignore=0) # Subtract the mean pixel value from each pixel feed_dict = { self.net.data: img_Batch, self.net.labels: label_Batch, self.net.time_step_len: np.array(time_step_Batch), self.net.labels_len: np.array(label_len_Batch), self.net.keep_prob: 0.5 } fetch_list = [loss, summary_op, train_op] ctc_loss, summary_str, _ = sess.run(fetches=fetch_list, feed_dict=feed_dict) self.writer.add_summary(summary=summary_str, global_step=global_step.eval()) _diff_time = timer.toc(average=False) if (iter) % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, total loss: %.7f, lr: %.7f'%\ (iter, max_iters, ctc_loss ,lr.eval()),end=' ') print('speed: {:.3f}s / iter'.format(_diff_time)) if ( iter + 1 ) % cfg.TRAIN.SNAPSHOT_ITERS == 0 or ctc_loss < loss_min: if (ctc_loss < loss_min): print('loss: ', ctc_loss, end=' ') self.snapshot(sess, 1) loss_min = ctc_loss else: self.snapshot(sess, iter) if (iter + 1) % cfg.VAL.VAL_STEP == 0 or loss_min == ctc_loss: val_img_Batch,val_labels_Batch, val_label_len_Batch,val_time_step_Batch = \ sess.run([val_img_b,val_lb_b,val_lb_len_b,val_t_s_b]) val_label_Batch = self.mergeLabel(val_labels_Batch, ignore=0) feed_dict = { self.net.data: val_img_Batch, self.net.labels: val_label_Batch, self.net.time_step_len: np.array(val_time_step_Batch), self.net.labels_len: np.array(val_label_len_Batch), self.net.keep_prob: 1.0 } # fetch_list = [dense_decoded] org = val_labels_Batch res = sess.run(fetches=dense_decoded, feed_dict=feed_dict) acc = accuracy_calculation(org, res, ignore_value=0) print('accuracy: {:.5f}'.format(acc)) iter = max_iters - 1 self.snapshot(sess, iter) coord.request_stop() except tf.errors.OutOfRangeError: print('finish') finally: coord.request_stop() coord.join(threads)
def train_model(self, sess, max_iters, restore=False): train_gen = get_batch(num_workers=12, batch_size=cfg.TRAIN.BATCH_SIZE, vis=False) val_gen = get_batch(num_workers=1, batch_size=cfg.VAL.BATCH_SIZE, vis=False) loss, dense_decoded = self.net.build_loss() tf.summary.scalar('loss', loss) summary_op = tf.summary.merge_all() # optimizer lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) if cfg.TRAIN.SOLVER == 'Adam': opt = tf.train.AdamOptimizer(lr) elif cfg.TRAIN.SOLVER == 'RMS': opt = tf.train.RMSPropOptimizer(lr) else: opt = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM) global_step = tf.Variable(0, trainable=False) with_clip = True if with_clip: tvars = tf.trainable_variables() grads, norm = tf.clip_by_global_norm(tf.gradients(loss, tvars), 10.0) train_op = opt.apply_gradients(list(zip(grads, tvars)), global_step=global_step) else: train_op = opt.minimize(loss, global_step=global_step) # intialize variables local_vars_init_op = tf.local_variables_initializer() global_vars_init_op = tf.global_variables_initializer() combined_op = tf.group(local_vars_init_op, global_vars_init_op) sess.run(combined_op) restore_iter = 1 # resuming a trainer if restore: try: ckpt = tf.train.get_checkpoint_state(self.output_dir) print('Restoring from {}...'.format( ckpt.model_checkpoint_path), end=' ') self.saver.restore(sess, tf.train.latest_checkpoint(self.output_dir)) stem = os.path.splitext( os.path.basename(ckpt.model_checkpoint_path))[0] restore_iter = int(stem.split('_')[-1]) sess.run(global_step.assign(restore_iter)) print('done') except: raise Exception('Check your pretrained {:s}'.format( ckpt.model_checkpoint_path)) timer = Timer() loss_min = 0.015 first_val = True for iter in range(restore_iter, max_iters): timer.tic() # learning rate if iter != 0 and iter % cfg.TRAIN.STEPSIZE == 0: sess.run(tf.assign(lr, lr.eval() * cfg.TRAIN.GAMMA)) # get one batch img_Batch, label_Batch, label_len_Batch, time_step_Batch = next( train_gen) img_Batch = np.array(img_Batch) # Subtract the mean pixel value from each pixel feed_dict = { self.net.data: np.array(img_Batch), self.net.labels: np.array(label_Batch), self.net.time_step_len: np.array(time_step_Batch), self.net.labels_len: np.array(label_len_Batch), self.net.keep_prob: 0.5 } fetch_list = [loss, summary_op, train_op] ctc_loss, summary_str, _ = sess.run(fetches=fetch_list, feed_dict=feed_dict) self.writer.add_summary(summary=summary_str, global_step=global_step.eval()) _diff_time = timer.toc(average=False) if (iter) % (cfg.TRAIN.DISPLAY) == 0: print('iter: %d / %d, total loss: %.7f, lr: %.7f'%\ (iter, max_iters, ctc_loss ,lr.eval()),end=' ') print('speed: {:.3f}s / iter'.format(_diff_time)) if (iter + 1) % cfg.TRAIN.SNAPSHOT_ITERS == 0 or ctc_loss < loss_min: if (ctc_loss < loss_min): print('loss: ', ctc_loss, end=' ') self.snapshot(sess, 1) loss_min = ctc_loss else: self.snapshot(sess, iter) if (iter + 1) % cfg.VAL.VAL_STEP == 0 or loss_min == ctc_loss: if first_val: val_img_Batch, val_label_Batch, val_label_len_Batch, val_time_step_Batch = next( val_gen) org = self.restoreLabel(val_label_Batch, val_label_len_Batch) first_val = False feed_dict = { self.net.data: np.array(val_img_Batch), self.net.labels: np.array(val_label_Batch), self.net.time_step_len: np.array(val_time_step_Batch), self.net.labels_len: np.array(val_label_len_Batch), self.net.keep_prob: 1.0 } # fetch_list = [dense_decoded] res = sess.run(fetches=dense_decoded, feed_dict=feed_dict) acc = accuracy_calculation(org, res, ignore_value=0) print('accuracy: {:.5f}'.format(acc))