def main(args):
    # Parse device ids
    default_dev, *parallel_dev = parse_devices(args.devices)
    all_devs = parallel_dev + [default_dev]
    all_devs = [x.replace('gpu', '') for x in all_devs]
    all_devs = [int(x) for x in all_devs]
    nr_devs = len(all_devs)

    with open(args.list_val, 'r') as f:
        lines = f.readlines()
        nr_files = len(lines)
        if args.num_val > 0:
            nr_files = min(nr_files, args.num_val)
    nr_files_per_dev = math.ceil(nr_files / nr_devs)

    pbar = tqdm(total=nr_files)

    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()

    result_queue = Queue(500)
    procs = []
    for dev_id in range(nr_devs):
        start_idx = dev_id * nr_files_per_dev
        end_idx = min(start_idx + nr_files_per_dev, nr_files)
        proc = Process(target=worker, args=(args, dev_id, start_idx, end_idx, result_queue))
        print('process:%d, start_idx:%d, end_idx:%d' % (dev_id, start_idx, end_idx))
        proc.start()
        procs.append(proc)

    # master fetches results
    processed_counter = 0
    while processed_counter < nr_files:
        if result_queue.empty():
            continue
        (acc, pix, intersection, union) = result_queue.get()
        acc_meter.update(acc, pix)
        intersection_meter.update(intersection)
        union_meter.update(union)
        processed_counter += 1
        pbar.update(1)

    for p in procs:
        p.join()

    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {}'.format(i, _iou))

    print('[Eval Summary]:')
    print('Mean IoU: {:.4}, Accuracy: {:.2f}%'
          .format(iou.mean(), acc_meter.average()*100))

    print('Evaluation Done!')
Пример #2
0
def train(segmentation_module, iterator, optimizers, history, epoch, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()

    names = ['object', 'part', 'scene', 'material']
    ave_losses = {n: AverageMeter() for n in names}
    ave_metric = {n: AverageMeter() for n in names}
    ave_losses['total'] = AverageMeter() 

    segmentation_module.train(not args.fix_bn)

    # main loop
    tic = time.time()
    for i in range(args.epoch_iters):

        batch_data, src_idx = next(iterator)

        data_time.update(time.time() - tic)

        segmentation_module.zero_grad()

        # forward pass
        ret = segmentation_module(batch_data)

        # Backward
        loss = ret['loss']['total'].mean()
        loss.backward()
        for optimizer in optimizers:
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # measure losses 
        for name in ret['loss'].keys():
            ave_losses[name].update(ret['loss'][name].mean().item())

        # measure metrics 
        # NOTE: scene metric will be much lower than benchmark
        for name in ret['metric'].keys():
            ave_metric[name].update(ret['metric'][name].mean().item())

        # calculate accuracy, and display
        if i % args.disp_iter == 0:
            loss_info = "Loss: total {:.4f}, ".format(ave_losses['total'].average())
            loss_info += ", ".join(["{} {:.2f}".format(
                n[0], ave_losses[n].average() 
                if ave_losses[n].average() is not None else 0) for n in names])
            acc_info = "Accuracy: " + ", ".join(["{} {:4.2f}".format(
                n[0], ave_metric[n].average() 
                if ave_metric[n].average() is not None else 0) for n in names])
            print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'LR: encoder {:.6f}, decoder {:.6f}, {}, {}'
                  .format(epoch, i, args.epoch_iters,
                          batch_time.average(), data_time.average(),
                          args.running_lr_encoder, args.running_lr_decoder,
                          acc_info, loss_info))

            fractional_epoch = epoch - 1 + 1. * i / args.epoch_iters
            history['train']['epoch'].append(fractional_epoch)
            history['train']['loss'].append(loss.item())

        # adjust learning rate
        cur_iter = i + (epoch - 1) * args.epoch_iters
        adjust_learning_rate(optimizers, cur_iter, args)
Пример #3
0
def evaluate(segmentation_module, loader, cfg, gpu, results_file=None):
    results = []

    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    time_meter = AverageMeter()

    segmentation_module.eval()

    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']

        torch.cuda.synchronize()
        tic = time.perf_counter()
        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0],
                                 segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                scores_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + scores_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        torch.cuda.synchronize()
        time_meter.update(time.perf_counter() - tic)

        # calculate accuracy
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label,
                                                   cfg.DATASET.num_class)
        acc_meter.update(acc, pix)
        intersection_meter.update(intersection)
        union_meter.update(union)

        # visualization
        if cfg.VAL.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']), pred,
                os.path.join(cfg.DIR, 'result'))

        if results_file:
            ious = intersection / (union + 1e-10)
            recs = [batch_data["info"], acc] + np.column_stack(
                (union, ious)).ravel().tolist()
            results.append(recs)

        pbar.update(1)

    # summary
    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {:.4f}'.format(i, _iou))

    print('[Eval Summary]:')
    print(
        'Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'.format(
            iou.mean(),
            acc_meter.average() * 100, time_meter.average()))

    if results_file:
        import pandas as pd
        headers = ['File', 'Acc']
        for i in range(len(names)):
            headers.extend((names[i] + '_union', names[i] + '_iou'))
        pd.DataFrame(results, columns=headers).to_csv(results_file,
                                                      index=False)
