def train(net, data_loader, loss_dict, optimizer, scheduler,logger, epoch, metric_dict, use_aux): net.train() progress_bar = dist_tqdm(train_loader) t_data_0 = time.time() for b_idx, data_label in enumerate(progress_bar): t_data_1 = time.time() reset_metrics(metric_dict) global_step = epoch * len(data_loader) + b_idx t_net_0 = time.time() results = inference(net, data_label, use_aux) loss = calc_loss(loss_dict, results, logger, global_step) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step(global_step) t_net_1 = time.time() results = resolve_val_data(results, use_aux) update_metrics(metric_dict, results) if global_step % 20 == 0: for me_name, me_op in zip(metric_dict['name'], metric_dict['op']): logger.add_scalar('metric/' + me_name, me_op.get(), global_step=global_step) logger.add_scalar('meta/lr', optimizer.param_groups[0]['lr'], global_step=global_step) if hasattr(progress_bar,'set_postfix'): kwargs = {me_name: '%.3f' % me_op.get() for me_name, me_op in zip(metric_dict['name'], metric_dict['op'])} progress_bar.set_postfix(loss = '%.3f' % float(loss), data_time = '%.3f' % float(t_data_1 - t_data_0), net_time = '%.3f' % float(t_net_1 - t_net_0), **kwargs) t_data_0 = time.time()
def train(net, data_loader, loss_dict, optimizer, scheduler, logger, epoch, metric_dict, use_aux, local_rank): net.train() if local_rank != -1: data_loader.sampler.set_epoch(epoch) progress_bar = dist_tqdm(data_loader) t_data_0 = time.time() for b_idx, data_label in enumerate(progress_bar): t_data_1 = time.time() reset_metrics(metric_dict) global_step = epoch * len(data_loader) + b_idx t_net_0 = time.time() results = inference(net, data_label, use_aux) loss = calc_loss(loss_dict, results, logger, global_step) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step(global_step) t_net_1 = time.time() results = resolve_val_data(results, use_aux) update_metrics(metric_dict, results) if global_step % 20 == 0: for me_name, me_op in zip(metric_dict['name'], metric_dict['op']): logger.scalar_summary('metric/' + me_name, 'train', me_op.get(), global_step) logger.scalar_summary('meta/lr', 'train', optimizer.param_groups[0]['lr'], global_step) if hasattr(progress_bar, 'set_postfix'): kwargs = { me_name: '%.3f' % me_op.get() for me_name, me_op in zip(metric_dict['name'], metric_dict['op']) } log_msg = 'Epoch{}/{}|Iter{}'.format(epoch, scheduler.total_epoch, b_idx) # log_msg = 'Epoch{}/{}|Iter{} '.format(epoch, scheduler.total_epoch, # global_step, b_idx, len(data_loader), optimizer.param_groups[0]['lr']) progress_bar.set_description(log_msg) progress_bar.set_postfix(loss='%.3f' % float(loss), **kwargs) t_data_0 = time.time()
def train(net, data_loader, loss_dict, optimizer, scheduler, logger, epoch, metric_dict, cfg): net.train() progress_bar = dist_tqdm(data_loader) t_data_0 = time.time() # Pyten-20201019-FixBug reset_metrics(metric_dict) total_loss = 0 for b_idx, data_label in enumerate(progress_bar): t_data_1 = time.time() global_step = epoch * len(data_loader) + b_idx t_net_0 = time.time() results = inference(net, data_label, cfg.use_aux) loss = calc_loss(loss_dict, results, logger, global_step, "train") optimizer.zero_grad() loss.backward() # Pyten-20210201-ClipGrad clip_grad_norm_(net.parameters(), max_norm=10.0) optimizer.step() total_loss = total_loss + loss.detach() scheduler.step(global_step) t_net_1 = time.time() results = resolve_val_data(results, cfg.use_aux) update_metrics(metric_dict, results) if global_step % 20 == 0: # Pyten-20210201-TransformImg img = img_detrans(data_label[0][0]) logger.add_image("train_image/org", img, global_step=global_step) logger.add_image("train_image/std", data_label[0][0], global_step=global_step) if cfg.use_aux: seg_color_out = decode_seg_color_map(results["seg_out"][0]) seg_color_label = decode_seg_color_map(data_label[2][0]) logger.add_image("train_seg/predict", seg_color_out, global_step=global_step, dataformats='HWC') logger.add_image("train_seg/label", seg_color_label, global_step=global_step, dataformats='HWC') cls_color_out = decode_cls_color_map(data_label[0][0], results["cls_out"][0], cfg) cls_color_label = decode_cls_color_map(data_label[0][0], data_label[1][0], cfg) logger.add_image("train_cls/predict", cls_color_out, global_step=global_step, dataformats='HWC') logger.add_image("train_cls/label", cls_color_label, global_step=global_step, dataformats='HWC') for me_name, me_op in zip(metric_dict['name'], metric_dict['op']): logger.add_scalar('train_metric/' + me_name, me_op.get(), global_step=global_step) logger.add_scalar('train/meta/lr', optimizer.param_groups[0]['lr'], global_step=global_step) if hasattr(progress_bar, 'set_postfix'): kwargs = { me_name: '%.4f' % me_op.get() for me_name, me_op in zip(metric_dict['name'], metric_dict['op']) } progress_bar.set_postfix( loss='%.3f' % float(loss), avg_loss='%.3f' % float(total_loss / (b_idx + 1)), #data_time = '%.3f' % float(t_data_1 - t_data_0), net_time='%.3f' % float(t_net_1 - t_net_0), **kwargs) t_data_0 = time.time() dist_print("avg_loss_over_epoch", total_loss / len(data_loader))
def val(net, data_loader, loss_dict, scheduler, logger, epoch, metric_dict, cfg): net.eval() progress_bar = dist_tqdm(data_loader) t_data_0 = time.time() reset_metrics(metric_dict) total_loss = 0 with torch.no_grad(): for b_idx, data_label in enumerate(progress_bar): t_data_1 = time.time() # reset_metrics(metric_dict) global_step = epoch * len(data_loader) + b_idx t_net_0 = time.time() # pdb.set_trace() results = inference(net, data_label, cfg.use_aux) loss = calc_loss(loss_dict, results, logger, global_step, "val") total_loss = total_loss + loss.detach() t_net_1 = time.time() results = resolve_val_data(results, cfg.use_aux) update_metrics(metric_dict, results) if global_step % 20 == 0: # Pyten-20210201-TransformImg img = img_detrans(data_label[0][0]) logger.add_image("val_image/org", img, global_step=global_step) logger.add_image("val_image/std", data_label[0][0], global_step=global_step) if cfg.use_aux: # import pdb; pdb.set_trace() seg_color_out = decode_seg_color_map(results["seg_out"][0]) seg_color_label = decode_seg_color_map(data_label[2][0]) logger.add_image("val_seg/predict", seg_color_out, global_step=global_step, dataformats='HWC') logger.add_image("val_seg/label", seg_color_label, global_step=global_step, dataformats='HWC') cls_color_out = decode_cls_color_map(data_label[0][0], results["cls_out"][0], cfg) cls_color_label = decode_cls_color_map(data_label[0][0], data_label[1][0], cfg) logger.add_image("val_cls/predict", cls_color_out, global_step=global_step, dataformats='HWC') logger.add_image("val_cls/label", cls_color_label, global_step=global_step, dataformats='HWC') if hasattr(progress_bar, 'set_postfix'): kwargs = { me_name: '%.4f' % me_op.get() for me_name, me_op in zip(metric_dict['name'], metric_dict['op']) } progress_bar.set_postfix( loss='%.3f' % float(loss), avg_loss='%.3f' % float(total_loss / (b_idx + 1)), # data_time = '%.3f' % float(t_data_1 - t_data_0), net_time='%.3f' % float(t_net_1 - t_net_0), **kwargs) t_data_0 = time.time() dist_print("avg_loss_over_epoch", total_loss / len(data_loader)) for me_name, me_op in zip(metric_dict['name'], metric_dict['op']): logger.add_scalar('val_metric/' + me_name, me_op.get(), global_step=epoch) # Pyten-20201019-SaveBestMetric update_best_metric = True for me_name, me_op in zip(metric_dict['name'], metric_dict['op']): if me_name == "iou": continue cur_metric = me_op.get() if cur_metric < metric_dict["best_metric"][me_name]: update_best_metric = False if update_best_metric: for me_name, me_op in zip(metric_dict['name'], metric_dict['op']): metric_dict["best_metric"][me_name] = me_op.get() cfg.best_epoch = epoch dist_print("best metric updated!(epoch%d)" % epoch)