def log_info(trainer, batch_size, epoch_id, iter_id): lr_msg = "lr: {:.5f}".format(trainer.lr_sch.get_lr()) metric_msg = ", ".join([ "{}: {:.5f}".format(key, trainer.output_info[key].avg) for key in trainer.output_info ]) time_msg = "s, ".join([ "{}: {:.5f}".format(key, trainer.time_info[key].avg) for key in trainer.time_info ]) ips_msg = "ips: {:.5f} images/sec".format( batch_size / trainer.time_info["batch_cost"].avg) eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1 ) * len(trainer.train_dataloader) - iter_id ) * trainer.time_info["batch_cost"].avg eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec)))) logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format( epoch_id, trainer.config["Global"]["epochs"], iter_id, len(trainer.train_dataloader), lr_msg, metric_msg, time_msg, ips_msg, eta_msg)) logger.scaler( name="lr", value=trainer.lr_sch.get_lr(), step=trainer.global_step, writer=trainer.vdl_writer) for key in trainer.output_info: logger.scaler( name="train_{}".format(key), value=trainer.output_info[key].avg, step=trainer.global_step, writer=trainer.vdl_writer)
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None): """ Feed data to the model and fetch the measures and loss Args: dataloader(fluid dataloader): exe(): program(): fetchs(dict): dict of measures and the loss epoch(int): epoch of training or validation model(str): log only Returns: """ fetch_list = [f[0] for f in fetchs.values()] metric_list = [f[1] for f in fetchs.values()] for m in metric_list: m.reset() batch_time = AverageMeter('elapse', '.3f') tic = time.time() for idx, batch in enumerate(dataloader()): metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list) batch_time.update(time.time() - tic) tic = time.time() for i, m in enumerate(metrics): metric_list[i].update(m[0], len(batch[0])) fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list] + [batch_time.value]) + 's' if vdl_writer: global total_step logger.scaler('loss', metrics[0][0], total_step, vdl_writer) total_step += 1 if mode == 'eval': logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) else: epoch_str = "epoch:{:<3d}".format(epoch) step_str = "{:s} step:{:<4d}".format(mode, idx) logger.info("{:s} {:s} {:s}".format( logger.coloring(epoch_str, "HEADER") if idx == 0 else epoch_str, logger.coloring(step_str, "PURPLE"), logger.coloring(fetchs_str, 'OKGREEN'))) end_str = ''.join([str(m.mean) + ' ' for m in metric_list] + [batch_time.total]) + 's' if mode == 'eval': logger.info("END {:s} {:s}s".format(mode, end_str)) else: end_epoch_str = "END epoch:{:<3d}".format(epoch) logger.info("{:s} {:s} {:s}".format( logger.coloring(end_epoch_str, "RED"), logger.coloring(mode, "PURPLE"), logger.coloring(end_str, "OKGREEN"))) # return top1_acc in order to save the best model if mode == 'valid': return fetchs["top1"][1].avg
def run(dataloader, config, net, optimizer=None, lr_scheduler=None, epoch=0, mode='train', vdl_writer=None): """ Feed data to the model and fetch the measures and loss Args: dataloader(paddle dataloader): exe(): program(): fetchs(dict): dict of measures and the loss epoch(int): epoch of training or validation model(str): log only Returns: """ print_interval = config.get("print_interval", 10) use_mix = config.get("use_mix", False) and mode == "train" multilabel = config.get("multilabel", False) classes_num = config.get("classes_num") metric_list = [ ("loss", AverageMeter( 'loss', '7.5f', postfix=",")), ("lr", AverageMeter( 'lr', 'f', postfix=",", need_avg=False)), ("batch_time", AverageMeter( 'batch_cost', '.5f', postfix=" s,")), ("reader_time", AverageMeter( 'reader_cost', '.5f', postfix=" s,")), ] if not use_mix: if not multilabel: topk_name = 'top{}'.format(config.topk) metric_list.insert( 0, (topk_name, AverageMeter( topk_name, '.5f', postfix=","))) metric_list.insert( 0, ("top1", AverageMeter( "top1", '.5f', postfix=","))) else: metric_list.insert( 0, ("multilabel_accuracy", AverageMeter( "multilabel_accuracy", '.5f', postfix=","))) metric_list.insert( 0, ("hamming_distance", AverageMeter( "hamming_distance", '.5f', postfix=","))) metric_list = OrderedDict(metric_list) tic = time.time() for idx, batch in enumerate(dataloader()): # avoid statistics from warmup time if idx == 10: metric_list["batch_time"].reset() metric_list["reader_time"].reset() metric_list['reader_time'].update(time.time() - tic) batch_size = len(batch[0]) feeds = create_feeds(batch, use_mix, classes_num, multilabel) fetchs = create_fetchs(feeds, net, config, mode) if mode == 'train': avg_loss = fetchs['loss'] avg_loss.backward() optimizer.step() optimizer.clear_grad() lr_value = optimizer._global_learning_rate().numpy()[0] metric_list['lr'].update(lr_value, batch_size) if lr_scheduler is not None: if lr_scheduler.update_specified: curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx update = max( 0, curr_global_counter - lr_scheduler.update_start_step ) % lr_scheduler.update_step_interval == 0 if update: lr_scheduler.step() else: lr_scheduler.step() for name, fetch in fetchs.items(): metric_list[name].update(fetch.numpy()[0], batch_size) metric_list["batch_time"].update(time.time() - tic) tic = time.time() if vdl_writer and mode == "train": global total_step logger.scaler( name="lr", value=lr_value, step=total_step, writer=vdl_writer) for name, fetch in fetchs.items(): logger.scaler( name="train_{}".format(name), value=fetch.numpy()[0], step=total_step, writer=vdl_writer) total_step += 1 fetchs_str = ' '.join([ str(metric_list[key].mean) if "time" in key else str(metric_list[key].value) for key in metric_list ]) if idx % print_interval == 0: ips_info = "ips: {:.5f} images/sec".format( batch_size / metric_list["batch_time"].avg) if mode == "train": epoch_str = "epoch:{:<3d}".format(epoch) step_str = "{:s} step:{:<4d}".format(mode, idx) eta_sec = ((config["epochs"] - epoch) * len(dataloader) - idx ) * metric_list["batch_time"].avg eta_str = "eta: {:s}".format( str(datetime.timedelta(seconds=int(eta_sec)))) logger.info("{:s}, {:s}, {:s} {:s}, {:s}".format( epoch_str, step_str, fetchs_str, ips_info, eta_str)) else: logger.info("{:s} step:{:<4d}, {:s} {:s}".format( mode, idx, fetchs_str, ips_info)) end_str = ' '.join([str(m.mean) for m in metric_list.values()] + [metric_list['batch_time'].total]) ips_info = "ips: {:.5f} images/sec.".format( batch_size * metric_list["batch_time"].count / metric_list["batch_time"].sum) if mode == 'eval': logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info)) else: end_epoch_str = "END epoch:{:<3d}".format(epoch) logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str, ips_info)) # return top1_acc in order to save the best model if mode == 'valid': if multilabel: return metric_list['multilabel_accuracy'].avg else: return metric_list['top1'].avg
def run(dataloader, exe, program, feeds, fetchs, epoch=0, mode='train', config=None, vdl_writer=None, lr_scheduler=None): """ Feed data to the model and fetch the measures and loss Args: dataloader(paddle io dataloader): exe(): program(): fetchs(dict): dict of measures and the loss epoch(int): epoch of training or validation model(str): log only Returns: """ fetch_list = [f[0] for f in fetchs.values()] metric_list = [ ("lr", AverageMeter('lr', 'f', postfix=",", need_avg=False)), ("batch_time", AverageMeter('batch_cost', '.5f', postfix=" s,")), ("reader_time", AverageMeter('reader_cost', '.5f', postfix=" s,")), ] topk_name = 'top{}'.format(config.topk) metric_list.insert(0, ("loss", fetchs["loss"][1])) use_mix = config.get("use_mix", False) and mode == "train" if not use_mix: metric_list.insert(0, (topk_name, fetchs[topk_name][1])) metric_list.insert(0, ("top1", fetchs["top1"][1])) metric_list = OrderedDict(metric_list) for m in metric_list.values(): m.reset() use_dali = config.get('use_dali', False) dataloader = dataloader if use_dali else dataloader() tic = time.time() idx = 0 batch_size = None while True: # The DALI maybe raise RuntimeError for some particular images, such as ImageNet1k/n04418357_26036.JPEG try: batch = next(dataloader) except StopIteration: break except RuntimeError: logger.warning( "Except RuntimeError when reading data from dataloader, try to read once again..." ) continue idx += 1 # ignore the warmup iters if idx == 5: metric_list["batch_time"].reset() metric_list["reader_time"].reset() metric_list['reader_time'].update(time.time() - tic) if use_dali: batch_size = batch[0]["feed_image"].shape()[0] feed_dict = batch[0] else: batch_size = batch[0].shape()[0] feed_dict = { key.name: batch[idx] for idx, key in enumerate(feeds.values()) } metrics = exe.run(program=program, feed=feed_dict, fetch_list=fetch_list) for name, m in zip(fetchs.keys(), metrics): metric_list[name].update(np.mean(m), batch_size) metric_list["batch_time"].update(time.time() - tic) if mode == "train": metric_list['lr'].update(lr_scheduler.get_lr()) fetchs_str = ' '.join([ str(metric_list[key].mean) if "time" in key else str(metric_list[key].value) for key in metric_list ]) ips_info = " ips: {:.5f} images/sec.".format( batch_size / metric_list["batch_time"].avg) fetchs_str += ips_info if lr_scheduler is not None: if lr_scheduler.update_specified: curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx update = max( 0, curr_global_counter - lr_scheduler.update_start_step ) % lr_scheduler.update_step_interval == 0 if update: lr_scheduler.step() else: lr_scheduler.step() if vdl_writer: global total_step logger.scaler('loss', metrics[0][0], total_step, vdl_writer) total_step += 1 if mode == 'valid': if idx % config.get('print_interval', 10) == 0: logger.info("{:s} step:{:<4d} {:s}".format( mode, idx, fetchs_str)) else: epoch_str = "epoch:{:<3d}".format(epoch) step_str = "{:s} step:{:<4d}".format(mode, idx) if idx % config.get('print_interval', 10) == 0: logger.info("{:s} {:s} {:s}".format( logger.coloring(epoch_str, "HEADER") if idx == 0 else epoch_str, logger.coloring(step_str, "PURPLE"), logger.coloring(fetchs_str, 'OKGREEN'))) tic = time.time() end_str = ' '.join([str(m.mean) for m in metric_list.values()] + [metric_list["batch_time"].total]) ips_info = "ips: {:.5f} images/sec.".format( batch_size * metric_list["batch_time"].count / metric_list["batch_time"].sum) if mode == 'valid': logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info)) else: end_epoch_str = "END epoch:{:<3d}".format(epoch) logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str, ips_info)) if use_dali: dataloader.reset() # return top1_acc in order to save the best model if mode == 'valid': return fetchs["top1"][1].avg
def run(dataloader, exe, program, feeds, fetchs, epoch=0, mode='train', config=None, vdl_writer=None, lr_scheduler=None, profiler_options=None): """ Feed data to the model and fetch the measures and loss Args: dataloader(paddle io dataloader): exe(): program(): fetchs(dict): dict of measures and the loss epoch(int): epoch of training or evaluation model(str): log only Returns: """ fetch_list = [f[0] for f in fetchs.values()] metric_dict = OrderedDict([("lr", AverageMeter('lr', 'f', postfix=",", need_avg=False))]) for k in fetchs: metric_dict[k] = fetchs[k][1] metric_dict["batch_time"] = AverageMeter('batch_cost', '.5f', postfix=" s,") metric_dict["reader_time"] = AverageMeter('reader_cost', '.5f', postfix=" s,") for m in metric_dict.values(): m.reset() use_dali = config["Global"].get('use_dali', False) tic = time.time() if not use_dali: dataloader = dataloader() idx = 0 batch_size = None while True: # The DALI maybe raise RuntimeError for some particular images, such as ImageNet1k/n04418357_26036.JPEG try: batch = next(dataloader) except StopIteration: break except RuntimeError: logger.warning( "Except RuntimeError when reading data from dataloader, try to read once again..." ) continue idx += 1 # ignore the warmup iters if idx == 5: metric_dict["batch_time"].reset() metric_dict["reader_time"].reset() metric_dict['reader_time'].update(time.time() - tic) profiler.add_profiler_step(profiler_options) if use_dali: batch_size = batch[0]["data"].shape()[0] feed_dict = batch[0] else: batch_size = batch[0].shape()[0] feed_dict = { key.name: batch[idx] for idx, key in enumerate(feeds.values()) } metrics = exe.run(program=program, feed=feed_dict, fetch_list=fetch_list) for name, m in zip(fetchs.keys(), metrics): metric_dict[name].update(np.mean(m), batch_size) metric_dict["batch_time"].update(time.time() - tic) if mode == "train": metric_dict['lr'].update(lr_scheduler.get_lr()) fetchs_str = ' '.join([ str(metric_dict[key].mean) if "time" in key else str(metric_dict[key].value) for key in metric_dict ]) ips_info = " ips: {:.5f} images/sec.".format( batch_size / metric_dict["batch_time"].avg) fetchs_str += ips_info if lr_scheduler is not None: lr_scheduler.step() if vdl_writer: global total_step logger.scaler('loss', metrics[0][0], total_step, vdl_writer) total_step += 1 if mode == 'eval': if idx % config.get('print_interval', 10) == 0: logger.info("{:s} step:{:<4d} {:s}".format( mode, idx, fetchs_str)) else: epoch_str = "epoch:{:<3d}".format(epoch) step_str = "{:s} step:{:<4d}".format(mode, idx) if idx % config.get('print_interval', 10) == 0: logger.info("{:s} {:s} {:s}".format(epoch_str, step_str, fetchs_str)) tic = time.time() end_str = ' '.join([str(m.mean) for m in metric_dict.values()] + [metric_dict["batch_time"].total]) ips_info = "ips: {:.5f} images/sec.".format(batch_size / metric_dict["batch_time"].avg) if mode == 'eval': logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info)) else: end_epoch_str = "END epoch:{:<3d}".format(epoch) logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str, ips_info)) if use_dali: dataloader.reset() # return top1_acc in order to save the best model if mode == 'eval': return fetchs["top1"][1].avg
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', config=None, vdl_writer=None): """ Feed data to the model and fetch the measures and loss Args: dataloader(fluid dataloader): exe(): program(): fetchs(dict): dict of measures and the loss epoch(int): epoch of training or validation model(str): log only Returns: """ fetch_list = [f[0] for f in fetchs.values()] metric_list = [f[1] for f in fetchs.values()] for m in metric_list: m.reset() batch_time = AverageMeter('elapse', '.5f', need_avg=True) tic = time.time() dataloader = dataloader if config.get('use_dali') else dataloader()() for idx, batch in enumerate(dataloader): if idx == 10: for m in metric_list: m.reset() batch_time.reset() batch_size = batch[0]["feed_image"].shape()[0] metrics = exe.run(program=program, feed=batch, fetch_list=fetch_list) batch_time.update(time.time() - tic) for i, m in enumerate(metrics): metric_list[i].update(np.mean(m), batch_size) fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list] + [batch_time.mean]) + 's' ips_info = " ips: {:.5f} images/sec.".format(batch_size / batch_time.avg) fetchs_str += ips_info if vdl_writer: global total_step logger.scaler('loss', metrics[0][0], total_step, vdl_writer) total_step += 1 if mode == 'eval': if idx % config.get('print_interval', 10) == 0: logger.info("{:s} step:{:<4d} {:s}".format( mode, idx, fetchs_str)) else: epoch_str = "epoch:{:<3d}".format(epoch) step_str = "{:s} step:{:<4d}".format(mode, idx) if idx % config.get('print_interval', 10) == 0: logger.info("{:s} {:s} {:s}".format( epoch_str if idx == 0 else epoch_str, step_str, fetchs_str)) tic = time.time() if config.get('use_dali'): dataloader.reset() end_str = ''.join([str(m.mean) + ' ' for m in metric_list] + [batch_time.total]) + 's' ips_info = "ips: {:.5f} images/sec.".format(batch_size * batch_time.count / batch_time.sum) if mode == 'eval': logger.info("END {:s} {:s}s {:s}".format(mode, end_str, ips_info)) else: end_epoch_str = "END epoch:{:<3d}".format(epoch) logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str, ips_info)) # return top1_acc in order to save the best model if mode == 'valid': return fetchs["top1"][1].avg
def run(dataloader, exe, program, feeds, fetchs, epoch=0, mode='train', config=None, vdl_writer=None, lr_scheduler=None): """ Feed data to the model and fetch the measures and loss Args: dataloader(paddle io dataloader): exe(): program(): fetchs(dict): dict of measures and the loss epoch(int): epoch of training or validation model(str): log only Returns: """ fetch_list = [f[0] for f in fetchs.values()] metric_list = [f[1] for f in fetchs.values()] if mode == "train": metric_list.append(AverageMeter('lr', 'f', need_avg=False)) for m in metric_list: m.reset() batch_time = AverageMeter('elapse', '.3f') use_dali = config.get('use_dali', False) dataloader = dataloader if use_dali else dataloader() tic = time.time() for idx, batch in enumerate(dataloader): # ignore the warmup iters if idx == 5: batch_time.reset() if use_dali: batch_size = batch[0]["feed_image"].shape()[0] feed_dict = batch[0] else: batch_size = batch[0].shape()[0] feed_dict = { key.name: batch[idx] for idx, key in enumerate(feeds.values()) } metrics = exe.run(program=program, feed=feed_dict, fetch_list=fetch_list) batch_time.update(time.time() - tic) for i, m in enumerate(metrics): metric_list[i].update(np.mean(m), batch_size) if mode == "train": metric_list[-1].update(lr_scheduler.get_lr()) fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list] + [batch_time.mean]) + 's' ips_info = " ips: {:.5f} images/sec.".format(batch_size / batch_time.avg) fetchs_str += ips_info if lr_scheduler is not None: if lr_scheduler.update_specified: curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx update = max( 0, curr_global_counter - lr_scheduler.update_start_step ) % lr_scheduler.update_step_interval == 0 if update: lr_scheduler.step() else: lr_scheduler.step() if vdl_writer: global total_step logger.scaler('loss', metrics[0][0], total_step, vdl_writer) total_step += 1 if mode == 'valid': if idx % config.get('print_interval', 10) == 0: logger.info("{:s} step:{:<4d} {:s}".format( mode, idx, fetchs_str)) else: epoch_str = "epoch:{:<3d}".format(epoch) step_str = "{:s} step:{:<4d}".format(mode, idx) if idx % config.get('print_interval', 10) == 0: logger.info("{:s} {:s} {:s}".format( logger.coloring(epoch_str, "HEADER") if idx == 0 else epoch_str, logger.coloring(step_str, "PURPLE"), logger.coloring(fetchs_str, 'OKGREEN'))) tic = time.time() end_str = ''.join([str(m.mean) + ' ' for m in metric_list] + [batch_time.total]) + 's' ips_info = "ips: {:.5f} images/sec.".format(batch_size * batch_time.count / batch_time.sum) if mode == 'valid': logger.info("END {:s} {:s}s {:s}".format(mode, end_str, ips_info)) else: end_epoch_str = "END epoch:{:<3d}".format(epoch) logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str, ips_info)) if use_dali: dataloader.reset() # return top1_acc in order to save the best model if mode == 'valid': return fetchs["top1"][1].avg
def train(self): assert self.mode == "train" print_batch_step = self.config['Global']['print_batch_step'] save_interval = self.config["Global"]["save_interval"] best_metric = { "metric": 0.0, "epoch": 0, } # key: # val: metrics list word self.output_info = dict() self.time_info = { "batch_cost": AverageMeter("batch_cost", '.5f', postfix=" s,"), "reader_cost": AverageMeter("reader_cost", ".5f", postfix=" s,"), } # global iter counter self.global_step = 0 if self.config["Global"]["checkpoints"] is not None: metric_info = init_model(self.config["Global"], self.model, self.optimizer) if metric_info is not None: best_metric.update(metric_info) self.max_iter = len(self.train_dataloader) - 1 if platform.system( ) == "Windows" else len(self.train_dataloader) for epoch_id in range(best_metric["epoch"] + 1, self.config["Global"]["epochs"] + 1): acc = 0.0 # for one epoch train self.train_epoch_func(self, epoch_id, print_batch_step) if self.use_dali: self.train_dataloader.reset() metric_msg = ", ".join([ "{}: {:.5f}".format(key, self.output_info[key].avg) for key in self.output_info ]) logger.info("[Train][Epoch {}/{}][Avg]{}".format( epoch_id, self.config["Global"]["epochs"], metric_msg)) self.output_info.clear() # eval model and save model if possible if self.config["Global"][ "eval_during_train"] and epoch_id % self.config["Global"][ "eval_interval"] == 0: acc = self.eval(epoch_id) if acc > best_metric["metric"]: best_metric["metric"] = acc best_metric["epoch"] = epoch_id save_load.save_model( self.model, self.optimizer, best_metric, self.output_dir, model_name=self.config["Arch"]["name"], prefix="best_model") logger.info("[Eval][Epoch {}][best metric: {}]".format( epoch_id, best_metric["metric"])) logger.scaler(name="eval_acc", value=acc, step=epoch_id, writer=self.vdl_writer) self.model.train() # save model if epoch_id % save_interval == 0: save_load.save_model(self.model, self.optimizer, { "metric": acc, "epoch": epoch_id }, self.output_dir, model_name=self.config["Arch"]["name"], prefix="epoch_{}".format(epoch_id)) # save the latest model save_load.save_model(self.model, self.optimizer, { "metric": acc, "epoch": epoch_id }, self.output_dir, model_name=self.config["Arch"]["name"], prefix="latest") if self.vdl_writer is not None: self.vdl_writer.close()
def main(args): config = get_config(args.config, overrides=args.override, show=True) # 如果需要量化训练,就必须开启评估 if not config.validate and args.use_quant: logger.error("=====>Train quant model must use validate!") sys.exit(1) if args.use_quant: config.epochs = config.epochs + 5 gpu_count = get_gpu_count() if gpu_count != 1: logger.error( "=====>`Train quant model must use only one GPU. " "Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_TO_USE]` ." ) sys.exit(1) # 设置是否使用 GPU use_gpu = config.get("use_gpu", True) places = fluid.cuda_places() if use_gpu else fluid.cpu_places() startup_prog = fluid.Program() train_prog = fluid.Program() best_top1_acc = 0.0 # 获取训练数据和模型输出 if not config.get('use_ema'): train_dataloader, train_fetchs, out, softmax_out = program.build( config, train_prog, startup_prog, is_train=True, is_distributed=False) else: train_dataloader, train_fetchs, ema, out, softmax_out = program.build( config, train_prog, startup_prog, is_train=True, is_distributed=False) # 获取评估数据和模型输出 if config.validate: valid_prog = fluid.Program() valid_dataloader, valid_fetchs, _, _ = program.build( config, valid_prog, startup_prog, is_train=False, is_distributed=False) # 克隆评估程序,可以去掉与评估无关的计算 valid_prog = valid_prog.clone(for_test=True) # 创建执行器 exe = fluid.Executor(places[0]) exe.run(startup_prog) # 加载模型,可以是预训练模型,也可以是检查点 init_model(config, train_prog, exe) train_reader = Reader(config, 'train')() train_dataloader.set_sample_list_generator(train_reader, places) compiled_train_prog = program.compile(config, train_prog, train_fetchs['loss'][0].name) if config.validate: valid_reader = Reader(config, 'valid')() valid_dataloader.set_sample_list_generator(valid_reader, places) compiled_valid_prog = program.compile(config, valid_prog, share_prog=compiled_train_prog) vdl_writer = LogWriter(args.vdl_dir) for epoch_id in range(config.epochs - 5): # 训练一轮 program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, epoch_id, 'train', config, vdl_writer) # 执行一次评估 if config.validate and epoch_id % config.valid_interval == 0: if config.get('use_ema'): logger.info(logger.coloring("EMA validate start...")) with ema.apply(exe): _ = program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, epoch_id, 'valid', config) logger.info(logger.coloring("EMA validate over!")) top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, epoch_id, 'valid', config) if vdl_writer: logger.scaler('valid_avg', top1_acc, epoch_id, vdl_writer) if top1_acc > best_top1_acc: best_top1_acc = top1_acc message = "The best top1 acc {:.5f}, in epoch: {:d}".format( best_top1_acc, epoch_id) logger.info("{:s}".format(logger.coloring(message, "RED"))) if epoch_id % config.save_interval == 0: model_path = os.path.join(config.model_save_dir, config.ARCHITECTURE["name"]) save_model(train_prog, model_path, "best_model") # 保存模型 if epoch_id % config.save_interval == 0: model_path = os.path.join(config.model_save_dir, config.ARCHITECTURE["name"]) if epoch_id >= 3 and os.path.exists( os.path.join(model_path, str(epoch_id - 3))): shutil.rmtree(os.path.join(model_path, str(epoch_id - 3)), ignore_errors=True) save_model(train_prog, model_path, epoch_id) # 量化训练 if args.use_quant and config.validate: # 执行量化训练 quant_program = slim.quant.quant_aware(train_prog, exe.place, for_test=False) # 评估量化的结果 val_quant_program = slim.quant.quant_aware(valid_prog, exe.place, for_test=True) fetch_list = [f[0] for f in train_fetchs.values()] metric_list = [f[1] for f in train_fetchs.values()] for i in range(5): for idx, batch in enumerate(train_dataloader()): metrics = exe.run(program=quant_program, feed=batch, fetch_list=fetch_list) for i, m in enumerate(metrics): metric_list[i].update(np.mean(m), len(batch[0])) fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list]) if idx % 10 == 0: logger.info("quant train : " + fetchs_str) fetch_list = [f[0] for f in valid_fetchs.values()] metric_list = [f[1] for f in valid_fetchs.values()] for idx, batch in enumerate(valid_dataloader()): metrics = exe.run(program=val_quant_program, feed=batch, fetch_list=fetch_list) for i, m in enumerate(metrics): metric_list[i].update(np.mean(m), len(batch[0])) fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list]) if idx % 10 == 0: logger.info("quant valid: " + fetchs_str) # 保存量化训练模型 float_prog, int8_prog = slim.quant.convert(val_quant_program, exe.place, save_int8=True) fluid.io.save_inference_model(dirname=args.output_path, feeded_var_names=['feed_image'], target_vars=[softmax_out], executor=exe, main_program=float_prog, model_filename='__model__', params_filename='__params__')