def train(self, image, epochs, enable_es=1): graph = tf.Graph() with tf.Session(graph=graph) as session: tf.set_random_seed(1234) self.__create_inputs() new_saver = self.__create_graph(self.meta_file) self.__create_loss_optimizer() # slim.model_analyzer.analyze_vars(tf.trainable_variables() , print_info=True) early_stopping = EarlyStopping(patience=30, min_delta=1e-1) tf.global_variables_initializer().run() new_saver.restore(session,self.latest_checkpoint) recons_loss = list() print('Starting optimization...') for cur_epoch in range(epochs + 1): dict_loss = self.__train_epoch(session,image) list_loss = list(dict_loss.values()) if np.isnan(list_loss[0]): print ('Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.') sys.exit() if(cur_epoch % 20 == 0 or cur_epoch==0): print('EPOCH: {} | dist: {} '.format(cur_epoch, list_loss[0])) recons_loss.append(list_loss[0]) #Early stopping if(cur_epoch>50 and enable_es==1 and early_stopping.stop(list_loss[0])): print('Early Stopping!') print('EPOCH: {} | dist: {} '.format(cur_epoch, list_loss[0])) break z_infer = session.run(self.z) x_recons = session.run(self.x_recons) return z_infer, x_recons, recons_loss
def fit(self, X, y=None): print('\nProcessing data...') self.data_train = data_utils.process_data(X, y, test_size=0) if self.config.plot: self.data_plot = self.data_train self.config.num_batches = self.data_train.num_batches( self.config.batch_size) if not self.config.isBuilt: self.config.restore = True self.build_model(self.data_train.height, self.data_train.width, self.data_train.num_channels) else: assert (self.config.height == self.data_train.height) and (self.config.width == self.data_train.width) and \ (self.config.num_channels == self.data_train.num_channels), \ 'Wrong dimension of data. Expected shape {}, and got {}'.format((self.config.height,self.config.width, \ self.config.num_channels), \ (self.data_train.height, self.data_train.width, \ self.data_train.num_channels) \ ) ''' ------------------------------------------------------------------------------- TRAIN THE MODEL ------------------------------------------------------------------------------------- ''' print('\nTraining a model...') with tf.Session(graph=self.graph) as session: tf.set_random_seed(self.config.seeds) self.session = session logger = Logger(self.session, self.config.log_dir) saver = tf.train.Saver() early_stopper = EarlyStopping(name='total loss', decay_fn=self.decay_fn) if (self.config.restore and self.load(self.session, saver)): load_config = file_utils.load_args(self.config.model_name, self.config.config_dir) self.config.update(load_config) num_epochs_trained = self.model_graph.cur_epoch_tensor.eval( self.session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initializing Variables ...') tf.global_variables_initializer().run() for cur_epoch in range( self.model_graph.cur_epoch_tensor.eval(self.session), self.config.epochs + 1, 1): print('EPOCH: ', cur_epoch) self.current_epoch = cur_epoch losses_tr = self._train(self.data_train, self.session, logger) if np.isnan(losses_tr[0]): print( 'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.' ) for lname, lval in zip(self.model_graph.losses, losses_tr): print(lname, lval) sys.exit() train_msg = 'TRAIN: \n' for lname, lval in zip(self.model_graph.losses, losses_tr): train_msg += str(lname) + ': ' + str(lval) + ' | ' print(train_msg) print() if (cur_epoch == 1) or ((cur_epoch % self.config.save_epoch == 0) and (cur_epoch != 0)): gc.collect() self.save( self.session, saver, self.model_graph.global_step_tensor.eval(self.session)) if self.config.plot: self.plot_latent(cur_epoch) self.session.run(self.model_graph.increment_cur_epoch_tensor) # Early stopping if (self.config.early_stopping and early_stopper.stop(losses_tr[0])): print('Early Stopping!') break if cur_epoch % self.config.colab_save == 0: if self.config.colab: self.push_colab() self.save(self.session, saver, self.model_graph.global_step_tensor.eval(self.session)) if self.config.plot: self.plot_latent(cur_epoch) if self.config.colab: self.push_colab() return
def train(self, data_train, data_valid, enable_es=1): with tf.Session(graph=self.graph) as session: tf.set_random_seed(1234) logger = Logger(session, self.summary_dir) # here you initialize the tensorflow saver that will be used in saving the checkpoints. # max_to_keep: defaults to keeping the 5 most recent checkpoints of your model saver = tf.train.Saver() self.session = session early_stopping = EarlyStopping(name='total loss', decay_fn=self.decay_fn) if (self.restore and self.load(session, saver)): num_epochs_trained = self.model_graph.cur_epoch_tensor.eval( session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initizalizing Variables ...') tf.global_variables_initializer().run() if (self.model_graph.cur_epoch_tensor.eval(session) == self.epochs ): return for cur_epoch in range( self.model_graph.cur_epoch_tensor.eval(session), self.epochs + 1, 1): print('EPOCH: ', cur_epoch) self.current_epoch = cur_epoch loss_tr, recons_tr, L2_loss = self.train_epoch( session, logger, data_train) if da.isnan(loss_tr): print( 'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.' ) print('Recons: ', recons_tr) sys.exit() loss_val, recons_val = self.valid_epoch( session, logger, data_valid) print('TRAIN | AE Loss: ', loss_tr, ' | Recons: ', recons_tr, ' | L2_loss: ', L2_loss) print('VALID | AE Loss: ', loss_val, ' | Recons: ', recons_val) if (cur_epoch == 1) or ((cur_epoch % const.SAVE_EPOCH == 0) and ((cur_epoch != 0))): self.save( session, saver, self.model_graph.global_step_tensor.eval(session)) if self.plot: self.generate_samples(data_train, session, cur_epoch) if self.clustering: self.generate_clusters(logger, cur_epoch, data_train, data_valid) session.run(self.model_graph.increment_cur_epoch_tensor) #Early stopping if (enable_es == 1 and early_stopping.stop(loss_val)): print('Early Stopping!') break if cur_epoch % 50 == 0: if self.colab: self.push_colab() self.save(session, saver, self.model_graph.global_step_tensor.eval(session)) if self.plot: self.generate_samples(data_train, session, cur_epoch) if self.clustering: self.generate_clusters(logger, cur_epoch, data_train, data_valid) if self.colab: self.push_colab() return
def train(self, data_train, data_valid, enable_es=1): with tf.Session(graph=self.graph) as session: tf.set_random_seed(1234) logger = Logger(session, self.summary_dir) # here you initialize the tensorflow saver that will be used in saving the checkpoints. # max_to_keep: defaults to keeping the 5 most recent checkpoints of your model saver = tf.train.Saver() early_stopping = EarlyStopping() if (self.restore == 1 and self.load(session, saver)): num_epochs_trained = self.model_graph.cur_epoch_tensor.eval( session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initizalizing Variables ...') tf.global_variables_initializer().run() if (self.model_graph.cur_epoch_tensor.eval(session) == self.epochs ): return for cur_epoch in range( self.model_graph.cur_epoch_tensor.eval(session), self.epochs + 1, 1): print('EPOCH: ', cur_epoch) self.current_epoch = cur_epoch # beta=utils.sigmoid(cur_epoch- 50) beta = 1. losses, recons, cond_prior, KL_w, y_prior, L2_loss = self.train_epoch( session, logger, data_train, beta=beta) train_string = 'TRAIN | Loss: ' + str(losses) + \ ' | Recons: ' + str(recons) + \ ' | CP: ' + str(cond_prior) + \ ' | KL_w: ' + str(KL_w) + \ ' | KL_y: ' + str(y_prior) + \ ' | L2_loss: '+ str(L2_loss) # train_string = colored(train_string, 'red', attrs=['reverse', 'blink']) train_string = colored(train_string, 'red') if np.isnan(losses): print( 'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.' ) print('Recons: ', recons) print('CP: ', cond_prior) print('KL_w: ', KL_w) print('KL_y: ', y_prior) sys.exit() loss_val, recons, cond_prior, KL_w, y_prior, L2_loss = self.valid_epoch( session, logger, data_valid, beta=beta) valid_string = 'VALID | Loss: ' + str(loss_val) + \ ' | Recons: ' + str(recons) + \ ' | CP: ' + str(cond_prior) + \ ' | KL_w: ' + str(KL_w) + \ ' | KL_y: ' + str(y_prior) + \ ' | L2_loss: '+ str(L2_loss) print(train_string) print(valid_string) if (cur_epoch > 0 and cur_epoch % 10 == 0): self.save( session, saver, self.model_graph.global_step_tensor.eval(session)) session.run(self.model_graph.increment_cur_epoch_tensor) #Early stopping if (enable_es == 1 and early_stopping.stop(loss_val)): print('Early Stopping!') break self.save(session, saver, self.model_graph.global_step_tensor.eval(session)) return
def fit(self, dataset): assert str(dataset.__class__).split('.')[0].replace("<class '", '') + '.' + str(dataset.__class__).split('.')[1] \ == "tensorflow_datasets.image", 'The dataset type is not image tensorflow_datasets' self.data_train = dataset.as_dataset(split=tfds.Split.TRAIN, shuffle_files=True, batch_size=self.config.batch_size) try: self.data_test = dataset.as_dataset(split=tfds.Split.TEST, shuffle_files=True, batch_size=self.config.batch_size) except: self.data_test = dataset.as_dataset(split=tfds.Split.TRAIN, shuffle_files=True, batch_size=self.config.batch_size) width = dataset.info.features['image'].shape[0] height = dataset.info.features['image'].shape[1] num_channels = dataset.info.features['image'].shape[2] self.config.ntrain_batches = dataset.info.splits['train'].num_examples // self.config.batch_size self.config.ntest_batches = dataset.info.splits['test'].num_examples // self.config.batch_size if not self.config.isBuilt: self.config.restore=True self.build_model(height, width, num_channels) else: assert (self.config.height == height) and (self.config.width == width) and \ (num_channels == num_channels), \ 'Wrong dimension of data. Expected shape {}, and got {}'.\ format((self.config.height, self.config.width, self.config.num_channels), (height, width, num_channels)) ''' ------------------------------------------------------------------------------- TRAIN THE MODEL ------------------------------------------------------------------------------------- ''' print('\nTraining a model...') with tf.Session(graph=self.graph) as session: tf.set_random_seed(self.config.seeds) self.session = session logger = Logger(self.session, self.config.log_dir) self.saver = tf.train.Saver() early_stopper = EarlyStopping(name='total loss', decay_fn=self.decay_fn) if(self.config.restore and self.load(self.session, self.saver) ): load_config = file_utils.load_args(self.config.model_name, self.config.config_dir, ['latent_mean', 'latent_std', 'samples', 'y_uniqs']) self.config.update(load_config) num_epochs_trained = self.model_graph.cur_epoch_tensor.eval(self.session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initializing Variables ...') tf.global_variables_initializer().run() if self.config.plot: if self.config.y_uniqs is None: print('\nFinding the unique categories...') y_uniqs = list() iterator = self.data_train.make_one_shot_iterator() for t in tqdm(range(self.config.ntrain_batches)): batch = session.run(iterator.get_next()) y, _ = tf.unique(batch[self.config.y_index]) y_uniqs += y.eval().tolist() self.config.y_uniqs = np.unique(y_uniqs) if self.config.samples is None: y_uniqs = self.config.y_uniqs[:10] y_uniqs = np.array(list(itertools.repeat(y_uniqs, 10))).flatten()[:10] print('\nSampling from the unique categories...') samples = dict(zip(y_uniqs, itertools.repeat(list(), len(y_uniqs)))) iterator = self.data_train.make_one_shot_iterator() for t in tqdm(range(self.config.ntrain_batches)): batch = session.run(iterator.get_next()) for yi in y_uniqs: if len(samples[yi]) <= 10: samples[yi] = samples[yi] + da.from_array( tf.boolean_mask(mask=batch[self.config.y_index]==yi, tensor=batch['image']).eval(), chunks=10).compute().tolist() samples[yi] = samples[yi][:10] self.config.samples = da.from_array(da.vstack(samples.values()), chunks=10).compute() if not self.config.isTrained: for cur_epoch in range(self.model_graph.cur_epoch_tensor.eval(self.session), self.config.epochs+1, 1): print('EPOCH: ', cur_epoch) self.current_epoch = cur_epoch losses_tr = self._train(self.data_train, self.session, logger, self.config.ntrain_batches) if np.isnan(losses_tr[0]): print('Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.') for lname, lval in zip(self.model_graph.losses, losses_tr): print(lname, lval) sys.exit() losses_test = self._test(self.data_test, self.session, logger, self.config.ntest_batches) train_msg = 'TRAIN: \n' for lname, lval in zip(self.model_graph.losses, losses_tr): train_msg += str(lname) + ': ' + str(lval) + ' | ' eval_msg = 'TEST: \n' for lname, lval in zip(self.model_graph.losses, losses_test): eval_msg += str(lname) + ': ' + str(lval) + ' | ' print(train_msg) print(eval_msg) print() if (cur_epoch == 1) or ((cur_epoch % self.config.save_epoch == 0) and (cur_epoch != 0)): self.save_model() if self.config.plot: self.plot_latent(cur_epoch) self.plot_reconst(cur_epoch) self.session.run(self.model_graph.increment_cur_epoch_tensor) # Early stopping if (self.config.early_stopping and early_stopper.stop(losses_test[0])): print('Early Stopping!') break if cur_epoch % self.config.colab_save == 0: if self.config.colab: self.push_colab() self.config.isTrained = True self.save_model() if self.config.plot: self.plot_latent(cur_epoch) self.plot_reconst(cur_epoch) if self.config.colab: self.push_colab()
def sent_clf(dataset, config, opts, transfer=False): from logger.experiment import Experiment opts.name = config["name"] X_train, y_train, X_val, y_val = dataset vocab = None if transfer: opts.transfer = config["pretrained_lm"] checkpoint = load_checkpoint(opts.transfer) config["vocab"].update(checkpoint["config"]["vocab"]) dict_pattern_rename(checkpoint["config"]["model"], {"rnn_": "bottom_rnn_"}) config["model"].update(checkpoint["config"]["model"]) vocab = checkpoint["vocab"] #################################################################### # Load Preprocessed Datasets #################################################################### if config["preprocessor"] == "twitter": preprocessor = twitter_preprocessor() else: preprocessor = None print("Building training dataset...") train_set = ClfDataset(X_train, y_train, vocab=vocab, preprocess=preprocessor, vocab_size=config["vocab"]["size"], seq_len=config["data"]["seq_len"]) print("Building validation dataset...") val_set = ClfDataset(X_val, y_val, seq_len=train_set.seq_len, preprocess=preprocessor, vocab=train_set.vocab) src_lengths = [len(x) for x in train_set.data] val_lengths = [len(x) for x in val_set.data] # select sampler & dataloader train_sampler = BucketBatchSampler(src_lengths, config["batch_size"], True) val_sampler = SortedSampler(val_lengths) val_sampler_train = SortedSampler(src_lengths) train_loader = DataLoader(train_set, batch_sampler=train_sampler, num_workers=opts.cores, collate_fn=ClfCollate()) val_loader = DataLoader(val_set, sampler=val_sampler, batch_size=config["batch_size"], num_workers=opts.cores, collate_fn=ClfCollate()) val_loader_train_dataset = DataLoader(train_set, sampler=val_sampler_train, batch_size=config["batch_size"], num_workers=opts.cores, collate_fn=ClfCollate()) #################################################################### # Model #################################################################### ntokens = len(train_set.vocab) model = Classifier(ntokens, len(set(train_set.labels)), **config["model"]) model.to(opts.device) clf_criterion = nn.CrossEntropyLoss() lm_criterion = nn.CrossEntropyLoss(ignore_index=0) embed_parameters = filter(lambda p: p.requires_grad, model.embed.parameters()) bottom_parameters = filter( lambda p: p.requires_grad, chain(model.bottom_rnn.parameters(), model.vocab.parameters())) if config["model"]["has_att"]: top_parameters = filter( lambda p: p.requires_grad, chain(model.top_rnn.parameters(), model.attention.parameters(), model.classes.parameters())) else: top_parameters = filter( lambda p: p.requires_grad, chain(model.top_rnn.parameters(), model.classes.parameters())) embed_optimizer = optim.ASGD(embed_parameters, lr=0.0001) rnn_optimizer = optim.ASGD(bottom_parameters) top_optimizer = Adam(top_parameters, lr=config["top_lr"]) #################################################################### # Training Pipeline #################################################################### # Trainer: responsible for managing the training process trainer = SentClfTrainer(model, train_loader, val_loader, (lm_criterion, clf_criterion), [embed_optimizer, rnn_optimizer, top_optimizer], config, opts.device, valid_loader_train_set=val_loader_train_dataset, unfreeze_embed=config["unfreeze_embed"], unfreeze_rnn=config["unfreeze_rnn"]) #################################################################### # Experiment: logging and visualizing the training process #################################################################### exp = Experiment(opts.name, config, src_dirs=opts.source, output_dir=EXP_DIR) exp.add_metric("ep_loss_lm", "line", "epoch loss lm", ["TRAIN", "VAL"]) exp.add_metric("ep_loss_cls", "line", "epoch loss class", ["TRAIN", "VAL"]) exp.add_metric("ep_f1", "line", "epoch f1", ["TRAIN", "VAL"]) exp.add_metric("ep_acc", "line", "epoch accuracy", ["TRAIN", "VAL"]) exp.add_value("epoch", title="epoch summary") exp.add_value("progress", title="training progress") #################################################################### # Resume Training from a previous checkpoint #################################################################### if transfer: print("Transferring Encoder weights ...") dict_pattern_rename(checkpoint["model"], { "encoder": "bottom_rnn", "decoder": "vocab" }) load_state_dict_subset(model, checkpoint["model"]) print(model) #################################################################### # Training Loop #################################################################### best_loss = None early_stopping = EarlyStopping("min", config["patience"]) for epoch in range(0, config["epochs"]): train_loss = trainer.train_epoch() val_loss, y, y_pred = trainer.eval_epoch(val_set=True) _, y_train, y_pred_train = trainer.eval_epoch(train_set=True) exp.update_metric("ep_loss_lm", train_loss[0], "TRAIN") exp.update_metric("ep_loss_lm", val_loss[0], "VAL") exp.update_metric("ep_loss_cls", train_loss[1], "TRAIN") exp.update_metric("ep_loss_cls", val_loss[1], "VAL") exp.update_metric("ep_f1", f1_macro(y_train, y_pred_train), "TRAIN") exp.update_metric("ep_f1", f1_macro(y, y_pred), "VAL") exp.update_metric("ep_acc", acc(y_train, y_pred_train), "TRAIN") exp.update_metric("ep_acc", acc(y, y_pred), "VAL") print() epoch_log = exp.log_metrics( ["ep_loss_lm", "ep_loss_cls", "ep_f1", "ep_acc"]) print(epoch_log) exp.update_value("epoch", epoch_log) # Save the model if the val loss is the best we've seen so far. if not best_loss or val_loss[1] < best_loss: best_loss = val_loss[1] trainer.best_acc = acc(y, y_pred) trainer.best_f1 = f1_macro(y, y_pred) trainer.checkpoint(name=opts.name, timestamp=True) if early_stopping.stop(val_loss[1]): print("Early Stopping (according to classification loss)....") break print("\n" * 2) return best_loss, trainer.best_acc, trainer.best_f1
def fit(self, X, y=None): print('\nProcessing data...') self.data_train, self.data_eval = utils.process_data(X, y) self.config['num_batches'] = self.data_train.num_batches( self.config.batch_size) if not self.isBuild: self.config.restore = True self.build_model(self.data_train.height, self.data_train.width, self.data_train.num_channels) else: assert (self.config.height == self.data_train.height) and (self.config.width == self.data_train.width) and \ (self.config.num_channels == self.data_train.num_channels), \ 'Wrong dimension of data. Expected shape {}, and got {}'.format((self.config.height,self.config.width, \ self.config.num_channels), \ (self.data_train.height, self.data_train.width, \ self.data_train.num_channels) \ ) ''' ------------------------------------------------------------------------------- TRAIN THE MODEL ------------------------------------------------------------------------------------- ''' print('\nTraining a model...') with tf.Session(graph=self.graph) as session: tf.set_random_seed(222222) self.session = session logger = Logger(self.session, self.config.summary_dir) saver = tf.train.Saver() early_stopper = EarlyStopping(name='total loss', decay_fn=self.decay_fn) if (self.config.restore and self.load(self.session, saver)): num_epochs_trained = self.model_graph.cur_epoch_tensor.eval( self.session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initializing Variables ...') tf.global_variables_initializer().run() for cur_epoch in range( self.model_graph.cur_epoch_tensor.eval(self.session), self.config.epochs + 1, 1): print('EPOCH: ', cur_epoch) self.current_epoch = cur_epoch losses_tr = self._train(self.data_train, self.session, logger) if np.isnan(losses_tr[0]): print( 'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.' ) for lname, lval in zip(self.model_graph.losses, losses_tr): print(lname, lval) sys.exit() losses_eval = self._evaluate(self.data_eval, self.session, logger) train_msg = 'TRAIN: \n' for lname, lval in zip(self.model_graph.losses, losses_tr): train_msg += str(lname) + ': ' + str(lval) + ' | ' eval_msg = 'EVALUATE: \n' for lname, lval in zip(self.model_graph.losses, losses_eval): eval_msg += str(lname) + ': ' + str(lval) + ' | ' print(train_msg) print(eval_msg) print() if (cur_epoch == 1) or ((cur_epoch % const.SAVE_EPOCH == 0) and (cur_epoch != 0)): self.save( self.session, saver, self.model_graph.global_step_tensor.eval(self.session)) if self.config.plot: self.reconst_samples_from_data(self.data_train, self.session, cur_epoch) self.session.run(self.model_graph.increment_cur_epoch_tensor) # Early stopping if (self.config.early_stopping and early_stopper.stop(losses_eval[0])): print('Early Stopping!') break if cur_epoch % const.COLAB_SAVE == 0: if self.config.colab: self.push_colab() self.save(self.session, saver, self.model_graph.global_step_tensor.eval(self.session)) if self.config.plot: self.reconst_samples_from_data(self.data_train, self.session, cur_epoch) if self.config.colab: self.push_colab() z = self.encode(self.data_train.x) self.config.latent_max = z.max().compute() self.config.latent_min = z.min().compute() self.config.latent_std = z.std().compute() del z return
def train(classifier, config, train_batcher, valid_batcher, test_batcher): """ 开启session, 真正的执行训练步骤。 每训练完一个epoch之后使用验证集数据验证模型,并写入summary。 指定的epoch数(配置文件的train_epoch项)训练完后,使用测试集数据测试模型。 :param classifier: 分类器 :param config: 配置文件 :param train_batcher: 训练集数据的batch生成器。 batch的shape:[config.batch_size, config.max_len] :param valid_batcher: 验证集数据的batch生成器。 :param test_batcher: 测试集数据的batch生成器。 :return: 无。模型的训练进度保存在路径:config.log_root/config.model_name/train """ # 创建eval_config,与config唯一不同是 dropout设置为1.0;验证和测试模型时使用eval_config eval_config = copy.deepcopy(config) eval_config.keep_prob = 1.0 # 定义模型训练进度和summary的存放地址 train_dir train_dir = os.path.join(config.log_root, config.model_name, "train") # 定义训练集和验证集的summary_writer summary_writer_train = tf.summary.FileWriter( os.path.join(train_dir, "summaries", "summaries_train")) summary_writer_valid = tf.summary.FileWriter( os.path.join(train_dir, "summaries", "summaries_valid")) saver = tf.train.Saver(max_to_keep=3) early_stop = EarlyStopping(config.patience, mode='min') # 设置early_stop. 监控元素为 验证集的loss # 配置参数:内存自增长 config_proto = tf.ConfigProto(allow_soft_placement=True) config_proto.gpu_options.allow_growth = True # 创建MonitoredSession with tf.train.MonitoredTrainingSession( is_chief=True, checkpoint_dir=train_dir, scaffold=None, # ? hooks=None, # ? chief_only_hooks=None, # ? save_checkpoint_secs=None, # config. save_summaries_steps=None, save_summaries_secs=None, config=config_proto, ) as sess: # 计算需要运行的step数 if config.train_epoch: step_num = int(config.train_epoch * train_batcher.data_size / train_batcher.batch_size) else: step_num = config.train_step start = time.time() for step in range(step_num): if sess.should_stop(): # 可能有未知情况需要停止 break # run train data one epoch t0 = time.time() global_step = classifier.global_step.eval(sess) print("\nglobal_step : {}".format(global_step)) return_dict = run_some_steps(classifier, sess, config, run_train_step, train_batcher, summary_writer_train, step_num=config.once_train_step) print("Train time: %.3f s" % (time.time() - t0)) print("Train accuracy: %.3f %%, loss: %.4f" % (return_dict['acc'] * 100, return_dict['loss'])) # evaluate valid data tf.logging.set_verbosity(tf.logging.ERROR) return_dict = run_one_epoch(classifier, sess, config, run_eval_step, valid_batcher, summary_writer_valid) tf.logging.set_verbosity(tf.logging.WARN) print("Valid accuracy: %.3f %%, loss: %.4f" % (return_dict['acc'] * 100, return_dict['loss'])) # early stopping if early_stop.add_monitor(return_dict['loss']): # 保存 print('发现更优的模型参数,进行保存..') # 这个奇葩的写法,详见:https://github.com/tensorflow/tensorflow/issues/8425 saver.save(sess._sess._sess._sess._sess, os.path.join(train_dir, 'model.ckpt'), global_step=global_step) if early_stop.stop(): print('触发early stopping! 前 {} 个 monitor 为 {}'.format( early_stop.patience, early_stop.pre_monitor)) break print("Training finished, time consumed : %.3f s" % (time.time() - start)) # evaluate test data print("\nStart evaluating:") tf.logging.set_verbosity(tf.logging.ERROR) return_dict = run_one_epoch(classifier, sess, config, run_eval_step, test_batcher) tf.logging.set_verbosity(tf.logging.WARN) print("Test accuracy: %.3f %%, loss: %.4f" % (return_dict['acc'] * 100, return_dict['loss'])) # 计算评价指标 metrics_model(return_dict['real_label'], return_dict['predict'])
def train(self, data_train, data_valid, enable_es=1): with tf.Session(graph=self.graph) as session: tf.set_random_seed(1234) logger = Logger(session, self.summary_dir) # here you initialize the tensorflow saver that will be used in saving the checkpoints. # max_to_keep: defaults to keeping the 5 most recent checkpoints of your model saver = tf.train.Saver() early_stopping = EarlyStopping() if (self.restore == 1 and self.load(session, saver)): num_epochs_trained = self.vae_graph.cur_epoch_tensor.eval( session) print('EPOCHS trained: ', num_epochs_trained) else: print('Initizalizing Variables ...') tf.global_variables_initializer().run() if (self.vae_graph.cur_epoch_tensor.eval(session) == self.epochs): return for cur_epoch in range( self.vae_graph.cur_epoch_tensor.eval(session), self.epochs + 1, 1): print('EPOCH: ', cur_epoch) self.current_epoch = cur_epoch # beta=utils.sigmoid(cur_epoch- 50) beta = 1. loss_tr, recons_tr, cond_prior_tr, L2_loss = self.train_epoch( session, logger, data_train, beta=beta) if np.isnan(loss_tr): print( 'Encountered NaN, stopping training. Please check the learning_rate settings and the momentum.' ) print('Recons: ', recons_tr) print('KL: ', cond_prior_tr) sys.exit() loss_val, recons_val, cond_prior_val = self.valid_epoch( session, logger, data_valid, beta=beta) print('TRAIN | Loss: ', loss_tr, ' | Recons: ', recons_tr, ' | KL: ', cond_prior_tr, ' | L2_loss: ', L2_loss) print('VALID | Loss: ', loss_val, ' | Recons: ', recons_val, ' | KL: ', cond_prior_val) if (cur_epoch > 0 and cur_epoch % 10 == 0): self.save(session, saver, self.vae_graph.global_step_tensor.eval(session)) z_matrix = self.vae_graph.get_z_matrix( session, data_valid.random_batch(self.batch_size)) np.savez(self.z_file, z_matrix) session.run(self.vae_graph.increment_cur_epoch_tensor) #Early stopping if (enable_es == 1 and early_stopping.stop(loss_val)): print('Early Stopping!') break self.save(session, saver, self.vae_graph.global_step_tensor.eval(session)) z_matrix = self.vae_graph.get_z_matrix( session, data_valid.random_batch(self.batch_size)) np.savez(self.z_file, z_matrix) return