Пример #4
0
def calc_metrics(batch_data, outputs, args):
    # meters
    sdr_mix_meter = AverageMeter()
    sdr_meter = AverageMeter()
    sir_meter = AverageMeter()
    sar_meter = AverageMeter()

    # fetch data and predictions
    mag_mix = batch_data['mag_mix']
    phase_mix = batch_data['phase_mix']
    audios = batch_data['audios']

    pred_masks_ = outputs['pred_masks']

    # unwarp log scale
    N = args.num_mix
    B = mag_mix.size(0)
    pred_masks_linear = [None for n in range(N)]
    for n in range(N):
        if args.log_freq:
            grid_unwarp = torch.from_numpy(
                warpgrid(B,
                         args.stft_frame // 2 + 1,
                         pred_masks_[0].size(3),
                         warp=False)).to(args.device)
            pred_masks_linear[n] = F.grid_sample(pred_masks_[n], grid_unwarp)
        else:
            pred_masks_linear[n] = pred_masks_[n]

    # convert into numpy
    mag_mix = mag_mix.numpy()
    phase_mix = phase_mix.numpy()
    for n in range(N):
        pred_masks_linear[n] = pred_masks_linear[n].detach().cpu().numpy()

        # threshold if binary mask
        if args.binary_mask:
            pred_masks_linear[n] = (pred_masks_linear[n] >
                                    args.mask_thres).astype(np.float32)

    # loop over each sample
    for j in range(B):
        # save mixture
        mix_wav = istft_reconstruction(mag_mix[j, 0],
                                       phase_mix[j, 0],
                                       hop_length=args.stft_hop)

        # save each component
        preds_wav = [None for n in range(N)]
        for n in range(N):
            # Predicted audio recovery
            pred_mag = mag_mix[j, 0] * pred_masks_linear[n][j, 0]
            preds_wav[n] = istft_reconstruction(pred_mag,
                                                phase_mix[j, 0],
                                                hop_length=args.stft_hop)

        # separation performance computes
        L = preds_wav[0].shape[0]
        gts_wav = [None for n in range(N)]
        valid = True
        for n in range(N):
            gts_wav[n] = audios[n][j, 0:L].numpy()
            valid *= np.sum(np.abs(gts_wav[n])) > 1e-5
            valid *= np.sum(np.abs(preds_wav[n])) > 1e-5
        if valid:
            sdr, sir, sar, _ = bss_eval_sources(np.asarray(gts_wav),
                                                np.asarray(preds_wav), False)
            sdr_mix, _, _, _ = bss_eval_sources(
                np.asarray(gts_wav),
                np.asarray([mix_wav[0:L] for n in range(N)]), False)
            sdr_mix_meter.update(sdr_mix.mean())
            sdr_meter.update(sdr.mean())
            sir_meter.update(sir.mean())
            sar_meter.update(sar.mean())

    return [
        sdr_mix_meter.average(),
        sdr_meter.average(),
        sir_meter.average(),
        sar_meter.average()
    ]
Пример #5
0
def evaluate(netWrapper, loader, history, epoch, args):
    print('Evaluating at {} epochs...'.format(epoch))
    torch.set_grad_enabled(False)

    # remove previous viz results
    makedirs(args.vis, remove=True)

    # switch to eval mode
    netWrapper.eval()

    # initialize meters
    loss_meter = AverageMeter()
    sdr_mix_meter = AverageMeter()
    sdr_meter = AverageMeter()
    sir_meter = AverageMeter()
    sar_meter = AverageMeter()

    # initialize HTML header
    visualizer = HTMLVisualizer(os.path.join(args.vis, 'index.html'))
    header = ['Filename', 'Input Mixed Audio']
    for n in range(1, args.num_mix + 1):
        header += [
            'Video {:d}'.format(n), 'Predicted Audio {:d}'.format(n),
            'GroundTruth Audio {}'.format(n), 'Predicted Mask {}'.format(n),
            'GroundTruth Mask {}'.format(n)
        ]
    header += ['Loss weighting']
    visualizer.add_header(header)
    vis_rows = []

    for i, batch_data in enumerate(loader):
        # forward pass
        err, outputs = netWrapper.forward(batch_data, args)
        err = err.mean()

        loss_meter.update(err.item())
        print('[Eval] iter {}, loss: {:.4f}'.format(i, err.item()))

        # calculate metrics
        sdr_mix, sdr, sir, sar = calc_metrics(batch_data, outputs, args)
        sdr_mix_meter.update(sdr_mix)
        sdr_meter.update(sdr)
        sir_meter.update(sir)
        sar_meter.update(sar)

        # output visualization
        if len(vis_rows) < args.num_vis:
            output_visuals(vis_rows, batch_data, outputs, args)

    print('[Eval Summary] Epoch: {}, Loss: {:.4f}, '
          'SDR_mixture: {:.4f}, SDR: {:.4f}, SIR: {:.4f}, SAR: {:.4f}'.format(
              epoch, loss_meter.average(), sdr_mix_meter.average(),
              sdr_meter.average(), sir_meter.average(), sar_meter.average()))
    history['val']['epoch'].append(epoch)
    history['val']['err'].append(loss_meter.average())
    history['val']['sdr'].append(sdr_meter.average())
    history['val']['sir'].append(sir_meter.average())
    history['val']['sar'].append(sar_meter.average())

    print('Plotting html for visualization...')
    visualizer.add_rows(vis_rows)
    visualizer.write_html()

    # Plot figure
    if epoch > 0:
        print('Plotting figures...')
        plot_loss_metrics(args.ckpt, history)
Пример #6
0
def evaluate(segmentation_module, loader, cfg, gpu):
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    time_meter = AverageMeter()

    segmentation_module.eval()

    pbar = tqdm(total=len(loader))
    for batch_data in loader:
        # process data
        batch_data = batch_data[0]
        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']

        torch.cuda.synchronize()
        tic = time.perf_counter()
        with torch.no_grad():
            segSize = (seg_label.shape[0], seg_label.shape[1])
            scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1])
            scores = async_copy_to(scores, gpu)

            for img in img_resized_list:
                feed_dict = batch_data.copy()
                feed_dict['img_data'] = img
                del feed_dict['img_ori']
                del feed_dict['info']
                del feed_dict['name']
                feed_dict = async_copy_to(feed_dict, gpu)

                # forward pass
                scores_tmp = segmentation_module(feed_dict, segSize=segSize)
                scores = scores + scores_tmp / len(cfg.DATASET.imgSizes)

            tmp_scores = scores
            if cfg.OOD.exclude_back:
                tmp_scores = tmp_scores[:,1:]

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())


        torch.cuda.synchronize()
        time_meter.update(time.perf_counter() - tic)

        # calculate accuracy
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label, cfg.DATASET.num_class)
        acc_meter.update(acc, pix)
        intersection_meter.update(intersection)
        union_meter.update(union)

        # visualization
        if cfg.VAL.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']),
                pred,
                os.path.join(cfg.TEST.result),
                as_numpy(scores.squeeze(0).cpu())
            )

        pbar.update(1)

    # summary
    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {:.4f}'.format(i, _iou))

    print('[Eval Summary]:')
    print('Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'
          .format(iou.mean(), acc_meter.average()*100, time_meter.average()))
