def train (self, train_data = None, valid_data=None) : # if valid_data is None: # no validation set specified. # valid_data = train_data # using training set as validation set. stop_batch = self.stop_batch self._init_session() # Before data shard is enabled, only cheif do evaluation and record it # self.print_head() fp = None if self.run_opt.is_chief : fp = open(self.disp_file, "a") cur_batch = run_sess(self.sess, self.global_step) is_first_step = True self.cur_batch = cur_batch log.info("start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e" % (run_sess(self.sess, self.learning_rate), self.lr.value(cur_batch), self.lr.decay_steps_, self.lr.decay_rate_, self.lr.value(stop_batch)) ) prf_options = None prf_run_metadata = None if self.profiling: prf_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) prf_run_metadata = tf.RunMetadata() # set tensorboard execution environment if self.tensorboard: summary_merged_op = tf.summary.merge_all() # Remove TB old logging directory from previous run try: shutil.rmtree(self.tensorboard_log_dir) except FileNotFoundError: pass # directory does not exist, this is OK except Exception as e: # general error when removing directory, warn user log.exception( f"Could not remove old tensorboard logging directory: " f"{self.tensorboard_log_dir}. Error: {e}" ) else: log.debug("Removing old tensorboard log directory.") tb_train_writer = tf.summary.FileWriter(self.tensorboard_log_dir + '/train', self.sess.graph) tb_valid_writer = tf.summary.FileWriter(self.tensorboard_log_dir + '/test') else: tb_train_writer = None tb_valid_writer = None train_time = 0 while cur_batch < stop_batch : # first round validation: train_batch = train_data.get_batch() if self.display_in_training and is_first_step: if self.run_opt.is_chief: valid_batches = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None self.valid_on_the_fly(fp, [train_batch], valid_batches, print_header=True) is_first_step = False if self.timing_in_training: tic = time.time() train_feed_dict = self.get_feed_dict(train_batch, is_training=True) # use tensorboard to visualize the training of deepmd-kit # it will takes some extra execution time to generate the tensorboard data if self.tensorboard and (cur_batch % self.tensorboard_freq == 0): summary, _ = run_sess(self.sess, [summary_merged_op, self.train_op], feed_dict=train_feed_dict, options=prf_options, run_metadata=prf_run_metadata) tb_train_writer.add_summary(summary, cur_batch) else: run_sess(self.sess, [self.train_op], feed_dict=train_feed_dict, options=prf_options, run_metadata=prf_run_metadata) if self.timing_in_training: toc = time.time() if self.timing_in_training: train_time += toc - tic cur_batch = run_sess(self.sess, self.global_step) self.cur_batch = cur_batch # on-the-fly validation if self.display_in_training and (cur_batch % self.disp_freq == 0): if self.timing_in_training: tic = time.time() if self.run_opt.is_chief: valid_batches = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None self.valid_on_the_fly(fp, [train_batch], valid_batches) if self.timing_in_training: toc = time.time() test_time = toc - tic log.info("batch %7d training time %.2f s, testing time %.2f s" % (cur_batch, train_time, test_time)) train_time = 0 if self.save_freq > 0 and cur_batch % self.save_freq == 0 and self.saver is not None: try: ckpt_prefix = self.saver.save (self.sess, os.path.join(os.getcwd(), self.save_ckpt), global_step=cur_batch) except google.protobuf.message.DecodeError as e: raise GraphTooLargeError( "The graph size exceeds 2 GB, the hard limitation of protobuf." " Then a DecodeError was raised by protobuf. You should " "reduce the size of your model." ) from e # make symlinks from prefix with step to that without step to break nothing # get all checkpoint files original_files = glob.glob(ckpt_prefix + ".*") for ori_ff in original_files: new_ff = self.save_ckpt + ori_ff[len(ckpt_prefix):] try: # remove old one os.remove(new_ff) except OSError: pass os.symlink(ori_ff, new_ff) log.info("saved checkpoint %s" % self.save_ckpt) if self.run_opt.is_chief: fp.close () if self.profiling and self.run_opt.is_chief : fetched_timeline = timeline.Timeline(prf_run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format() with open(self.profiling_file, 'w') as f: f.write(chrome_trace)
def train(self, data): stop_batch = self.stop_batch if self.run_opt.is_distrib: self._init_sess_distrib() else: self._init_sess_serial() self.print_head() fp = None if self.run_opt.is_chief: fp = open(self.disp_file, "a") cur_batch = self.sess.run(self.global_step) is_first_step = True self.cur_batch = cur_batch self.run_opt.message( "start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e" % (self.sess.run(self.learning_rate), self.lr.value(cur_batch), self.lr.decay_steps_, self.lr.decay_rate_, self.lr.value(stop_batch))) prf_options = None prf_run_metadata = None if self.profiling: prf_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) prf_run_metadata = tf.RunMetadata() train_time = 0 while cur_batch < stop_batch: batch_data = data.get_batch(sys_probs=self.sys_probs, auto_prob_style=self.auto_prob_style) feed_dict_batch = {} for kk in batch_data.keys(): if kk == 'find_type' or kk == 'type': continue if 'find_' in kk: feed_dict_batch[self.place_holders[kk]] = batch_data[kk] else: feed_dict_batch[self.place_holders[kk]] = np.reshape( batch_data[kk], [-1]) for ii in ['type']: feed_dict_batch[self.place_holders[ii]] = np.reshape( batch_data[ii], [-1]) for ii in ['natoms_vec', 'default_mesh']: feed_dict_batch[self.place_holders[ii]] = batch_data[ii] feed_dict_batch[self.place_holders['is_training']] = True if self.display_in_training and is_first_step: self.test_on_the_fly(fp, data, feed_dict_batch) is_first_step = False if self.timing_in_training: tic = time.time() self.sess.run([self.train_op], feed_dict=feed_dict_batch, options=prf_options, run_metadata=prf_run_metadata) if self.timing_in_training: toc = time.time() if self.timing_in_training: train_time += toc - tic cur_batch = self.sess.run(self.global_step) self.cur_batch = cur_batch if self.display_in_training and (cur_batch % self.disp_freq == 0): tic = time.time() self.test_on_the_fly(fp, data, feed_dict_batch) toc = time.time() test_time = toc - tic if self.timing_in_training: self._message( "batch %7d training time %.2f s, testing time %.2f s" % (cur_batch, train_time, test_time)) train_time = 0 if self.save_freq > 0 and cur_batch % self.save_freq == 0 and self.run_opt.is_chief: if self.saver is not None: self.saver.save(self.sess, os.getcwd() + "/" + self.save_ckpt) self._message("saved checkpoint %s" % self.save_ckpt) if self.run_opt.is_chief: fp.close() if self.profiling and self.run_opt.is_chief: fetched_timeline = timeline.Timeline(prf_run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format() with open(self.profiling_file, 'w') as f: f.write(chrome_trace)
def train(self, train_data=None, valid_data=None): # if valid_data is None: # no validation set specified. # valid_data = train_data # using training set as validation set. stop_batch = self.stop_batch self._init_session() # Before data shard is enabled, only cheif do evaluation and record it # self.print_head() fp = None if self.run_opt.is_chief: fp = open(self.disp_file, "a") cur_batch = run_sess(self.sess, self.global_step) is_first_step = True self.cur_batch = cur_batch log.info( "start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e" % (run_sess(self.sess, self.learning_rate), self.lr.value(cur_batch), self.lr.decay_steps_, self.lr.decay_rate_, self.lr.value(stop_batch))) prf_options = None prf_run_metadata = None if self.profiling: prf_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) prf_run_metadata = tf.RunMetadata() # set tensorboard execution environment if self.tensorboard: summary_merged_op = tf.summary.merge_all() # Remove TB old logging directory from previous run try: shutil.rmtree(self.tensorboard_log_dir) except FileNotFoundError: pass # directory does not exist, this is OK except Exception as e: # general error when removing directory, warn user log.exception( f"Could not remove old tensorboard logging directory: " f"{self.tensorboard_log_dir}. Error: {e}") else: log.debug("Removing old tensorboard log directory.") tb_train_writer = tf.summary.FileWriter( self.tensorboard_log_dir + '/train', self.sess.graph) tb_valid_writer = tf.summary.FileWriter(self.tensorboard_log_dir + '/test') else: tb_train_writer = None tb_valid_writer = None if self.enable_profiler: # https://www.tensorflow.org/guide/profiler tfv2.profiler.experimental.start(self.tensorboard_log_dir) train_time = 0 while cur_batch < stop_batch: # first round validation: train_batch = train_data.get_batch() if self.display_in_training and is_first_step: if self.run_opt.is_chief: valid_batches = [ valid_data.get_batch() for ii in range(self.valid_numb_batch) ] if valid_data is not None else None self.valid_on_the_fly(fp, [train_batch], valid_batches, print_header=True) is_first_step = False if self.timing_in_training: tic = time.time() train_feed_dict = self.get_feed_dict(train_batch, is_training=True) # use tensorboard to visualize the training of deepmd-kit # it will takes some extra execution time to generate the tensorboard data if self.tensorboard and (cur_batch % self.tensorboard_freq == 0): summary, _ = run_sess(self.sess, [summary_merged_op, self.train_op], feed_dict=train_feed_dict, options=prf_options, run_metadata=prf_run_metadata) tb_train_writer.add_summary(summary, cur_batch) else: run_sess(self.sess, [self.train_op], feed_dict=train_feed_dict, options=prf_options, run_metadata=prf_run_metadata) if self.timing_in_training: toc = time.time() if self.timing_in_training: train_time += toc - tic cur_batch = run_sess(self.sess, self.global_step) self.cur_batch = cur_batch # on-the-fly validation if self.display_in_training and (cur_batch % self.disp_freq == 0): if self.timing_in_training: tic = time.time() if self.run_opt.is_chief: valid_batches = [ valid_data.get_batch() for ii in range(self.valid_numb_batch) ] if valid_data is not None else None self.valid_on_the_fly(fp, [train_batch], valid_batches) if self.timing_in_training: toc = time.time() test_time = toc - tic log.info( "batch %7d training time %.2f s, testing time %.2f s" % (cur_batch, train_time, test_time)) train_time = 0 if self.save_freq > 0 and cur_batch % self.save_freq == 0 and self.saver is not None: self.save_checkpoint(cur_batch) if (self.save_freq == 0 or cur_batch == 0 or cur_batch % self.save_freq != 0) and self.saver is not None: self.save_checkpoint(cur_batch) if self.run_opt.is_chief: fp.close() if self.profiling and self.run_opt.is_chief: fetched_timeline = timeline.Timeline(prf_run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format() with open(self.profiling_file, 'w') as f: f.write(chrome_trace) if self.enable_profiler and self.run_opt.is_chief: tfv2.profiler.experimental.stop()