Ejemplo n.º 1
0
def train(args, model, device, train_loader, optimizer, epoch, criterion):
    model.train()
    epoch_loss = 0
    for batch_idx, batch_data in enumerate(train_loader):
        batch_ldr0, batch_ldr1, batch_ldr2 = batch_data['input0'].to(
            device), batch_data['input1'].to(device), batch_data['input2'].to(
                device)
        label = batch_data['label'].to(device)

        pred = model(batch_ldr0, batch_ldr1, batch_ldr2)
        pred = range_compressor_tensor(pred)
        pred = torch.clamp(pred, 0., 1.)
        loss = criterion(pred, label)
        psnr = batch_PSNR(pred, label, 1.0)
        # psnr = batch_PSNR(torch.clamp(pred, 0., 1.), label, 1.0)

        epoch_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        # nn.utils.clip_grad_value_(model.parameters(), 0.01)
        optimizer.step()

        iteration = (epoch - 1) * len(train_loader) + batch_idx
        if batch_idx % args.log_interval == 0:
            logx.msg('Train Epoch: {} [{}/{} ({:.0f} %)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(batch_data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            logx.add_scalar('train/learning_rate',
                            optimizer.param_groups[0]['lr'], iteration)
            logx.add_scalar('train/psnr', psnr, iteration)
            logx.add_image('train/input1', batch_ldr0[0][[2, 1, 0], :, :],
                           iteration)
            logx.add_image('train/input2', batch_ldr1[0][[2, 1, 0], :, :],
                           iteration)
            logx.add_image('train/input3', batch_ldr2[0][[2, 1, 0], :, :],
                           iteration)
            logx.add_image('train/pred', pred[0][[2, 1, 0], :, :], iteration)
            logx.add_image('train/gt', label[0][[2, 1, 0], :, :], iteration)

        # capture metrics
        metrics = {'loss': loss.item()}
        logx.metric('train', metrics, iteration)
Ejemplo n.º 2
0
def print_evaluate_results(hist,
                           iu,
                           epoch=0,
                           iou_per_scale=None,
                           log_multiscale_tb=False):
    """
    If single scale:
       just print results for default scale
    else
       print all scale results

    Inputs:
    hist = histogram for default scale
    iu = IOU for default scale
    iou_per_scale = iou for all scales
    """
    id2cat = cfg.DATASET_INST.trainid_to_name
    # id2cat = {i: i for i in range(cfg.DATASET.NUM_CLASSES)}

    iu_FP = hist.sum(axis=1) - np.diag(hist)
    iu_FN = hist.sum(axis=0) - np.diag(hist)
    iu_TP = np.diag(hist)

    logx.msg('IoU:')

    header = ['Id', 'label']
    header.extend(['iU_{}'.format(scale) for scale in iou_per_scale])
    header.extend(['TP', 'FP', 'FN', 'Precision', 'Recall'])

    tabulate_data = []

    for class_id in range(len(iu)):
        class_data = []
        class_data.append(class_id)
        class_name = "{}".format(
            id2cat[class_id]) if class_id in id2cat else ''
        class_data.append(class_name)
        for scale in iou_per_scale:
            class_data.append(iou_per_scale[scale][class_id] * 100)

        total_pixels = hist.sum()
        class_data.append(100 * iu_TP[class_id] / total_pixels)
        class_data.append(iu_FP[class_id] / iu_TP[class_id])
        class_data.append(iu_FN[class_id] / iu_TP[class_id])
        class_data.append(iu_TP[class_id] /
                          (iu_TP[class_id] + iu_FP[class_id]))
        class_data.append(iu_TP[class_id] /
                          (iu_TP[class_id] + iu_FN[class_id]))
        tabulate_data.append(class_data)

        if log_multiscale_tb:
            logx.add_scalar("xscale_%0.1f/%s" % (0.5, str(id2cat[class_id])),
                            float(iou_per_scale[0.5][class_id] * 100), epoch)
            logx.add_scalar("xscale_%0.1f/%s" % (1.0, str(id2cat[class_id])),
                            float(iou_per_scale[1.0][class_id] * 100), epoch)
            logx.add_scalar("xscale_%0.1f/%s" % (2.0, str(id2cat[class_id])),
                            float(iou_per_scale[2.0][class_id] * 100), epoch)

    print_str = str(tabulate((tabulate_data), headers=header, floatfmt='1.2f'))
    logx.msg(print_str)