Пример #7
0
def evaluate(model, loader, gpu_mode, num_class=7):
    # output format
    res = {
        'acc': 0.2,  # or acc for every category,
        'iou': 0.3,
        'iou_mean': 0.4
    }

    # metric meters
    acc_meter = AverageMeter()
    inter_meter = AverageMeter()
    union_meter = AverageMeter()

    # confusion_matrix = np.zeros((num_class, num_class))

    for i_batch, (img, mask, _) in enumerate(loader):
        if gpu_mode:
            img = img.cuda()
            mask = mask.cuda()

        output = model(img)

        output = output.max(1)[1]
        # calculate accuracy
        acc = accuracy(output, mask)
        acc_meter.update(acc)

        # calculate iou(ta)
        # if gpu_mode:
        #     output = output.int().cpu().detach()
        #     mask = mask.int().cpu().detach()
        # seg_pred = np.array(output)
        # seg_gt = np.array(mask)
        # ignore_index = seg_gt != 255
        # seg_gt = seg_gt[ignore_index]
        # seg_pred = seg_pred[ignore_index]
        # confusion_matrix += get_confusion_matrix(seg_gt, seg_pred, 7)
        #
        # pos = confusion_matrix.sum(1)
        # res0 = confusion_matrix.sum(0)
        # tp = np.diag(confusion_matrix)
        #
        # IU_array = (tp / np.maximum(1.0, pos + res0 - tp))

        # calculate iou
        intersection, union = intersectionAndUnion(output, mask, num_class)
        inter_meter.update(intersection)
        union_meter.update(union)
        del output
        del acc

    # summary
    # iou = IU_array
    # iou_mean = IU_array.mean()
    iou = inter_meter.sum / (union_meter.sum + 1e-10)
    iou_mean = iou.mean()
    acc_mean = acc_meter.average()

    res['acc'] = acc_mean
    res['iou'] = iou
    res['iou_mean'] = iou_mean

    return res
Пример #8
0
def train(segmentation_module, iterator, optimizers, history, epoch, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()

    names = ['object', 'part', 'scene', 'material']
    ave_losses = {n: AverageMeter() for n in names}
    ave_metric = {n: AverageMeter() for n in names}
    ave_losses['total'] = AverageMeter()

    segmentation_module.train(not args.fix_bn)

    # main loop
    tic = time.time()
    for i in range(args.epoch_iters):

        batch_data, src_idx = next(iterator)

        data_time.update(time.time() - tic)

        segmentation_module.zero_grad()

        # forward pass
        ret = segmentation_module(batch_data)

        # Backward
        loss = ret['loss']['total'].mean()
        loss.backward()
        for optimizer in optimizers:
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # measure losses
        for name in ret['loss'].keys():
            ave_losses[name].update(ret['loss'][name].mean().item())

        # measure metrics
        # NOTE: scene metric will be much lower than benchmark
        for name in ret['metric'].keys():
            ave_metric[name].update(ret['metric'][name].mean().item())

        # calculate accuracy, and display
        if i % args.disp_iter == 0:
            loss_info = "Loss: total {:.4f}, ".format(
                ave_losses['total'].average())
            loss_info += ", ".join([
                "{} {:.2f}".format(
                    n[0], ave_losses[n].average()
                    if ave_losses[n].average() is not None else 0)
                for n in names
            ])
            acc_info = "Accuracy: " + ", ".join([
                "{} {:4.2f}".format(
                    n[0], ave_metric[n].average()
                    if ave_metric[n].average() is not None else 0)
                for n in names
            ])
            print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'LR: encoder {:.6f}, decoder {:.6f}, {}, {}'.format(
                      epoch, i, args.epoch_iters, batch_time.average(),
                      data_time.average(), args.running_lr_encoder,
                      args.running_lr_decoder, acc_info, loss_info))

            fractional_epoch = epoch - 1 + 1. * i / args.epoch_iters
            history['train']['epoch'].append(fractional_epoch)
            history['train']['loss'].append(loss.item())

        # adjust learning rate
        cur_iter = i + (epoch - 1) * args.epoch_iters
        adjust_learning_rate(optimizers, cur_iter, args)
def mixup_train(loader, model, criterion, optimizer, epoch, use_cuda):
    global BEST_ACC, LR_STATE
    # switch to train mode
    if not cfg.CLS.fix_bn:
        model.train()
    else:
        model.eval()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    for batch_idx, (inputs, targets) in enumerate(loader):
        # adjust learning rate
        adjust_learning_rate(optimizer,
                             epoch,
                             batch=batch_idx,
                             batch_per_epoch=len(loader))

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        # mixup
        inputs, targets_a, targets_b, targets_c, lam = mixup_data_triple(
            inputs, targets, ALPHA, use_cuda)
        optimizer.zero_grad()
        inputs, targets_a, targets_b, targets_c = Variable(inputs), Variable(
            targets_a), Variable(targets_b), Variable(targets_c)

        # measure data loading time
        data_time.update(time.time() - end)

        # forward pass: compute output
        outputs = model(inputs)
        # forward pass: compute gradient and do SGD step
        loss_func = mixup_criterion_triple(targets_a, targets_b, targets_c,
                                           lam)
        loss = loss_func(criterion, outputs)
        # backward
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        # measure accuracy and record loss
        prec1, prec5 = [0.0], [0.0]
        losses.update(loss.data[0], inputs.size(0))
        top1.update(prec1[0], inputs.size(0))
        top5.update(prec5[0], inputs.size(0))

        if (batch_idx + 1) % cfg.CLS.disp_iter == 0:
            print(
                'Training: [{}/{}][{}/{}] | Best_Acc: {:4.2f}% | Time: {:.2f} | Data: {:.2f} | '
                'LR: {:.8f} | Top1: {:.4f}% | Top5: {:.4f}% | Loss: {:.4f} | Total: {:.2f}'
                .format(epoch + 1, cfg.CLS.epochs, batch_idx + 1,
                        len(loader), BEST_ACC, batch_time.average(),
                        data_time.average(), LR_STATE, top1.avg, top5.avg,
                        losses.avg, batch_time.sum + data_time.sum))

    return (losses.avg, top1.avg)
