def train(net, data_loader, loss_dict, optimizer, scheduler,logger, epoch, metric_dict, use_aux):
    net.train()
    progress_bar = dist_tqdm(train_loader)
    t_data_0 = time.time()
    for b_idx, data_label in enumerate(progress_bar):
        t_data_1 = time.time()
        reset_metrics(metric_dict)
        global_step = epoch * len(data_loader) + b_idx

        t_net_0 = time.time()
        results = inference(net, data_label, use_aux)

        loss = calc_loss(loss_dict, results, logger, global_step)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step(global_step)
        t_net_1 = time.time()

        results = resolve_val_data(results, use_aux)

        update_metrics(metric_dict, results)
        if global_step % 20 == 0:
            for me_name, me_op in zip(metric_dict['name'], metric_dict['op']):
                logger.add_scalar('metric/' + me_name, me_op.get(), global_step=global_step)
        logger.add_scalar('meta/lr', optimizer.param_groups[0]['lr'], global_step=global_step)

        if hasattr(progress_bar,'set_postfix'):
            kwargs = {me_name: '%.3f' % me_op.get() for me_name, me_op in zip(metric_dict['name'], metric_dict['op'])}
            progress_bar.set_postfix(loss = '%.3f' % float(loss), 
                                    data_time = '%.3f' % float(t_data_1 - t_data_0), 
                                    net_time = '%.3f' % float(t_net_1 - t_net_0), 
                                    **kwargs)
        t_data_0 = time.time()
def train(net, data_loader, loss_dict, optimizer, scheduler, logger, epoch,
          metric_dict, use_aux, local_rank):
    net.train()
    if local_rank != -1:
        data_loader.sampler.set_epoch(epoch)
    progress_bar = dist_tqdm(data_loader)
    t_data_0 = time.time()
    for b_idx, data_label in enumerate(progress_bar):
        t_data_1 = time.time()
        reset_metrics(metric_dict)
        global_step = epoch * len(data_loader) + b_idx

        t_net_0 = time.time()
        results = inference(net, data_label, use_aux)

        loss = calc_loss(loss_dict, results, logger, global_step)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step(global_step)
        t_net_1 = time.time()

        results = resolve_val_data(results, use_aux)

        update_metrics(metric_dict, results)
        if global_step % 20 == 0:
            for me_name, me_op in zip(metric_dict['name'], metric_dict['op']):
                logger.scalar_summary('metric/' + me_name, 'train',
                                      me_op.get(), global_step)
        logger.scalar_summary('meta/lr', 'train',
                              optimizer.param_groups[0]['lr'], global_step)

        if hasattr(progress_bar, 'set_postfix'):
            kwargs = {
                me_name: '%.3f' % me_op.get()
                for me_name, me_op in zip(metric_dict['name'],
                                          metric_dict['op'])
            }
            log_msg = 'Epoch{}/{}|Iter{}'.format(epoch, scheduler.total_epoch,
                                                 b_idx)
            # log_msg = 'Epoch{}/{}|Iter{} '.format(epoch, scheduler.total_epoch,
            #         global_step, b_idx, len(data_loader), optimizer.param_groups[0]['lr'])
            progress_bar.set_description(log_msg)
            progress_bar.set_postfix(loss='%.3f' % float(loss), **kwargs)
        t_data_0 = time.time()
def train(net, data_loader, loss_dict, optimizer, scheduler, logger, epoch,
          metric_dict, cfg):
    net.train()
    progress_bar = dist_tqdm(data_loader)
    t_data_0 = time.time()
    # Pyten-20201019-FixBug
    reset_metrics(metric_dict)
    total_loss = 0
    for b_idx, data_label in enumerate(progress_bar):
        t_data_1 = time.time()
        global_step = epoch * len(data_loader) + b_idx

        t_net_0 = time.time()
        results = inference(net, data_label, cfg.use_aux)

        loss = calc_loss(loss_dict, results, logger, global_step, "train")
        optimizer.zero_grad()
        loss.backward()
        # Pyten-20210201-ClipGrad
        clip_grad_norm_(net.parameters(), max_norm=10.0)
        optimizer.step()
        total_loss = total_loss + loss.detach()
        scheduler.step(global_step)
        t_net_1 = time.time()

        results = resolve_val_data(results, cfg.use_aux)
        update_metrics(metric_dict, results)
        if global_step % 20 == 0:
            # Pyten-20210201-TransformImg
            img = img_detrans(data_label[0][0])
            logger.add_image("train_image/org", img, global_step=global_step)
            logger.add_image("train_image/std",
                             data_label[0][0],
                             global_step=global_step)
            if cfg.use_aux:
                seg_color_out = decode_seg_color_map(results["seg_out"][0])
                seg_color_label = decode_seg_color_map(data_label[2][0])
                logger.add_image("train_seg/predict",
                                 seg_color_out,
                                 global_step=global_step,
                                 dataformats='HWC')
                logger.add_image("train_seg/label",
                                 seg_color_label,
                                 global_step=global_step,
                                 dataformats='HWC')
            cls_color_out = decode_cls_color_map(data_label[0][0],
                                                 results["cls_out"][0], cfg)
            cls_color_label = decode_cls_color_map(data_label[0][0],
                                                   data_label[1][0], cfg)
            logger.add_image("train_cls/predict",
                             cls_color_out,
                             global_step=global_step,
                             dataformats='HWC')
            logger.add_image("train_cls/label",
                             cls_color_label,
                             global_step=global_step,
                             dataformats='HWC')

            for me_name, me_op in zip(metric_dict['name'], metric_dict['op']):
                logger.add_scalar('train_metric/' + me_name,
                                  me_op.get(),
                                  global_step=global_step)
        logger.add_scalar('train/meta/lr',
                          optimizer.param_groups[0]['lr'],
                          global_step=global_step)

        if hasattr(progress_bar, 'set_postfix'):
            kwargs = {
                me_name: '%.4f' % me_op.get()
                for me_name, me_op in zip(metric_dict['name'],
                                          metric_dict['op'])
            }
            progress_bar.set_postfix(
                loss='%.3f' % float(loss),
                avg_loss='%.3f' % float(total_loss / (b_idx + 1)),
                #data_time = '%.3f' % float(t_data_1 - t_data_0),
                net_time='%.3f' % float(t_net_1 - t_net_0),
                **kwargs)
        t_data_0 = time.time()

    dist_print("avg_loss_over_epoch", total_loss / len(data_loader))