Ejemplo n.º 3
0
def train_net():
    header = [
        'epoch', 'train_loss', 'val_loss', 'val_dice', 'val_iou', 'lr',
        'time(s)'
    ]
    start_epoch, global_step, best_score, total_list = -1, 1, 0.0, []
    if args.vis:
        viz = Visualizer(port=args.port,
                         env=f"EXP_{args.exp_id}_NET_{args.arch}")

    # Resume the training process
    if args.resume:
        start_epoch = resume(args=args)

    # automatic mixed-precision training
    if args.amp_available:
        scaler = torch.cuda.amp.GradScaler()

    for epoch in range(start_epoch + 1, args.epochs):
        args.net.train()

        epoch_loss, epoch_start_time, rows = 0., time(), [epoch + 1]

        # get the current learning rate
        new_lr = get_lr(args=args, epoch=epoch)

        # Training process
        with tqdm(total=n_train,
                  desc=f'Epoch-{epoch + 1}/{args.epochs}',
                  unit='img') as p_bar:
            for batch in train_loader:
                # args.optimizer.zero_grad()
                image, label = batch['image'], batch['label']
                assert image.shape[1] == args.n_channels

                # Prepare the image and the corresponding label.
                image = image.to(device=args.device, dtype=torch.float32)
                mask_type = torch.float32 if args.n_classes == 1 else torch.long
                label = label.to(device=args.device, dtype=mask_type)

                # Forward propagation.
                if args.amp_available:
                    with torch.cuda.amp.autocast():
                        try:
                            output = args.net(image)
                        except RuntimeError as exception:
                            if "out of memory" in str(exception):
                                print("WARNING: out of memory")
                                if hasattr(torch.cuda, 'empty_cache'):
                                    torch.cuda.empty_cache()
                                exit(0)
                            else:
                                raise exception
                        loss = criterion(output, label)
                else:
                    output = args.net(image)
                    loss = criterion(output, label)

                # visualize the image.
                if args.vis:
                    try:
                        viz.img(name='ground_truth', img_=label[0])
                        tmp = output[0]
                        tmp[tmp > 0.5] = 1.0
                        tmp[tmp < 0.5] = 0.0
                        viz.img(name='prediction', img_=tmp)
                    except ConnectionError:
                        pass

                args.optimizer.zero_grad()
                # Back propagation.
                if args.amp_available:
                    scaler.scale(loss).backward()
                    scaler.step(args.optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    args.optimizer.step()

                global_step += 1
                epoch_loss += loss.item()
                logx.add_scalar('Loss/train', loss.item(), global_step)
                p_bar.set_postfix(**{'loss (batch)': loss.item()})
                p_bar.update(image.shape[0])

        # Calculate  the train loss
        train_loss = epoch_loss / (n_train // args.batch_size)
        metrics = {'train_loss': train_loss}
        logx.metric(phase='train', metrics=metrics, epoch=epoch)

        # Validate process
        val_score, val_loss = eval_net(criterion, logx, epoch, val_loader,
                                       n_val, args)

        # Update the current learning rate and
        # you should write the monitor metrics in step() if you use the ReduceLROnPlateau scheduler.
        if args.sche != "Poly":
            args.scheduler.step()

        # Calculating and logging the metrics
        metrics = {
            'val_loss': val_loss,
            'iou': val_score['iou'],
            'dc': val_score['dc'],
            'sp': val_score['sp'],
            'se': val_score['se'],
            'acc': val_score['acc'],
        }
        logx.metric(phase='val', metrics=metrics, epoch=epoch)

        # Print the metrics
        print(
            "\033[1;33;44m=============================Evaluation result=============================\033[0m"
        )
        logx.msg("[Train] Loss: %.4f | LR: %.6f" % (train_loss, new_lr))
        logx.msg("[Valid] Loss: %.4f | ACC: %.4f | IoU: %.4f | DC: %.4f" % (
            val_loss,
            metrics['acc'],
            metrics['iou'],
            metrics['dc'],
        ))
        rows += [train_loss, val_loss, metrics['dc'], metrics['iou'], new_lr]

        # Logging the image to tensorboard
        logx.add_image('image', torch.cat([i for i in image], 2), epoch)
        logx.add_image('label/gt', torch.cat([j for j in label], 2), epoch)
        logx.add_image('label/pd', torch.cat([k > 0.5 for k in output], 2),
                       epoch)

        # Update the best score
        best_score, tm = update_score(args, best_score, val_score, logx, epoch,
                                      epoch_start_time)
        rows.append(tm)
        total_list.append(rows)

        # Saving the model with relevant parameters
        save_model(args, epoch, new_lr, interval=10)

    data = pd.DataFrame(total_list)
    file_path = os.path.join(os.path.join(args.dir_log, 'metrics.csv'))
    data.to_csv(file_path,
                header=header,
                index=False,
                mode='w',
                encoding='utf-8')
    plot_curve(file_path, args.dir_log, show=True)