Пример #10
0
def evaluate(nets, loader, loader_2, history, epoch, args, isVis=True):
    print('Evaluating at {} epochs...'.format(epoch))
    loss_meter = AverageMeter()
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()

    loss_meter_2 = AverageMeter()
    acc_meter_2 = AverageMeter()
    intersection_meter_2 = AverageMeter()
    union_meter_2 = AverageMeter()

    # switch to eval mode
    for net in nets:
        net.eval()

    for i, batch_data in enumerate(loader):
        # forward pass
        torch.cuda.empty_cache()
        pred, recon, err = forward_with_loss(nets, batch_data, is_train=False)
        loss_meter.update(err.data.item())
        print('[Eval] iter {}, loss: {}'.format(i, err.data.item()))

        # calculate accuracy
        acc, pix = accuracy(batch_data, pred)
        acc_meter.update(acc, pix)

        intersection, union = intersectionAndUnion(batch_data, pred,
                                                   args.num_class)
        intersection_meter.update(intersection)
        union_meter.update(union)

        # visualization
        if isVis:
            visualize(batch_data, pred, args)
            visualize_recon(batch_data, recon, args)

    for i, batch_data in enumerate(loader_2):
        # forward pass
        torch.cuda.empty_cache()
        pred, recon, err = forward_with_loss(nets, batch_data, is_train=False)
        loss_meter_2.update(err.data.item())
        print('[Eval] iter {}, loss: {}'.format(i, err.data.item()))

        # calculate accuracy
        acc, pix = accuracy(batch_data, pred)
        acc_meter_2.update(acc, pix)

        intersection, union = intersectionAndUnion(batch_data, pred,
                                                   args.num_class)
        intersection_meter_2.update(intersection)
        union_meter_2.update(union)

        # visualization
        if isVis:
            visualize_recon(batch_data, recon, args)
            visualize(batch_data, pred, args)

    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {}'.format(trainID2Class[i], _iou))

    print('[Cityscapes Eval Summary]:')
    print('Epoch: {}, Loss: {}, Mean IoU: {:.4}, Accuracy: {:.2f}%'.format(
        epoch, loss_meter.average(), iou.mean(),
        acc_meter.average() * 100))

    history['val']['epoch'].append(epoch)
    history['val']['err'].append(loss_meter.average())
    history['val']['acc'].append(acc_meter.average())
    history['val']['mIoU'].append(iou.mean())

    iou = intersection_meter_2.sum / (union_meter_2.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {}'.format(trainID2Class[i], _iou))

    print('[BDD Eval Summary]:')
    print('Epoch: {}, Loss: {}, Mean IoU: {:.4}, Accuracy: {:.2f}%'.format(
        epoch, loss_meter_2.average(), iou.mean(),
        acc_meter_2.average() * 100))

    history['val_2']['epoch'].append(epoch)
    history['val_2']['err'].append(loss_meter_2.average())
    history['val_2']['acc'].append(acc_meter_2.average())
    history['val_2']['mIoU'].append(iou.mean())
Пример #11
0
def train(segmentation_module,
          iterator,
          optimizers,
          epoch,
          cfg,
          history=None,
          foveation_module=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()

    segmentation_module.train(not cfg.TRAIN.fix_bn)
    if cfg.MODEL.foveation:
        foveation_module.train(not cfg.TRAIN.fix_bn)

    # main loop
    tic = time.time()
    for i in range(cfg.TRAIN.epoch_iters):
        # load a batch of data
        batch_data = next(iterator)
        if type(batch_data) is not list:
            single_gpu_mode = True
            batch_data['img_data'] = batch_data['img_data'][0].cuda()
            batch_data['seg_label'] = batch_data['seg_label'][0].cuda()
            batch_data = [batch_data]
        else:
            single_gpu_mode = False
        data_time.update(time.time() - tic)
        segmentation_module.zero_grad()
        if cfg.MODEL.foveation:
            foveation_module.zero_grad()

        # adjust learning rate non_foveation
        if not cfg.MODEL.foveation:
            cur_iter = i + (epoch - 1) * cfg.TRAIN.epoch_iters
            adjust_learning_rate(optimizers, cur_iter, cfg)

        # Foveation
        if cfg.MODEL.foveation:
            # Note by sudo_ means here is only for size estimation purpose
            # because batch_data is obtained by user modified DataParallel, s.t. batch_data is a list with length as len(gpus)
            # and each batch_data[i] is the actualy dict(batch_data) returned in dataset.TrainDataset
            # for ib in range(len(batch_data)):
            # print('img_data shape: ',  batch_data[ib]['img_data'].shape)
            sudo_X, sudo_Y = batch_data[0]['img_data'], batch_data[0][
                'seg_label']
            fov_map_scale = cfg.MODEL.fov_map_scale
            # NOTE: although here we use batch imresize yet in practical batch size for X = 1
            sudo_X_lr = b_imresize(
                sudo_X, (round(sudo_X.shape[2] / fov_map_scale),
                         round(sudo_X.shape[3] /
                               (fov_map_scale * cfg.MODEL.patch_ap))),
                interp='bilinear')
            if cfg.TRAIN.auto_fov_location_step:
                cfg.TRAIN.fov_location_step = round(
                    sudo_X.shape[2] / fov_map_scale) * round(
                        sudo_X.shape[3] / (fov_map_scale * cfg.MODEL.patch_ap))
            # foveation (crop as you go)
            fov_location_batch_step = 0
            if cfg.TRAIN.sync_location == 'rand':  # bp         at each step and sync at random
                rand_location = random.randint(1,
                                               cfg.TRAIN.fov_location_step - 1)
            elif cfg.TRAIN.sync_location == 'mean_mbs':  # bp and opt at each step and sync at random (last of random X_lr_cord list) with average loss
                rand_location = cfg.TRAIN.fov_location_step
            elif cfg.TRAIN.sync_location == 'none_sync':  # bp and opt at each step
                rand_location = cfg.TRAIN.fov_location_step

            # mini_batch
            X_lr_cord = []
            for xi in range(sudo_X_lr.shape[2]):
                for yi in range(sudo_X_lr.shape[3]):
                    X_lr_cord.append((xi, yi))
            random.shuffle(X_lr_cord)
            mbs = cfg.TRAIN.mini_batch_size
            mb_iter_count = 0
            mb_idx = 0
            mb_idx_count = 0
            while mb_idx < len(X_lr_cord) and mb_idx_count < rand_location:
                # correct zero_grad https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903
                # https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch
                # https://discuss.pytorch.org/t/whats-the-difference-between-optimizer-zero-grad-vs-nn-module-zero-grad/59233
                segmentation_module.zero_grad()
                foveation_module.zero_grad()

                batch_iters = rand_location
                cur_iter = fov_location_batch_step + (i - 1) * batch_iters + (
                    epoch - 1) * cfg.TRAIN.epoch_iters * batch_iters
                # print('original max_iter:', cfg.TRAIN.max_iters)
                if cfg.TRAIN.fov_scale_lr != '' or cfg.TRAIN.fov_scale_weight_decay != '':
                    # weighted patch size normalized _ mini_batch average
                    if mb_idx == 0:
                        wpsn_mb = 1
                    else:
                        wpsn_mb = wpsn_mb / mbs
                if cfg.TRAIN.sync_location != 'rand':
                    fov_max_iters = batch_iters * cfg.TRAIN.epoch_iters * cfg.TRAIN.num_epoch

                    if cfg.TRAIN.fov_scale_lr == 'pen_sp':  # penalty small patch, smaller average patch size smaller learning rate
                        lr_scale = float(wpsn_mb)
                    elif cfg.TRAIN.fov_scale_lr == 'pen_lp':  # penalty large patch, larger average patch size smaller learning rate
                        lr_scale = float(1 - wpsn_mb)
                    else:
                        lr_scale = 1.
                    if cfg.TRAIN.fov_scale_weight_decay == 'reg_sp':  # regularise small patch, smaller average patch size larger regularisation
                        wd_scale = float(1 - wpsn_mb)
                    elif cfg.TRAIN.fov_scale_weight_decay == 'reg_lp':  # regularise large patch, larger average patch size larger regularisation
                        wd_scale = float(wpsn_mb)
                    else:
                        wd_scale = 1.

                    if cfg.TRAIN.fov_scale_lr != '' or cfg.TRAIN.fov_scale_weight_decay != '':
                        wpsn_mb = 0

                    # print('before fov_pow lr_scale={}, wd_scale={}'.format(lr_scale, wd_scale))
                    adjust_learning_rate(optimizers,
                                         cur_iter,
                                         cfg,
                                         lr_mbs=True,
                                         f_max_iter=fov_max_iters,
                                         lr_scale=lr_scale,
                                         wd_scale=wd_scale)
                    if cfg.MODEL.gumbel_tau_anneal:
                        adjust_gms_tau(cur_iter, cfg, r=1. / fov_max_iters)
                if cfg.TRAIN.entropy_regularisation:
                    mbs_mean_entropy_reg = 0
                xi = []
                yi = []
                mini_batch_sample = 0
                while mini_batch_sample < mbs and mb_idx < len(X_lr_cord):
                    xi.append(X_lr_cord[mb_idx][0])
                    yi.append(X_lr_cord[mb_idx][1])
                    mb_idx += 1
                    fov_location_batch_step += 1
                    mb_idx_count += 1
                    mini_batch_sample += 1
                xi = tuple(xi)
                yi = tuple(yi)

                for idx in range(len(batch_data)):
                    batch_data[idx]['cor_info'] = (xi, yi, rand_location,
                                                   fov_location_batch_step)
                if fov_location_batch_step == rand_location:
                    if single_gpu_mode:
                        patch_data, F_Xlr, print_grad = foveation_module(
                            batch_data[0])
                    else:
                        patch_data, F_Xlr, print_grad = foveation_module(
                            batch_data)
                else:
                    if single_gpu_mode:
                        patch_data, F_Xlr = foveation_module(batch_data[0])
                    else:
                        patch_data, F_Xlr = foveation_module(batch_data)

                # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
                # by set base = len(patch_bank), uniform distribution will have entropy = 1 (so absolute uncertain)
                if cfg.TRAIN.entropy_regularisation:
                    # comprosed solution consider batch size != 1
                    F_Xlr_c = F_Xlr.clone()
                    if cfg.MODEL.gumbel_softmax:
                        F_Xlr_c = F_Xlr_c.exp()

                    mean_entropy_reg = 0
                    for i_batch in range(F_Xlr_c.shape[0]):
                        mean_entropy_reg += (
                            -F_Xlr_c[i_batch, :, xi, yi] *
                            F_Xlr_c[i_batch, :, xi, yi].log()).sum()
                    mbs_mean_entropy_reg += mean_entropy_reg / (
                        rand_location // mbs)

                if cfg.TRAIN.entropy_regularisation:
                    # comprosed solution consider batch size != 1
                    mean_entropy = 0
                    for i_batch in range(F_Xlr.shape[0]):
                        mean_entropy += (entropy(
                            F_Xlr[i_batch, :, xi, yi].cpu().detach().numpy(),
                            base=len(
                                cfg.MODEL.patch_bank)).mean()) / F_Xlr.shape[0]

                if cfg.TRAIN.fov_scale_lr != '':
                    print(F_Xlr.shape)
                    pb = cfg.MODEL.patch_bank
                    wps = torch.sum(
                        F_Xlr[:, :, xi, yi] *
                        torch.tensor(pb).float().unsqueeze(0).unsqueeze(
                            -1).unsqueeze(-1).to(F_Xlr.device),
                        dim=1).mean()
                    wpsn = (wps - pb[0]) / (pb[-1] - pb[0])
                    print('wpsn: ', wpsn)
                    wpsn_mb += wpsn

                # split multi gpu collected dict into list to keep DataParall work for segmentation_module
                # print('patch_data_img_data_shape: ', patch_data['img_data'].shape)
                if mb_iter_count == 0:
                    patch_data_list = []
                    for idx in range(len(batch_data)):
                        patch_data_temp = dict()
                        patch_data_temp['img_data'] = torch.split(
                            patch_data['img_data'],
                            patch_data['img_data'].shape[0] // len(batch_data),
                            dim=0)[idx]
                        patch_data_temp['seg_label'] = torch.split(
                            patch_data['seg_label'],
                            patch_data['seg_label'].shape[0] //
                            len(batch_data),
                            dim=0)[idx]
                        if cfg.MODEL.hard_fov_pred:
                            patch_data_temp['hard_max_idx'] = torch.split(
                                patch_data['hard_max_idx'],
                                patch_data['hard_max_idx'].shape[0] //
                                len(batch_data),
                                dim=0)[idx]
                        patch_data_list.append(patch_data_temp)
                else:
                    for idx in range(len(batch_data)):
                        patch_data_temp['img_data'] = torch.split(
                            patch_data['img_data'],
                            patch_data['img_data'].shape[0] // len(batch_data),
                            dim=0)[idx]
                        patch_data_temp['seg_label'] = torch.split(
                            patch_data['seg_label'],
                            patch_data['seg_label'].shape[0] //
                            len(batch_data),
                            dim=0)[idx]
                        patch_data_list[idx]['img_data'] = torch.cat([
                            patch_data_list[idx]['img_data'],
                            patch_data_temp['img_data']
                        ])
                        patch_data_list[idx]['seg_label'] = torch.cat([
                            patch_data_list[idx]['seg_label'],
                            patch_data_temp['seg_label']
                        ])
                        if cfg.MODEL.hard_fov_pred:
                            patch_data_temp['hard_max_idx'] = torch.split(
                                patch_data['hard_max_idx'],
                                patch_data['hard_max_idx'].shape[0] //
                                len(batch_data),
                                dim=0)[idx]
                            patch_data_list[idx]['hard_max_idx'] = torch.cat([
                                patch_data_list[idx]['hard_max_idx'],
                                patch_data_temp['hard_max_idx']
                            ])
                    mb_iter_count += 1
                mb_iter_count = 0
                # forward pass
                # print('[patch_data_list_img_data_shape: ]', patch_data_list[0]['img_data'].shape)
                if single_gpu_mode:
                    loss, acc = segmentation_module(patch_data_list[0])
                else:
                    loss, acc = segmentation_module(patch_data_list)
                if cfg.MODEL.categorical:
                    # print('log_prob_act:', patch_data['log_prob_act'])
                    # print('ori loss:', loss)
                    if cfg.MODEL.inv_categorical:
                        loss = -patch_data['log_prob_act'] * loss
                    else:
                        loss = patch_data['log_prob_act'] * loss
                    # print('reinforced loss:', loss)
                if not single_gpu_mode:
                    loss = loss.mean()
                    acc = acc.mean()
                if cfg.TRAIN.entropy_regularisation:
                    loss += cfg.TRAIN.entropy_regularisation_weight * mbs_mean_entropy_reg
                if fov_location_batch_step // mbs == 1:
                    loss_step = loss.data
                    acc_step = acc.data
                else:
                    loss_step += loss.data
                    acc_step += acc.data

                if fov_location_batch_step == rand_location:
                    loss_retain = loss
                elif fov_location_batch_step != cfg.TRAIN.fov_location_step:
                    loss.backward()
                    if cfg.TRAIN.sync_location != 'rand':
                        for optimizer in optimizers:
                            optimizer.step()

                if fov_location_batch_step == cfg.TRAIN.fov_location_step:

                    if cfg.TRAIN.sync_location != 'none_sync':
                        # print('iter {}: bp at random retained location {}/{}, xi={}, yi={}'.format(i, rand_location, cfg.TRAIN.fov_location_step, xi, yi))
                        if cfg.TRAIN.sync_location == 'mean_mbs':
                            loss_retain.data = loss_step / (
                                cfg.TRAIN.fov_location_step / mbs)
                        loss_retain.backward()
                    else:
                        loss.backward()
                    for optimizer in optimizers:
                        optimizer.step()
                    loss_step /= (cfg.TRAIN.fov_location_step / mbs)
                    acc_step /= (cfg.TRAIN.fov_location_step / mbs)
                    ave_total_loss.update(loss_step.data.item())
                    ave_acc.update(acc_step.data.item() * 100)
                    fov_location_batch_step = 0
                    if not cfg.TRAIN.auto_fov_location_step and cfg.TRAIN.sync_location == 'rand':
                        rand_location = random.randint(
                            2, cfg.TRAIN.fov_location_step - 1)
                # print('iter {}: {}/{}/{} foveate points, xi={}, yi={}\n'.format(i, fov_location_batch_step, mb_idx, sudo_X_lr.shape[2]*sudo_X_lr.shape[3], xi, yi))

        else:
            # forward pass
            loss, acc = segmentation_module(batch_data)
            print()
            loss_step = loss.mean()
            acc_step = acc.mean()

            # Backward
            loss_step.backward()
            for optimizer in optimizers:
                optimizer.step()

            # update average loss and acc
            ave_total_loss.update(loss_step.data.item())
            ave_acc.update(acc_step.data.item() * 100)

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # calculate accuracy, and display
        if i % cfg.TRAIN.disp_iter == 0:
            if cfg.MODEL.foveation:
                print(
                    'iter {}: bp at random retained location {}/{}, xi={}, yi={}'
                    .format(i, rand_location, cfg.TRAIN.fov_location_step, xi,
                            yi))
            print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                  'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
                  'Accuracy: {:4.2f}, Loss: {:.6f}'.format(
                      epoch, i, cfg.TRAIN.epoch_iters, batch_time.average(),
                      data_time.average(), cfg.TRAIN.running_lr_encoder,
                      cfg.TRAIN.running_lr_decoder, ave_acc.average(),
                      ave_total_loss.average()))

        fractional_epoch = epoch - 1 + 1. * i / cfg.TRAIN.epoch_iters
        if history is not None:
            history['train']['epoch'].append(fractional_epoch)
            history['train']['loss'].append(ave_total_loss.average())
            history['train']['acc'].append(ave_acc.average() / 100)
            history['train']['print_grad'] = print_grad
def evaluate(segmentation_module, loader, cfg, gpu):
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    time_meter = AverageMeter()

    segmentation_module.eval()

    pbar = tqdm(total=len(loader))

    for batch_data in loader:
        # process data

        batch_data = batch_data[0]

        print('Info:', batch_data['info'])

        for key in batch_data:
            print(key, type(batch_data[key]))

            if isinstance(batch_data[key], torch.Tensor):
                print(batch_data[key].shape)

            if key == 'img_data':
                for i, data in enumerate(batch_data[key]):
                    # data.requires_grad = True

                    print(i, type(data), data.shape, data.requires_grad)

        seg_label = as_numpy(batch_data['seg_label'][0])
        img_resized_list = batch_data['img_data']

        print(seg_label.shape)

        torch.cuda.synchronize()
        tic = time.perf_counter()

        seg_size = (seg_label.shape[0], seg_label.shape[1])
        scores = torch.zeros(1, cfg.DATASET.num_class, seg_size[0],
                             seg_size[1])
        scores = async_copy_to(scores, gpu)

        for img in img_resized_list:
            feed_dict = batch_data.copy()
            feed_dict['img_data'] = img
            del feed_dict['img_ori']
            del feed_dict['info']
            feed_dict = async_copy_to(feed_dict, gpu)

            feed_dict['img_data'].requires_grad = True

            print("Right before", feed_dict['img_data'].size())

            # forward pass
            # scores_tmp = segmentation_module(feed_dict, segSize=seg_size)

            segmentation_module.zero_grad()

            loss, acc = segmentation_module(feed_dict, segSize=seg_size)
            loss = loss.mean()

            loss.backward()

            print(feed_dict['img_data'].grad.data.size())

            scores = scores + scores_tmp / len(cfg.DATASET.imgSizes)

            _, pred = torch.max(scores, dim=1)
            pred = as_numpy(pred.squeeze(0).cpu())

        torch.cuda.synchronize()
        time_meter.update(time.perf_counter() - tic)

        # calculate accuracy
        acc, pix = accuracy(pred, seg_label)
        intersection, union = intersectionAndUnion(pred, seg_label,
                                                   cfg.DATASET.num_class)

        acc_meter.update(acc, pix)
        intersection_meter.update(intersection)
        union_meter.update(union)

        # visualization
        if cfg.VAL.visualize:
            visualize_result(
                (batch_data['img_ori'], seg_label, batch_data['info']), pred,
                os.path.join(cfg.DIR, 'result'))

        pbar.update(1)

    # summary
    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {:.4f}'.format(i, _iou))

    print('[Eval Summary]:')
    print(
        'Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'.format(
            iou.mean(),
            acc_meter.average() * 100, time_meter.average()))
Пример #13
0
def train(segmentation_module, loader_train, optimizers, history, epoch, args):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()
    ave_j1 = AverageMeter()
    ave_j2 = AverageMeter()
    ave_j3 = AverageMeter()

    segmentation_module.train(not args.fix_bn)

    # main loop
    tic = time.time()
    iter_count = 0

    if epoch == args.start_epoch and args.start_epoch > 1:
        scale_running_lr = ((1. - float(epoch - 1) /
                             (args.num_epoch))**args.lr_pow)
        args.running_lr_encoder = args.lr_encoder * scale_running_lr
        for param_group in optimizers[0].param_groups:
            param_group['lr'] = args.running_lr_encoder

    for batch_data in loader_train:
        data_time.update(time.time() - tic)
        batch_data["image"] = batch_data["image"].cuda()
        segmentation_module.zero_grad()
        # forward pass
        loss, acc = segmentation_module(batch_data, epoch)
        loss = loss.mean()

        jaccard = acc[1]
        for j in jaccard:
            j = j.float().mean()
        acc = acc[0].float().mean()

        # Backward
        loss.backward()
        for optimizer in optimizers:
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()
        iter_count += args.batch_size_per_gpu

        # update average loss and acc
        ave_total_loss.update(loss.data.item())
        ave_acc.update(acc.data.item() * 100)

        ave_j1.update(jaccard[0].data.item() * 100)
        ave_j2.update(jaccard[1].data.item() * 100)
        ave_j3.update(jaccard[2].data.item() * 100)

        if iter_count % (args.batch_size_per_gpu * 10) == 0:
            # calculate accuracy, and display
            if args.unet == False:
                print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
                      'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
                      'Accuracy: {:4.2f}, Loss: {:.6f}'.format(
                          epoch, i, args.epoch_iters, batch_time.average(),
                          data_time.average(), args.running_lr_encoder,
                          args.running_lr_decoder, ave_acc.average(),
                          ave_total_loss.average()))
            else:
                print(
                    'Epoch: [{}/{}], Iter: [{}], Time: {:.2f}, Data: {:.2f},'
                    ' lr_unet: {:.6f}, Accuracy: {:4.2f}, Jaccard: [{:4.2f},{:4.2f},{:4.2f}], '
                    'Loss: {:.6f}'.format(epoch, args.max_iters, iter_count,
                                          batch_time.average(),
                                          data_time.average(),
                                          args.running_lr_encoder,
                                          ave_acc.average(), ave_j1.average(),
                                          ave_j2.average(), ave_j3.average(),
                                          ave_total_loss.average()))

    #Average jaccard across classes.
    j_avg = (ave_j1.average() + ave_j2.average() + ave_j3.average()) / 3

    #Update the training history
    history['train']['epoch'].append(epoch)
    history['train']['loss'].append(loss.data.item())
    history['train']['acc'].append(acc.data.item())
    history['train']['jaccard'].append(j_avg)
    # adjust learning rate
    adjust_learning_rate(optimizers, epoch, args)
Пример #14
0
def evaluate(nets, loader, history, epoch, args):
    print('Evaluating at {} epochs...'.format(epoch))
    loss_meter = AverageMeter()
    acc_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()

    # switch to eval mode
    for net in nets:
        net.eval()

    for i, batch_data in enumerate(loader):
        # forward pass
        torch.cuda.empty_cache()
        pred, err = forward_with_loss(nets, batch_data, args, is_train=False)
        loss_meter.update(err.data[0])
        print('[Eval] iter {}, loss: {}'.format(i, err.data[0]))

        # calculate accuracy
        acc, pix = accuracy(batch_data, pred)
        acc_meter.update(acc, pix)

        intersection, union = intersectionAndUnion(batch_data, pred,
                                                   args.num_class)
        intersection_meter.update(intersection)
        union_meter.update(union)

        # visualization
        visualize(batch_data, pred, args)

    iou = intersection_meter.sum / (union_meter.sum + 1e-10)
    for i, _iou in enumerate(iou):
        print('class [{}], IoU: {}'.format(trainID2Class[i], _iou))

    print('[Eval Summary]:')
    print('Epoch: {}, Loss: {}, Mean IoU: {:.4}, Accurarcy: {:.2f}%'.format(
        epoch, loss_meter.average(), iou.mean(),
        acc_meter.average() * 100))

    history['val']['epoch'].append(epoch)
    history['val']['err'].append(loss_meter.average())
    history['val']['acc'].append(acc_meter.average())
    history['val']['mIoU'].append(iou.mean())

    # Plot figure
    if epoch > 0:
        print('Plotting loss figure...')
        fig = plt.figure()
        plt.plot(np.asarray(history['train']['epoch']),
                 np.log(np.asarray(history['train']['err'])),
                 color='b',
                 label='training')
        plt.plot(np.asarray(history['val']['epoch']),
                 np.log(np.asarray(history['val']['err'])),
                 color='c',
                 label='validation')
        plt.legend()
        plt.xlabel('Epoch')
        plt.ylabel('Log(loss)')
        fig.savefig('{}/loss.png'.format(args.ckpt), dpi=200)
        plt.close('all')

        fig = plt.figure()
        plt.plot(history['train']['epoch'],
                 history['train']['acc'],
                 color='b',
                 label='training')
        plt.plot(history['val']['epoch'],
                 history['val']['acc'],
                 color='c',
                 label='validation')
        plt.legend()
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        fig.savefig('{}/accuracy.png'.format(args.ckpt), dpi=200)
        plt.close('all')
Пример #15
0
def evaluate(nets, loader, history, epoch, args):
    print('Evaluating at {} epochs...'.format(epoch))
    loss_meter = AverageMeter()
    acc_meter = AverageMeter()

    # switch to eval mode
    for net in nets:
        net.eval()

    for i, batch_data in enumerate(loader):
        # forward pass
        pred, err = forward_with_loss(nets, batch_data, args, is_train=False)
        loss_meter.update(err.data[0])
        print('[Eval] iter {}, loss: {}'.format(i, err.data[0]))

        # calculate accuracy
        acc, pix = accuracy(batch_data, pred)
        acc_meter.update(acc, pix)

        # visualization
        visualize(batch_data, pred, args)

    history['val']['epoch'].append(epoch)
    history['val']['err'].append(loss_meter.average())
    history['val']['acc'].append(acc_meter.average())
    print('[Eval Summary] Epoch: {}, Loss: {}, Accurarcy: {:4.2f}%'.format(
        epoch, loss_meter.average(),
        acc_meter.average() * 100))

    # Plot figure
    if epoch > 0:
        print('Plotting loss figure...')
        fig = plt.figure()
        plt.plot(np.asarray(history['train']['epoch']),
                 np.log(np.asarray(history['train']['err'])),
                 color='b',
                 label='training')
        plt.plot(np.asarray(history['val']['epoch']),
                 np.log(np.asarray(history['val']['err'])),
                 color='c',
                 label='validation')
        plt.legend()
        plt.xlabel('Epoch')
        plt.ylabel('Log(loss)')
        fig.savefig('{}/loss.png'.format(args.ckpt), dpi=200)
        plt.close('all')

        fig = plt.figure()
        plt.plot(history['train']['epoch'],
                 history['train']['acc'],
                 color='b',
                 label='training')
        plt.plot(history['val']['epoch'],
                 history['val']['acc'],
                 color='c',
                 label='validation')
        plt.legend()
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        fig.savefig('{}/accuracy.png'.format(args.ckpt), dpi=200)
        plt.close('all')
def train(segmentation_module, loader_train, optimizers, epoch, space):
    adjust_learning_rate(optimizers, epoch, space['lr'])

    batch_time = AverageMeter()
    data_time = AverageMeter()
    ave_total_loss = AverageMeter()
    ave_acc = AverageMeter()
    ave_j1 = AverageMeter()
    ave_j2 = AverageMeter()
    ave_j3 = AverageMeter()

    segmentation_module.train()

    # main loop
    tic = time.time()
    iter_count = 0

    for batch_data in loader_train:
        data_time.update(time.time() - tic)
        batch_data["image"] = batch_data["image"].cuda()
        batch_data["mask"] = batch_data["mask"].cuda()
        segmentation_module.zero_grad()
        # forward pass
        loss, acc = segmentation_module(batch_data, epoch)
        loss = loss.mean()

        jaccard = acc[1]
        for j in jaccard:
            j = j.float().mean()
        acc = acc[0].float().mean()

        # Backward
        loss.backward()
        for optimizer in optimizers:
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()
        iter_count += 4

        # update average loss and acc
        ave_total_loss.update(loss.data.item())
        ave_acc.update(acc.data.item() * 100)

        ave_j1.update(jaccard[0].data.item() * 100)
        ave_j2.update(jaccard[1].data.item() * 100)
        ave_j3.update(jaccard[2].data.item() * 100)

        if iter_count % 40 == 0:
            # calculate accuracy, and display
            print('Epoch: [{}/{}], Iter: [{}], Time: {:.2f}, Data: {:.2f},'
                  'Accuracy: {:4.2f}, Jaccard: [{:4.2f},{:4.2f},{:4.2f}], '
                  'Loss: {:.6f}'.format(epoch, 30, iter_count,
                                        batch_time.average(),
                                        data_time.average(), ave_acc.average(),
                                        ave_j1.average(), ave_j2.average(),
                                        ave_j3.average(),
                                        ave_total_loss.average()))

    #Average jaccard across classes.
    j_avg = (ave_j1.average() + ave_j2.average() + ave_j3.average()) / 3

    return ave_total_loss.average()