def val(net, data_loader, loss_dict, scheduler, logger, epoch, metric_dict,
        cfg):
    net.eval()
    progress_bar = dist_tqdm(data_loader)
    t_data_0 = time.time()
    reset_metrics(metric_dict)
    total_loss = 0
    with torch.no_grad():
        for b_idx, data_label in enumerate(progress_bar):
            t_data_1 = time.time()
            # reset_metrics(metric_dict)
            global_step = epoch * len(data_loader) + b_idx

            t_net_0 = time.time()
            # pdb.set_trace()
            results = inference(net, data_label, cfg.use_aux)
            loss = calc_loss(loss_dict, results, logger, global_step, "val")
            total_loss = total_loss + loss.detach()

            t_net_1 = time.time()

            results = resolve_val_data(results, cfg.use_aux)

            update_metrics(metric_dict, results)
            if global_step % 20 == 0:
                # Pyten-20210201-TransformImg
                img = img_detrans(data_label[0][0])
                logger.add_image("val_image/org", img, global_step=global_step)
                logger.add_image("val_image/std",
                                 data_label[0][0],
                                 global_step=global_step)
                if cfg.use_aux:
                    # import pdb; pdb.set_trace()
                    seg_color_out = decode_seg_color_map(results["seg_out"][0])
                    seg_color_label = decode_seg_color_map(data_label[2][0])
                    logger.add_image("val_seg/predict",
                                     seg_color_out,
                                     global_step=global_step,
                                     dataformats='HWC')
                    logger.add_image("val_seg/label",
                                     seg_color_label,
                                     global_step=global_step,
                                     dataformats='HWC')

                cls_color_out = decode_cls_color_map(data_label[0][0],
                                                     results["cls_out"][0],
                                                     cfg)
                cls_color_label = decode_cls_color_map(data_label[0][0],
                                                       data_label[1][0], cfg)
                logger.add_image("val_cls/predict",
                                 cls_color_out,
                                 global_step=global_step,
                                 dataformats='HWC')
                logger.add_image("val_cls/label",
                                 cls_color_label,
                                 global_step=global_step,
                                 dataformats='HWC')

            if hasattr(progress_bar, 'set_postfix'):
                kwargs = {
                    me_name: '%.4f' % me_op.get()
                    for me_name, me_op in zip(metric_dict['name'],
                                              metric_dict['op'])
                }
                progress_bar.set_postfix(
                    loss='%.3f' % float(loss),
                    avg_loss='%.3f' % float(total_loss / (b_idx + 1)),
                    # data_time = '%.3f' % float(t_data_1 - t_data_0),
                    net_time='%.3f' % float(t_net_1 - t_net_0),
                    **kwargs)
            t_data_0 = time.time()

    dist_print("avg_loss_over_epoch", total_loss / len(data_loader))
    for me_name, me_op in zip(metric_dict['name'], metric_dict['op']):
        logger.add_scalar('val_metric/' + me_name,
                          me_op.get(),
                          global_step=epoch)
    # Pyten-20201019-SaveBestMetric
    update_best_metric = True
    for me_name, me_op in zip(metric_dict['name'], metric_dict['op']):
        if me_name == "iou":
            continue
        cur_metric = me_op.get()
        if cur_metric < metric_dict["best_metric"][me_name]:
            update_best_metric = False
    if update_best_metric:
        for me_name, me_op in zip(metric_dict['name'], metric_dict['op']):
            metric_dict["best_metric"][me_name] = me_op.get()
        cfg.best_epoch = epoch
        dist_print("best metric updated!(epoch%d)" % epoch)