def dnn_train_mnist_cmb(self, mnist, train_configs): """Train the network on MNIST data""" self.get_learning_rate() self.get_loss() self.get_accuracy() self.get_opt_ce() self.get_opt_mse() # Init sess self.sess = tf.InteractiveSession() self.sess.run(tf.global_variables_initializer()) numbatch_trn = mnist.train.images.shape[0] // train_configs.batchsize numbatch_val = mnist.validation.images.shape[0] // train_configs.batchsize x_epoch = np.arange(0,train_configs.epochs,1) y_loss_trn_m = np.zeros(x_epoch.shape) y_loss_val_m = np.zeros(x_epoch.shape) y_loss_trn = np.zeros(x_epoch.shape) y_loss_val = np.zeros(x_epoch.shape) y_acc_trn = np.zeros(x_epoch.shape) y_acc_val = np.zeros(x_epoch.shape) # Init all variables timestamp = utils.get_timestamp() print("[%s]: Epochs Trn_loss_mse Val_loss_mse Trn_loss_ce Val_loss_ce Trn_acc Val_acc" % (timestamp)) for i in range(train_configs.epochs): lr_dict = {self.init_lr: train_configs.lr_init, self.global_step: i, self.decay_step: train_configs.batchsize, self.decay_rate: train_configs.decay_rate} loss_trn_all_m = 0.0 loss_val_all_m = 0.0 loss_trn_all = 0.0 loss_val_all = 0.0 acc_trn_all = 0.0 acc_val_all = 0.0 for i_trn in range(numbatch_trn): data_trn, label_trn = mnist.train.next_batch( batch_size=train_configs.batchsize) train_dict = { self.inputs: data_trn, self.labels: label_trn, self.outputs: data_trn, self.is_training: True, self.keep_prob: train_configs.keep_prob} train_dict.update(lr_dict) # train _, loss_trn_m = self.sess.run( [self.train_op_mse, self.cost_mse], feed_dict=train_dict) _, loss_trn, acc_trn = self.sess.run( [self.train_op_ce, self.cost_ce, self.accuracy], feed_dict=train_dict) loss_trn_all_m += loss_trn_m acc_trn_all += acc_trn loss_trn_all += loss_trn y_loss_trn_m[i] = loss_trn_all_m / numbatch_trn y_acc_trn[i] = acc_trn_all / numbatch_trn y_loss_trn[i] = loss_trn_all / numbatch_trn # validation for i_trn in range(numbatch_val): data_val, label_val = mnist.validation.next_batch( batch_size=train_configs.batchsize) val_dict = { self.inputs: data_val, self.labels: label_val, self.outputs: data_val, self.is_training: False, self.keep_prob: 1.0} loss_val_m = self.sess.run( self.cost_mse, feed_dict=val_dict) loss_val, acc_val = self.sess.run( [self.cost_ce, self.accuracy], feed_dict=val_dict) loss_val_all_m += loss_val_m acc_val_all += acc_val loss_val_all += loss_val y_loss_val_m[i] = loss_val_all_m / numbatch_val y_acc_val[i] = acc_val_all / numbatch_val y_loss_val[i] = loss_val_all / numbatch_val # print results if i % 5 == 0: timestamp = utils.get_timestamp() print('[%s]: %d %.8f %.8f %.8f %.8f %.4f %.4f' % ( timestamp, i, y_loss_trn_m[i], y_loss_val_m[i], y_loss_trn[i], y_loss_val[i], y_acc_trn[i], y_acc_val[i])) self.train_dict = { "epochs": x_epoch, "trn_loss_mse": y_loss_trn_m, "val_loss_mse": y_loss_val_m, "trn_loss_ce": y_loss_trn, "val_loss_ce": y_loss_val, "trn_acc_ce": y_acc_trn, "val_acc_ce": y_acc_val}
def dnn_train_cmb(self, data, train_configs, labels=None): """Train the network with combined loss functions""" self.get_learning_rate() self.get_loss() self.get_accuracy() self.get_opt_ce() self.get_opt_mse() # get validation data_trn,data_val = utils.gen_validation( data, valrate=train_configs.valrate, label=labels) # Init sess self.sess = tf.InteractiveSession() self.sess.run(tf.global_variables_initializer()) numbatch_trn = len(data_trn["data"]) // train_configs.batchsize numbatch_val = len(data_val["data"]) // train_configs.batchsize x_epoch = np.arange(0,train_configs.epochs,1) y_loss_trn_m = np.zeros(x_epoch.shape) y_loss_val_m = np.zeros(x_epoch.shape) y_loss_trn = np.zeros(x_epoch.shape) y_loss_val = np.zeros(x_epoch.shape) y_acc_trn = np.zeros(x_epoch.shape) y_acc_val = np.zeros(x_epoch.shape) # Init all variables timestamp = utils.get_timestamp() print("[%s]: Epochs Trn_loss_mse Val_loss_mse Trn_loss_ce Val_loss_ce Trn_acc Val_acc" % (timestamp)) for i in range(train_configs.epochs): lr_dict = {self.init_lr: train_configs.lr_init, self.global_step: i, self.decay_step: train_configs.batchsize, self.decay_rate: train_configs.decay_rate} loss_trn_all_m = 0.0 loss_val_all_m = 0.0 loss_trn_all = 0.0 loss_val_all = 0.0 acc_trn_all = 0.0 acc_val_all = 0.0 indices_trn = utils.gen_BatchIterator_label( data_trn['data'], data_trn['label'], batch_size=train_configs.batchsize, shuffle=True) for i_trn in range(numbatch_trn): idx_trn = indices_trn[i_trn*train_configs.batchsize: (i_trn+1)*train_configs.batchsize] train_dict = { self.inputs: data_trn['data'][idx_trn], self.labels: data_trn['label'][idx_trn], self.outputs: data_trn['data'][idx_trn], self.is_training: True, self.keep_prob: train_configs.keep_prob} train_dict.update(lr_dict) # train _, loss_trn_m = self.sess.run( [self.train_op_mse, self.cost_mse], feed_dict=train_dict) _, loss_trn, acc_trn = self.sess.run( [self.train_op_ce, self.cost_ce, self.accuracy], feed_dict=train_dict) loss_trn_all_m += loss_trn_m acc_trn_all += acc_trn loss_trn_all += loss_trn y_loss_trn_m[i] = loss_trn_all_m / numbatch_trn y_acc_trn[i] = acc_trn_all / numbatch_trn y_loss_trn[i] = loss_trn_all / numbatch_trn # validation indices_val = utils.gen_BatchIterator_label( data_val['data'], data_val['label'], batch_size=train_configs.batchsize, shuffle=True) for i_val in range(numbatch_val): idx_val = indices_val[i_val*train_configs.batchsize: (i_val+1)*train_configs.batchsize] val_dict = { self.inputs: data_val['data'][idx_val], self.labels: data_val['label'][idx_val], self.outputs: data_val['data'][idx_val], self.is_training: False, self.keep_prob: 1.0} loss_val_m = self.sess.run( self.cost_mse, feed_dict=val_dict) loss_val, acc_val = self.sess.run( [self.cost_ce, self.accuracy], feed_dict=val_dict) loss_val_all_m += loss_val_m acc_val_all += acc_val loss_val_all += loss_val y_loss_val_m[i] = loss_val_all_m / numbatch_val y_acc_val[i] = acc_val_all / numbatch_val y_loss_val[i] = loss_val_all / numbatch_val # print results if i % 5 == 0: timestamp = utils.get_timestamp() print('[%s]: %d %.8f %.8f %.8f %.8f %.4f %.4f' % ( timestamp, i, y_loss_trn_m[i], y_loss_val_m[i], y_loss_trn[i], y_loss_val[i], y_acc_trn[i], y_acc_val[i])) self.train_dict = { "epochs": x_epoch, "trn_loss_mse": y_loss_trn_m, "val_loss_mse": y_loss_val_m, "trn_loss_ce": y_loss_trn, "val_loss_ce": y_loss_val, "trn_acc_ce": y_acc_trn, "val_acc_ce": y_acc_val}