def _build_model_use_tfdata(test_set_tfrecords_dir, ckpt_dir): ''' test_set_tfrecords_dir: '/xxx/xxx/*.tfrecords' ''' g = tf.Graph() with g.as_default(): # region TFRecord+DataSet with tf.device('/cpu:0'): with tf.name_scope('input'): x_batch, y_batch, Xtheta_batch, Ytheta_batch, lengths_batch, iter_test = get_batch_use_tfdata( test_set_tfrecords_dir, get_theta=True) with tf.name_scope('model'): test_model = PARAM.SE_MODEL(x_batch, lengths_batch, y_batch, Xtheta_batch, Ytheta_batch, behavior=PARAM.SE_MODEL.infer) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True sess = tf.Session(config=config) sess.run(init) ckpt = tf.train.get_checkpoint_state( os.path.join(PARAM.SAVE_DIR, ckpt_dir)) test_model.saver.restore(sess, ckpt.model_checkpoint_path) g.finalize() return sess, test_model, iter_test
def build_session(ckpt_dir, batch_size, finalizeG=True): batch_size = None # None is OK g = tf.Graph() with g.as_default(): with tf.device('/cpu:0'): with tf.name_scope('input'): x_batch = tf.placeholder( tf.float32, shape=[batch_size, None, PARAM.FFT_DOT], name='x_batch') lengths_batch = tf.placeholder(tf.int32, shape=[batch_size], name='lengths_batch') y_batch = tf.placeholder( tf.float32, shape=[batch_size, None, PARAM.FFT_DOT], name='y_batch') x_theta = tf.placeholder( tf.float32, shape=[batch_size, None, PARAM.FFT_DOT], name='x_theta') y_theta = tf.placeholder( tf.float32, shape=[batch_size, None, PARAM.FFT_DOT], name='y_theta') with tf.name_scope('model'): model = PARAM.SE_MODEL(x_batch, lengths_batch, y_batch, x_theta, y_theta, behavior=PARAM.SE_MODEL.infer) init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True sess = tf.Session(config=config) sess.run(init) ckpt = tf.train.get_checkpoint_state( os.path.join(PARAM.SAVE_DIR, ckpt_dir)) if ckpt and ckpt.model_checkpoint_path: tf.logging.info("Restore from " + ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) else: tf.logging.fatal("checkpoint not found.") sys.exit(-1) if finalizeG: g.finalize() return sess, model return sess, model, g
def train(): g = tf.Graph() with g.as_default(): # region TFRecord+DataSet with tf.device('/cpu:0'): with tf.name_scope('input'): train_tfrecords, val_tfrecords, _, _ = generate_tfrecord( gen=PARAM.GENERATE_TFRECORD) if PARAM.GENERATE_TFRECORD: print("TFRecords preparation over.") # exit(0) # set gen=True and exit to generate tfrecords PSIRM = True if PARAM.PIPLINE_GET_THETA else False x_batch_tr, y_batch_tr, Xtheta_batch_tr, Ytheta_batch_tr, lengths_batch_tr, iter_train = get_batch_use_tfdata( train_tfrecords, get_theta=PSIRM) x_batch_val, y_batch_val, Xtheta_batch_val, Ytheta_batch_val, lengths_batch_val, iter_val = get_batch_use_tfdata( val_tfrecords, get_theta=PSIRM) # endregion # build model with tf.name_scope('model'): tr_model = PARAM.SE_MODEL(x_batch_tr, lengths_batch_tr, y_batch_tr, Xtheta_batch_tr, Ytheta_batch_tr, PARAM.SE_MODEL.train) tf.get_variable_scope().reuse_variables() val_model = PARAM.SE_MODEL(x_batch_val, lengths_batch_val, y_batch_val, Xtheta_batch_val, Ytheta_batch_val, PARAM.SE_MODEL.validation) utils.tf_tool.show_all_variables() init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) config = tf.ConfigProto() config.gpu_options.allow_growth = PARAM.GPU_RAM_ALLOW_GROWTH config.allow_soft_placement = False sess = tf.Session(config=config) sess.run(init) # region resume training if PARAM.resume_training.lower() == 'true': ckpt = tf.train.get_checkpoint_state(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT)) if ckpt and ckpt.model_checkpoint_path: # tf.logging.info("restore from" + ckpt.model_checkpoint_path) tr_model.saver.restore(sess, ckpt.model_checkpoint_path) best_path = ckpt.model_checkpoint_path else: tf.logging.fatal("checkpoint not found") with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f: f.writelines('Training resumed.\n') else: if os.path.exists(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log')): os.remove(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log')) # endregion # region validation before training. valstart_time = time.time() sess.run(iter_val.initializer) mag_loss_prev, logmag_loss_prev = eval_one_epoch(sess, val_model) cross_val_msg = "CROSSVAL PRERUN LOGBIASNET_LOSS_MASKNET_LOSS (%.4F,%.4F), Costime %dS" % ( mag_loss_prev, logmag_loss_prev, time.time()-valstart_time) tf.logging.info(cross_val_msg) with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f: f.writelines(cross_val_msg+'\n') tr_model.assign_lr(sess, PARAM.learning_rate_logbiasnet, PARAM.learning_rate_masknet) g.finalize() # endregion # epochs training reject_num = 0 for epoch in range(PARAM.start_epoch, PARAM.max_epochs): sess.run([iter_train.initializer, iter_val.initializer]) start_time = time.time() # train one epoch tr_mag_loss, tr_logmag_loss, logbiasnet_lr, masknet_lr, log_bias = train_one_epoch(sess, tr_model) # Validation val_mag_loss, val_logmag_loss = eval_one_epoch(sess, val_model) end_time = time.time() # Determine checkpoint path ckpt_name = "nnet_iter%d_lrate(%e,%e)_trloss(%.4f,%.4f)_cvloss(%.4f,%.4f)_avglogbias%f_duration%ds" % ( epoch + 1, logbiasnet_lr, masknet_lr, tr_mag_loss, tr_logmag_loss, val_mag_loss, val_logmag_loss, np.mean(log_bias), end_time - start_time) ckpt_dir = os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) ckpt_path = os.path.join(ckpt_dir, ckpt_name) # Relative loss between previous and current val_loss rel_impr_magloss = np.abs(mag_loss_prev - val_mag_loss) / mag_loss_prev rel_impr_logmagloss = np.abs(logmag_loss_prev - val_logmag_loss) / logmag_loss_prev # Accept or reject new parameters msg = "" if (PARAM.TRAIN_TYPE == 'BOTH' and (val_mag_loss < mag_loss_prev) and (val_logmag_loss < logmag_loss_prev)) or ( PARAM.TRAIN_TYPE == 'LOGBIASNET' and (val_mag_loss < mag_loss_prev)) or ( PARAM.TRAIN_TYPE == 'MASKNET' and (val_logmag_loss < logmag_loss_prev)): reject_num = 0 tr_model.saver.save(sess, ckpt_path) # Logging train loss along with validation loss mag_loss_prev = val_mag_loss logmag_loss_prev = val_logmag_loss best_path = ckpt_path msg = ("Train Iteration %03d: \n" " Train.LOSS (%.4f,%.4f), lrate(%e,%e), Val.LOSS (%.4f,%.4f), avglog_bias %f,\n" " %s, ckpt(%s) saved,\n" " EPOCH DURATION: %.2fs\n") % ( epoch + 1, tr_mag_loss, tr_logmag_loss, logbiasnet_lr, masknet_lr, val_mag_loss, val_logmag_loss, np.mean(log_bias), "NNET Accepted", ckpt_name, end_time - start_time) tf.logging.info(msg) else: reject_num += 1 tr_model.saver.restore(sess, best_path) msg = ("Train Iteration %03d: \n" " Train.LOSS (%.4f,%.4f), lrate(%e,%e), Val.LOSS (%.4f,%.4f), avglog_bias %f,\n" " %s, ckpt(%s) abandoned,\n" " EPOCH DURATION: %.2fs\n") % ( epoch + 1, tr_mag_loss, tr_logmag_loss, logbiasnet_lr, masknet_lr, val_logmag_loss, val_logmag_loss, np.mean(log_bias), "NNET Rejected", ckpt_name, end_time - start_time) tf.logging.info(msg) with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f: f.writelines(msg+'\n') # Start halving when improvement is lower than start_halving_impr if PARAM.TRAIN_TYPE == 'BOTH': if (rel_impr_magloss < PARAM.start_halving_impr) or (rel_impr_logmagloss < PARAM.start_halving_impr) or (reject_num >= 2): reject_num = 0 logbiasnet_lr *= PARAM.halving_factor masknet_lr *= PARAM.halving_factor tr_model.assign_lr(sess,logbiasnet_lr,masknet_lr) elif PARAM.TRAIN_TYPE == 'LOGBIASNET': if (rel_impr_magloss < PARAM.start_halving_impr) or (reject_num >= 2): reject_num = 0 logbiasnet_lr *= PARAM.halving_factor tr_model.assign_lr(sess,logbiasnet_lr,masknet_lr) elif PARAM.TRAIN_TYPE == 'MASKNET': if (rel_impr_logmagloss < PARAM.start_halving_impr) or (reject_num >= 2): reject_num = 0 masknet_lr *= PARAM.halving_factor tr_model.assign_lr(sess,logbiasnet_lr,masknet_lr) else: print('Train type error.') exit(-1) # Stopping criterion if rel_impr_magloss < PARAM.end_halving_impr and rel_impr_logmagloss < PARAM.end_halving_impr: if epoch < PARAM.min_epochs: tf.logging.info( "we were supposed to finish, but we continue as " "min_epochs : %s" % PARAM.min_epochs) continue else: tf.logging.info( "finished, too small rel. improvement (%g,%g)" % (rel_impr_magloss, rel_impr_logmagloss)) break sess.close() tf.logging.info("Done training")
def __init__(self, x_mag_spec_batch, lengths_batch, y_mag_spec_batch=None, theta_x_batch=None, theta_y_batch=None, infer=False): self._log_bias = tf.get_variable('logbias', [1], trainable=PARAM.LOG_BIAS_TRAINABEL, initializer=tf.constant_initializer( PARAM.INIT_LOG_BIAS)) self._real_logbias = self._log_bias + DEFAULT_LOG_BIAS self._inputs = x_mag_spec_batch self._x_mag_spec = self.inputs self._norm_x_mag_spec = norm_mag_spec(self._x_mag_spec) self._norm_x_logmag_spec = norm_logmag_spec(self._x_mag_spec, self._log_bias) if not infer: self._y_mag_spec = y_mag_spec_batch self._norm_y_mag_spec = norm_mag_spec(self._y_mag_spec) self._norm_y_logmag_spec = norm_logmag_spec( self._y_mag_spec, self._log_bias) self._lengths = lengths_batch self.batch_size = tf.shape(self._lengths)[0] self._model_type = PARAM.MODEL_TYPE if PARAM.INPUT_TYPE == 'mag': self.net_input = self._norm_x_mag_spec elif PARAM.INPUT_TYPE == 'logmag': self.net_input = self._norm_x_logmag_spec if not infer: if PARAM.LABEL_TYPE == 'mag': self._labels = self._norm_y_mag_spec elif PARAM.LABEL_TYPE == 'logmag': self._labels = self._norm_y_logmag_spec outputs = self.net_input def lstm_cell(): return tf.contrib.rnn.LSTMCell( PARAM.RNN_SIZE, forget_bias=1.0, use_peepholes=True, num_proj=PARAM.LSTM_num_proj, initializer=tf.contrib.layers.xavier_initializer(), state_is_tuple=True, activation=PARAM.LSTM_ACTIVATION) lstm_attn_cell = lstm_cell if not infer and PARAM.KEEP_PROB < 1.0: def lstm_attn_cell(): return tf.contrib.rnn.DropoutWrapper( lstm_cell(), output_keep_prob=PARAM.KEEP_PROB) def GRU_cell(): return tf.contrib.rnn.GRUCell( PARAM.RNN_SIZE, # kernel_initializer=tf.contrib.layers.xavier_initializer(), activation=PARAM.LSTM_ACTIVATION) GRU_attn_cell = lstm_cell if not infer and PARAM.KEEP_PROB < 1.0: def GRU_attn_cell(): return tf.contrib.rnn.DropoutWrapper( GRU_cell(), output_keep_prob=PARAM.KEEP_PROB) if PARAM.MODEL_TYPE.upper() == 'BLSTM': with tf.variable_scope('BLSTM'): lstm_fw_cell = tf.contrib.rnn.MultiRNNCell( [lstm_attn_cell() for _ in range(PARAM.RNN_LAYER)], state_is_tuple=True) lstm_bw_cell = tf.contrib.rnn.MultiRNNCell( [lstm_attn_cell() for _ in range(PARAM.RNN_LAYER)], state_is_tuple=True) lstm_fw_cell = lstm_fw_cell._cells lstm_bw_cell = lstm_bw_cell._cells result = rnn.stack_bidirectional_dynamic_rnn( cells_fw=lstm_fw_cell, cells_bw=lstm_bw_cell, inputs=outputs, dtype=tf.float32, sequence_length=self._lengths) outputs, fw_final_states, bw_final_states = result if PARAM.MODEL_TYPE.upper() == 'BGRU': with tf.variable_scope('BGRU'): gru_fw_cell = tf.contrib.rnn.MultiRNNCell( [GRU_attn_cell() for _ in range(PARAM.RNN_LAYER)], state_is_tuple=True) gru_bw_cell = tf.contrib.rnn.MultiRNNCell( [GRU_attn_cell() for _ in range(PARAM.RNN_LAYER)], state_is_tuple=True) gru_fw_cell = gru_fw_cell._cells gru_bw_cell = gru_bw_cell._cells result = rnn.stack_bidirectional_dynamic_rnn( cells_fw=gru_fw_cell, cells_bw=gru_bw_cell, inputs=outputs, dtype=tf.float32, sequence_length=self._lengths) outputs, fw_final_states, bw_final_states = result with tf.variable_scope('fullconnectOut'): if self._model_type.upper()[0] == 'B': # bidirection outputs = tf.reshape(outputs, [-1, 2 * PARAM.LSTM_num_proj]) in_size = 2 * PARAM.LSTM_num_proj out_size = PARAM.OUTPUT_SIZE weights = tf.get_variable( 'weights1', [in_size, out_size], initializer=tf.random_normal_initializer(stddev=0.01)) biases = tf.get_variable('biases1', [out_size], initializer=tf.constant_initializer(0.0)) mask = tf.nn.relu(tf.matmul(outputs, weights) + biases) self._mask = tf.reshape(mask, [self.batch_size, -1, PARAM.OUTPUT_SIZE]) self.saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=30) if infer: if PARAM.DECODING_MASK_POSITION == 'mag': self._cleaned = rm_norm_mag_spec(self._mask * self._norm_x_mag_spec) elif PARAM.DECODING_MASK_POSITION == 'logmag': self._cleaned = rm_norm_logmag_spec( self._mask * self._norm_x_logmag_spec, self._log_bias) return if PARAM.TRAINING_MASK_POSITION == 'mag': self._cleaned = self._mask * self._norm_x_mag_spec elif PARAM.TRAINING_MASK_POSITION == 'logmag': self._cleaned = self._mask * self._norm_x_logmag_spec if PARAM.MASK_TYPE == 'PSM': self._labels *= tf.cos(theta_x_batch - theta_y_batch) elif PARAM.MASK_TYPE == 'IRM': pass else: tf.logging.error('Mask type error.') exit(-1) if PARAM.TRAINING_MASK_POSITION != PARAM.LABEL_TYPE: if PARAM.LABEL_TYPE == 'mag': self._cleaned = normedLogmag2normedMag(self._cleaned, self._log_bias) elif PARAM.LABEL_TYPE == 'logmag': self._cleaned = normedMag2normedLogmag(self._cleaned, self._log_bias) self._loss = PARAM.LOSS_FUNC(self._cleaned, self._labels) if tf.get_variable_scope().reuse: return self._lr = tf.Variable(0.0, trainable=False) tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), PARAM.CLIP_NORM) optimizer = tf.train.AdamOptimizer(self.lr) #optimizer = tf.train.GradientDescentOptimizer(self.lr) self._train_op = optimizer.apply_gradients(zip(grads, tvars)) self._new_lr = tf.placeholder(tf.float32, shape=[], name='new_learning_rate') self._lr_update = tf.assign(self._lr, self._new_lr)
def train(): g = tf.Graph() with g.as_default(): # region TFRecord+DataSet with tf.device('/cpu:0'): with tf.name_scope('input'): train_tfrecords, val_tfrecords, _, _ = generate_tfrecord( gen=PARAM.GENERATE_TFRECORD) if PARAM.GENERATE_TFRECORD: tf.logging.info("TFRecords preparation over.") # exit(0) # set gen=True and exit to generate tfrecords PSIRM = True if PARAM.PIPLINE_GET_THETA else False x_batch_tr, y_batch_tr, Xtheta_batch_tr, Ytheta_batch_tr, lengths_batch_tr, iter_train = get_batch_use_tfdata( train_tfrecords, get_theta=PSIRM) x_batch_val, y_batch_val, Xtheta_batch_val, Ytheta_batch_val, lengths_batch_val, iter_val = get_batch_use_tfdata( val_tfrecords, get_theta=PSIRM) # endregion # build model with tf.name_scope('model'): tr_model = PARAM.SE_MODEL(x_batch_tr, lengths_batch_tr, y_batch_tr, Xtheta_batch_tr, Ytheta_batch_tr, PARAM.SE_MODEL.train) tf.get_variable_scope().reuse_variables() val_model = PARAM.SE_MODEL(x_batch_val, lengths_batch_val, y_batch_val, Xtheta_batch_val, Ytheta_batch_val, PARAM.SE_MODEL.validation) utils.tf_tool.show_all_variables() init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) config = tf.ConfigProto() config.gpu_options.allow_growth = PARAM.GPU_RAM_ALLOW_GROWTH config.allow_soft_placement = False # config.gpu_options.per_process_gpu_memory_fraction = 0.55 sess = tf.Session(config=config) sess.run(init) # region resume training if PARAM.resume_training.lower() == 'true': ckpt = tf.train.get_checkpoint_state(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT)) if ckpt and ckpt.model_checkpoint_path: # tf.logging.info("restore from" + ckpt.model_checkpoint_path) tr_model.saver.restore(sess, ckpt.model_checkpoint_path) best_path = ckpt.model_checkpoint_path else: tf.logging.fatal("checkpoint not found") with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f: f.writelines('Training resumed.\n') else: if os.path.exists(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log')): os.remove(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log')) # endregion # region validation before training. valstart_time = time.time() sess.run(iter_val.initializer) loss_prev = eval_one_epoch(sess, val_model) cross_val_msg = "\n\nCROSSVAL PRERUN AVG.LOSS %.4F costime %dS\n" % ( loss_prev, time.time()-valstart_time) tf.logging.info(cross_val_msg) with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f: f.writelines(cross_val_msg+'\n') g.finalize() # endregion # epochs training # TODO resume lr form model_name lr = PARAM.learning_rate tr_model.assign_lr(sess, lr) # if PARAM.resume_training.lower() == 'true': # lr = sess.run(tr_model.lr) # else: # lr = PARAM.learning_rate # tr_model.assign_lr(sess, lr) lr_halving_time = 0 for epoch in range(PARAM.start_epoch, PARAM.max_epochs): sess.run([iter_train.initializer, iter_val.initializer]) start_time = time.time() # train one epoch tr_loss, model_lr, log_bias = train_one_epoch(sess, tr_model) # Validation val_loss = eval_one_epoch(sess, val_model) end_time = time.time() # Determine checkpoint path ckpt_name = "nnet_iter%d_lrate%e_trloss%.4f_cvloss%.4f_avglogbias%f_duration%ds" % ( epoch + 1, model_lr, tr_loss, val_loss, np.mean(log_bias), end_time - start_time) ckpt_dir = os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) ckpt_path = os.path.join(ckpt_dir, ckpt_name) # Relative loss between previous and current val_loss rel_impr = (loss_prev - val_loss) / loss_prev # Accept or reject new parameters msg = "" if val_loss < loss_prev: tr_model.saver.save(sess, ckpt_path) # Logging train loss along with validation loss loss_prev = val_loss best_path = ckpt_path msg = ("Train Iteration %03d: \n" " Train.LOSS %.4f, lrate %e, Val.LOSS %.4f, avglog_bias %f,\n" " %s, ckpt(%s) saved,\n" " EPOCH DURATION: %.2fs\n") % ( epoch + 1, tr_loss, model_lr, val_loss, np.mean(log_bias), "NNET Accepted", ckpt_name, end_time - start_time) tf.logging.info(msg) else: tr_model.saver.restore(sess, best_path) msg = ("Train Iteration %03d: \n" " Train.LOSS %.4f, lrate%e, Val.LOSS %.4f, avglog_bias %f,\n" " %s, ckpt(%s) abandoned,\n" " EPOCH DURATION: %.2fs\n") % ( epoch + 1, tr_loss, model_lr, val_loss, np.mean(log_bias), "NNET Rejected", ckpt_name, end_time - start_time) tf.logging.info(msg) with open(os.path.join(PARAM.SAVE_DIR, PARAM.CHECK_POINT+'_train.log'), 'a+') as f: f.writelines(msg+'\n') # Start halving when improvement is lower than start_halving_impr if rel_impr < PARAM.start_halving_impr: lr *= PARAM.halving_factor lr_halving_time += 1 tr_model.assign_lr(sess, lr) # Stopping criterion if rel_impr < PARAM.end_halving_impr: if (epoch < PARAM.min_epochs) or (lr_halving_time<=PARAM.max_lr_halving_time): tf.logging.info( "we were supposed to finish, but we continue as " "now_epoch<=min_epochs(%s<=%s)," " or lr_halving_time<=max_lr_halving_time(%s<=%s)" "." % (epoch+1,PARAM.min_epochs,lr_halving_time,PARAM.max_lr_halving_time)) continue else: tf.logging.info( "finished, too small retive improvement %g." % rel_impr) break sess.close() tf.logging.info("Done training")