def train(opts): # --------------- OPTIONS --------------------- total_samples = data_loader.get_dataset_files_count(opts, is_training=True) opts["dataset_repeat"] = math.ceil( (opts["num_train_steps"] * opts["global_batch_size"]) / total_samples) total_samples_per_epoch = total_samples / opts["duplicate_factor"] logger.info(f"Total samples for each epoch {total_samples_per_epoch}") logger.info(f"Global batch size {opts['global_batch_size']}") steps_per_epoch = total_samples_per_epoch // opts["global_batch_size"] logger.info(f"Total steps for each epoch {steps_per_epoch}") steps_per_logs = math.ceil( opts["steps_per_logs"] / opts['batches_per_step']) * opts['batches_per_step'] steps_per_tensorboard = math.ceil( opts["steps_per_tensorboard"] / opts['batches_per_step']) * opts['batches_per_step'] steps_per_ckpts = math.ceil( opts["steps_per_ckpts"] / opts['batches_per_step']) * opts['batches_per_step'] logger.info(f"Checkpoint will be saved every {steps_per_ckpts} steps.") total_steps = (opts["num_train_steps"] // opts['batches_per_step']) * opts['batches_per_step'] logger.info( f"{opts['batches_per_step']} steps will be run for ipu to host synchronization once, it should be divided by num_train_steps, so num_train_steps will limit to {total_steps}.", opts) # learning rate strategy lr_schedule_name = opts['lr_schedule'] logger.info(f"Using learning rate schedule {lr_schedule_name}") learning_rate_schedule = make_lr_schedule(lr_schedule_name, opts, total_steps) # variable loss scaling loss_scaling_schedule = LossScalingScheduler(opts['loss_scaling'], opts['loss_scaling_by_step']) # -------------- BUILD TRAINING GRAPH ---------------- train = build_graph(opts, is_training=True) train.session.run(train.init) train.session.run(train.iterator.initializer) is_main_worker = opts['distributed_worker_index'] == 0 step = 0 # -------------- SAVE AND RESTORE -------------- if opts["restore_dir"]: restore_path = opts['restore_dir'] if os.path.isfile(restore_path): latest_checkpoint = os.path.splitext(restore_path)[0] else: latest_checkpoint = tf.train.latest_checkpoint(restore_path) logger.info( f"Restoring training from latest checkpoint: {latest_checkpoint}") step_pattern = re.compile(".*ckpt-([0-9]+)$") step = int(step_pattern.match(latest_checkpoint).groups()[0]) train.saver.restore(train.session, latest_checkpoint) epoch = step / steps_per_epoch # restore event files source_path = os.path.join(opts["restore_dir"], '/event') target_path = os.path.join(opts["save_path"], '/event') if os.path.isdir(source_path): copytree(source_path, target_path) else: if opts["init_checkpoint"]: train.saver.restore(train.session, opts["init_checkpoint"]) logger.info( f'Init Model from checkpoint {opts["init_checkpoint"]}') if opts['save_path']: file_path = train.saver.save(train.session, opts["checkpoint_path"], global_step=0) logger.info(f"Saved checkpoint to {file_path}") # Initialise Weights & Biases if available if opts['wandb'] and is_main_worker: import wandb wandb.init(project="tf-bert", sync_tensorboard=True, name=opts['wandb_name']) wandb.config.update(opts) # Tensorboard logs path log_path = os.path.join(opts["logs_path"], 'event') logger.info("Tensorboard event file path {}".format(log_path)) summary_writer = tf.summary.FileWriter(log_path, train.graph, session=train.session) # End to avoid any training if compile only mode if opts['compile_only']: # single warm up step without weight update or training # Graph gets compiled in here compilation_time, _, _, _, _ = training_step(train, 0, 0) print("Training graph successfully compiled. " + "Exiting as --compile-only was passed.") # Copying these from below, adding compile time to summary poplar_summary = tf.Summary() poplar_summary.value.add(tag='poplar/compile_time', simple_value=compilation_time) summary_writer.add_summary(poplar_summary) summary_writer.flush() logger.info("Compile time: {}".format(compilation_time)) sys.exit(0) # ------------- TRAINING LOOP ---------------- print_format = ( "step: {step:6d}, epoch: {epoch:6.2f}, lr: {lr:6.7f}, mlm_loss: {mlm_loss:6.3f}, nsp_loss: {nsp_loss:6.3f},\ mlm_acc: {mlm_acc:6.5f}, nsp_acc: {nsp_acc:6.5f}, samples/sec: {samples_per_sec:6.2f}, time: {iter_time:8.6f}, total_time: {total_time:8.1f}" ) learning_rate = mlm_loss = nsp_loss = 0 start_all = time.time() try: while step < total_steps: learning_rate = learning_rate_schedule.get_at_step(step) loss_scaling = loss_scaling_schedule.get_at_step(step) try: batch_time, mlm_loss, nsp_loss, mlm_acc, nsp_acc = training_step( train, learning_rate, loss_scaling) except tf.errors.OpError as e: raise tf.errors.ResourceExhaustedError(e.node_def, e.op, e.message) batch_time /= opts['batches_per_step'] is_log_step = (step % steps_per_logs == 0) is_save_tensorboard_step = (steps_per_tensorboard > 0 and (step % steps_per_tensorboard == 0)) is_save_ckpt_step = (step and (step % steps_per_ckpts == 0 or step == total_steps - opts['batches_per_step'])) if (step == 1 and (is_main_worker or opts['log_all_workers'])): poplar_compile_time = time.time() - start_all logger.info(f"Poplar compile time: {poplar_compile_time:.2f}s") poplar_summary = tf.Summary() poplar_summary.value.add(tag='poplar/compile_time', simple_value=poplar_compile_time) summary_writer.add_summary(poplar_summary) if is_log_step: total_time = time.time() - start_all epoch = step / steps_per_epoch stats = OrderedDict([ ('step', step), ('epoch', epoch), ('lr', learning_rate), ('loss_scaling', loss_scaling), ('mlm_loss', mlm_loss), ('nsp_loss', nsp_loss), ('mlm_acc', mlm_acc), ('nsp_acc', nsp_acc), ('iter_time', batch_time), ('samples_per_sec', opts['global_batch_size'] / batch_time), ('total_time', total_time), ]) logger.info(print_format.format(**stats)) # Log training statistics train_summary = tf.Summary() train_summary.value.add(tag='epoch', simple_value=epoch) train_summary.value.add(tag='loss/MLM', simple_value=mlm_loss) train_summary.value.add(tag='loss/NSP', simple_value=nsp_loss) train_summary.value.add(tag='accuracy/MLM', simple_value=mlm_acc) train_summary.value.add(tag='accuracy/NSP', simple_value=nsp_acc) train_summary.value.add(tag='learning_rate', simple_value=learning_rate) train_summary.value.add(tag='loss_scaling', simple_value=loss_scaling) train_summary.value.add(tag='samples_per_sec', simple_value=opts['global_batch_size'] / batch_time) train_summary.value.add(tag='samples', simple_value=step * opts['batches_per_step'] * opts['global_batch_size']) summary_writer.add_summary(train_summary, step) summary_writer.flush() if is_save_ckpt_step or is_save_tensorboard_step: if is_main_worker: file_path = train.saver.save(train.session, opts["checkpoint_path"], global_step=step) logger.info(f"Saved checkpoint to {file_path}") if is_save_tensorboard_step: log.save_model_statistics(file_path, summary_writer, step) if opts['use_popdist']: ipu_utils.barrier() step += opts['batches_per_step'] finally: train.session.close()
def train(bert_config, opts): # --------------- OPTIONS --------------------- epochs = opts["epochs"] total_samples = dataset.get_dataset_files_count(opts, is_training=True) logger.info("Total samples with duplications {}".format(total_samples)) total_independent_samples = total_samples // opts['duplication_factor'] logger.info("Total samples without duplications {}".format( total_independent_samples)) steps_per_epoch = total_independent_samples // (opts['batches_per_step'] * opts["total_batch_size"]) iterations_per_epoch = total_independent_samples // ( opts["total_batch_size"]) # total iterations if opts['steps']: logger.warn("Ignoring the epoch flag and using the steps one") steps = opts['steps'] else: steps = epochs * steps_per_epoch logger.info( "Total training steps equal to {}, total number of samples being analyzed equal to {}" .format(steps, steps * opts['batches_per_step'] * opts['total_batch_size'])) iterations_per_step = opts['batches_per_step'] ckpt_per_step = opts['steps_per_ckpts'] # avoid nan issue caused by queue length is zero. queue_len = iterations_per_epoch // iterations_per_step if queue_len == 0: queue_len = 1 batch_times = deque(maxlen=queue_len) # learning rate strategy lr_schedule_name = opts['lr_schedule'] logger.info(f"Using learning rate schedule {lr_schedule_name}") LR = make_lr_schedule(lr_schedule_name, opts, steps) if opts['do_train']: # -------------- BUILD TRAINING GRAPH ---------------- train = build_graph(bert_config, opts, iterations_per_step, is_training=True, feed_name="trainfeed") train.session.run(train.init) train.session.run(train.iterator.initializer) step = 0 i = 0 if opts['restore_path'] is not None: if os.path.isdir(opts['restore_path']): ckpt_file_path = tf.train.latest_checkpoint( opts['restore_path']) logger.info(f"Restoring training from latest checkpoint") else: # Assume it's a directory ckpt_file_path = opts['restore_path'] logger.info( f"Restoring training from checkpoint: {ckpt_file_path}") train.restore.restore(train.session, ckpt_file_path) ckpt_pattern = re.compile(".*ckpt-([0-9]+)$") i = int(ckpt_pattern.match(ckpt_file_path).groups()[0]) step = int(i // iterations_per_step) if opts['start_from_ckpt']: # We use a checkpoint to initialise our model train.restore.restore(train.session, opts['start_from_ckpt']) logger.info("Starting the training from the checkpoint {}".format( opts['start_from_ckpt'])) # Initialise Weights & Biases if available if opts['wandb']: import wandb wandb.init(project="tf-bert", sync_tensorboard=True) wandb.config.update(opts) # Tensorboard logs path log_path = os.path.join(opts["logs_path"], 'event') logger.info("Tensorboard event file path {}".format(log_path)) summary_writer = tf.summary.FileWriter(log_path, train.graph, session=train.session) # ------------- TRAINING LOOP ---------------- logger.info( "################################################################################" ) logger.info("Start training......") print_format = ( "step: {step:6d}, iteration: {iteration:6d}, epoch: {epoch:6.3f}, lr: {lr:10.3g}, mlm_loss: {mlm_loss:6.3f}, nsp_loss: {nsp_loss:6.3f}, " "samples/sec: {samples_per_sec:6.2f}, time: {iter_time:8.6f}, total_time: {total_time:8.1f}, mlm_acc: {mlm_acc:8.5f}, nsp_acc: {nsp_acc:8.5f}" ) start_all = time.time() train_saver = train.saver["train_saver"] best_saver = train.saver["best_saver"] # We initialize the best loss to a super large value best_total_loss = 1e10 best_step = 0 while step < steps: # Run Training learning_rate = LR.feed_dict_lr(step) try: batch_time, mlm_loss, nsp_loss, mlm_acc, nsp_acc = training_step( train, learning_rate) except tf.errors.OpError as e: raise tf.errors.ResourceExhaustedError(e.node_def, e.op, e.message) epoch = float( opts["total_batch_size"] * i) / total_independent_samples batch_time /= iterations_per_step if step != 0: batch_times.append([batch_time]) if step == 1: poplar_compile_time = time.time() - start_all poplar_summary = tf.Summary() poplar_summary.value.add(tag='poplar/compile_time', simple_value=poplar_compile_time) summary_writer.add_summary(poplar_summary) # Print loss if step % opts['steps_per_logs'] == 0: if len(batch_times) != 0: avg_batch_time = np.mean(batch_times) else: avg_batch_time = batch_time samples_per_sec = opts['total_batch_size'] / avg_batch_time # flush times every time it is reported batch_times.clear() total_time = time.time() - start_all stats = OrderedDict([ ('step', step), ('iteration', i), ('epoch', epoch), ('lr', learning_rate), ('mlm_loss', mlm_loss), ('nsp_loss', nsp_loss), ('mlm_acc', mlm_acc), ('nsp_acc', nsp_acc), ('iter_time', avg_batch_time), ('samples_per_sec', samples_per_sec), ('total_time', total_time), ]) logger.info(print_format.format(**stats)) bert_logging.write_to_csv(stats, i == 0, True, opts['logs_path']) sys_summary = tf.Summary() sys_summary.value.add(tag='perf/throughput_samples_per_second', simple_value=samples_per_sec) sys_summary.value.add(tag='perf/average_batch_time', simple_value=avg_batch_time) summary_writer.add_summary(sys_summary, step) # Log training statistics train_summary = tf.Summary() train_summary.value.add(tag='epoch', simple_value=epoch) train_summary.value.add(tag='loss/MLM', simple_value=mlm_loss) train_summary.value.add(tag='loss/NSP', simple_value=nsp_loss) train_summary.value.add(tag='accuracy/MLM', simple_value=mlm_acc) train_summary.value.add(tag='accuracy/NSP', simple_value=nsp_acc) train_summary.value.add(tag='defaultLearningRate', simple_value=learning_rate) train_summary.value.add(tag='samples', simple_value=step * opts['batches_per_step'] * opts['total_batch_size']) summary_writer.add_summary(train_summary, step) summary_writer.flush() if step % ckpt_per_step == 0 and step: filepath = train_saver.save(train.session, save_path=opts["checkpoint_path"], global_step=step) logger.info("Saved checkpoint to {}".format(filepath)) if not opts['wandb']: bert_logging.save_model_statistics(filepath, summary_writer, step) # Mechanism to checkpoint the best model. # set opts["best_ckpt_min_steps"] to 0 to disable if best_total_loss > mlm_loss + nsp_loss and step - best_step > opts[ "best_ckpt_min_steps"] and opts["best_ckpt_min_steps"]: best_total_loss = mlm_loss + nsp_loss best_step = step filepath = best_saver.save(train.session, save_path=opts["checkpoint_path"] + '_best', global_step=step) logger.info("Saved Best checkpoint to {}".format(filepath)) i += iterations_per_step step += 1 # --------------- LAST CHECKPOINT ---------------- filepath = train_saver.save(train.session, save_path=opts["checkpoint_path"] + '_last', global_step=step) logger.info("Final model saved to to {}".format(filepath)) # --------------- CLEANUP ---------------- train.session.close()
def train(self): # Configure the IPU options. ipu_options = ipu_utils.get_ipu_config( ipu_id=self.opts["select_ipu"], num_ipus_required=len(self.opts["train"]["device_mapping"]) * self.opts["train"]["replicas"], fp_exceptions=False, stochastic_rounding=True, xla_recompute=True, available_memory_proportion=0.2, max_cross_replica_buffer_size=16 * 1024 * 1024, scheduler_selection="Clustering", compile_only=False, partials_type="half") # config replication strategy if self.opts["use_popdist"]: strategy = create_popdist_strategy() ipu_options = strategy.update_ipu_config(ipu_options) ipu_options = popdist.tensorflow.set_ipu_config( ipu_options, len(self.opts["train"]["device_mapping"]), configure_device=False) ipu_options.configure_ipu_system() self.sess = tf.Session(config=tf.ConfigProto()) stop_flag = [] data_threads = [] ds = self.get_dataset_on_the_fly(stop_flag, data_threads) global_step_holder = tf.placeholder(dtype=tf.int32, shape=()) # we write this wrapper because self.model_func has "self" as it's parameter # it will cause an error when cal ipu_compiler.compile def model_wrapper(): self.model_func(self.model, self.opts, global_step_holder, self.infeed_queue, self.outfeed_queue) with ExitStack() as stack: if self.opts["use_popdist"]: stack.enter_context(strategy.scope()) self.infeed_queue = ipu_infeed_queue.IPUInfeedQueue(ds) self.outfeed_queue = ipu_outfeed_queue.IPUOutfeedQueue() with ipu.scopes.ipu_scope("/device:IPU:0"): if self.opts["use_popdist"]: def distributed_per_replica_func(): return ipu_compiler.compile(model_wrapper, inputs=[]) compiled_model = strategy.experimental_run_v2( distributed_per_replica_func, args=[]) else: compiled_model = ipu_compiler.compile(model_wrapper, inputs=[]) # The outfeed dequeue has to happen after the outfeed enqueue(after calling compile) dequeue_outfeed = self.outfeed_queue.dequeue() if self.opts["use_popdist"]: # Take the mean of all the outputs across the distributed workers dequeue_outfeed = [ strategy.reduce(tf.distribute.ReduceOp.MEAN, v) for v in dequeue_outfeed ] with tf.name_scope("loader_and_saver"): self.loader, self.saver = self.get_loader_and_saver() self.sess.run(self.infeed_queue.initializer) self.sess.run(tf.global_variables_initializer()) begin_epoch = 0 if self.opts["train"]["load_type"] == "resume": # resume a half-trained run ckpts = [] if os.path.exists("./checkpoint"): ckpts = sorted([ path for path in os.listdir("./checkpoint") if "meta" in path ]) if len(ckpts) == 0: logger.info("fail to resume, not find any ckpt") return ckpt_path = "./checkpoint/" + ckpts[-1].replace(".meta", "") logger.info("=> Resume training from: %s ... " % ckpt_path) self.loader.restore(self.sess, ckpt_path) begin_epoch = int( re.search("epoch=([0-9]+)", ckpt_path).groups()[0]) elif self.opts["train"]["load_type"] in [ "yolov3", "darknet53", "phase1" ]: # if load some pretrained ckpt if self.initial_weight and os.path.exists(self.initial_weight + ".meta"): logger.info("=> Restoring weights from: %s ... " % self.initial_weight) self.loader.restore(self.sess, self.initial_weight) else: raise Exception("can't find ckpt to load") elif self.opts["train"]["load_type"] == "empty": logger.info("=> no checkpoint to load !!!") logger.info("=> Now it starts to train YOLOV3 from scratch ...") else: raise Exception( "'load_type' is not one of expected values: yolov3, darknet53, phase1, resume, empty" ) total_epochs = self.epochs total_batch_size = self.opts["train"]["pipeline_depth"] * \ self.batch_size * \ self.opts["train"]["replicas"] * \ self.opts["distributed_worker_count"] samples_per_interaction = total_batch_size * self.repeat_count samples_per_epoch = len(self.trainset) * self.batch_size interactions_per_epoch = samples_per_epoch // samples_per_interaction if self.for_speed_test: interactions_per_epoch = 30 total_epochs = 1 steps_per_epoch = interactions_per_epoch * self.repeat_count logger.info("total epochs: {}".format(total_epochs)) logger.info("steps_per_epoch: {}".format(steps_per_epoch)) moving_loss = deque(maxlen=30) if self.opts["distributed_worker_index"] == 0: # we only write logs to tensorboard on main worker summary_writer = tf.summary.FileWriter( "./tf_log/" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), session=self.sess) train_begin_time = time.time() for epoch in range(begin_epoch, total_epochs): logger.info("epoch {}:".format(epoch + 1)) start_time = time.time() for interaction_count in range(interactions_per_epoch): global_step = epoch * steps_per_epoch + interaction_count * self.repeat_count self.sess.run(compiled_model, feed_dict={global_step_holder: global_step}) result = self.sess.run(dequeue_outfeed) if self.opts["distributed_worker_index"] == 0: giou_loss = np.mean(result[0]) conf_loss = np.mean(result[1]) prob_loss = np.mean(result[2]) lr = np.mean(result[3]) total_loss = giou_loss + conf_loss + prob_loss moving_loss.append(total_loss) end_time = time.time() duration = end_time - start_time start_time = time.time() total_samples = global_step * total_batch_size logger.info( "epoch:{}, global_steps:{}, total_samples:{}, lr:{:.3e}, \ moving_total_loss:{:.2f}, duration:{:.2f}, samples/s:{:.2f},\ total_time:{:.2f}".format(epoch + 1, global_step, total_samples, lr, np.mean(moving_loss), duration, samples_per_interaction / duration, time.time() - train_begin_time)) train_summary = tf.Summary() train_summary.value.add(tag="giou_loss", simple_value=giou_loss) train_summary.value.add(tag="conf_loss", simple_value=conf_loss) train_summary.value.add(tag="prob_loss", simple_value=prob_loss) train_summary.value.add(tag="total_loss", simple_value=total_loss) train_summary.value.add(tag="lr", simple_value=lr) train_summary.value.add( tag="samples_per_sec", simple_value=samples_per_interaction / duration) summary_writer.add_summary(train_summary, total_samples) summary_writer.flush() if (not self.for_speed_test) and ( epoch % self.opts["train"]["epochs_per_ckpt"] == 0 or epoch == total_epochs - 1): if self.opts["distributed_worker_index"] == 0: ckpt_loss = np.mean(moving_loss) else: # if not call save on all instances, there will be a all-reduce error # but call save on all workers is pointless # so only ckpt saved at worker 0 will have a name with loss value ckpt_loss = 0.0 ckpt_file = "./checkpoint/yolov3-{}-epoch={}-moving_total_loss={:.4f}.ckpt".format( datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), epoch + 1, ckpt_loss) logger.info("saving to: " + ckpt_file) model_path = self.saver.save(self.sess, ckpt_file, global_step=global_step) if self.opts["distributed_worker_index"] == 0: log.save_model_statistics(model_path, summary_writer, global_step * total_batch_size) # tell threads to stop stop_flag.append(0) for data_thread in data_threads: data_thread.join() self.sess.close()