Ejemplo n.º 1
0
def visualize(args, epoch, model, data_loader, writer):
    def save_image(image, tag):
        image -= image.min()
        image /= image.max()
        grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1)
        writer.add_image(tag, grid, epoch)

    model.eval()
    output_images = []
    target_images = []
    corrupted_images = []
    with torch.no_grad():
        for iter, data in enumerate(data_loader):
            y, mask, target, metadata = data[:4]
            y = y.to(args.device)
            mask = mask.to(args.device)
            target = target.to(args.device)

            n_slices = target.size(-3)
            # corrupted = transforms.root_sum_of_squares(estimate_to_image(
            #     y[..., n_slices // 2, :, :, :], target.size()[-2:]), 1)
            if args.n_slices == 1:
                mask = mask.squeeze(-4)
                y = y[..., n_slices // 2, :, :, :]

            estimate = model.forward(y=y, mask=mask, metadata=metadata)
            estimate.detach_()

            target = target[..., n_slices // 2, :, :]
            #target_norm = target.norm(dim=(-2, -1), keepdim=True)
            # corrupted_images.append(corrupted / target_norm)
            # target_images.append(target / target_norm)
            #corrupted_images.append(corrupted)
            target_images.append(target)

            if args.n_slices > 1:
                # output_images.append(
                #     estimate_to_image(estimate[..., n_slices // 2, :, :, :],
                #                       target.size()[-2:]).clone().detach() / target_norm)
                output_images.append(
                    estimate_to_image(estimate[..., n_slices // 2, :, :, :],
                                      target.size()[-2:]).clone().detach())
            else:
                #output_images.append(estimate_to_image(estimate, target.size()[-2:]).clone().detach() / target_norm)
                output_images.append(
                    estimate_to_image(estimate,
                                      target.size()[-2:]).clone().detach())

    output = torch.cat(output_images, 0)[:16].unsqueeze(1)
    target = torch.cat(target_images, 0)[:16].unsqueeze(1)
    #corrupted = torch.cat(corrupted_images, 0)[:16].unsqueeze(1)

    #print(corrupted.shape, target.shape)
    save_image(target, 'Target')
    #save_image(corrupted, 'Corrupted')
    save_image(output, 'Reconstruction')
    save_image(target - output, 'Error')
Ejemplo n.º 2
0
def evaluate(args, epoch, model, data_loader, writer):
    model.eval()
    mse_losses = []
    psnr_losses = []
    nmse_losses = []
    ssim_losses = []
    memory_allocated = []

    start = time.perf_counter()
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            y, mask, target, metadata = data[:4]
            y = y.to(args.device)
            mask = mask.to(args.device)
            target = target.to(args.device)
            if args.n_slices > 1:
                output = model.forward(y=y, mask=mask, metadata=metadata)
                output = estimate_to_image(output, args.resolution)
                output_np = output.to('cpu').transpose(0, -4).squeeze(-4)
                del output
            else:
                y = y.transpose(0, -4).squeeze(-4)
                mask = mask.squeeze(-4).repeat(y.size(0), 1, 1, 1, 1)
                metadata = metadata.repeat(y.size(0), 1)
                output_np = []
                for k, l in zip(range(0, y.size(0), args.batch_size),
                                range(args.batch_size, y.size(0) + args.batch_size, args.batch_size)):
                    output = model.forward(y=y[k:l], mask=mask[k:l], metadata=metadata[k:l])
                    output = estimate_to_image(output, args.resolution)
                    output_np.append(output.to('cpu'))
                output_np = torch.cat(output_np, 0)

            output_np = output_np.reshape(-1, output_np.size(-2), output_np.size(-1))
            target = target.reshape_as(output_np)

            output_np = output_np.to('cpu').numpy()
            target_np = target.to('cpu').numpy()
            mse_losses.append(numpy_eval.mse(target_np, output_np))
            psnr_losses.append(numpy_eval.psnr(target_np, output_np))
            nmse_losses.append(numpy_eval.nmse(target_np, output_np))
            ssim_losses.append(numpy_eval.ssim(target_np, output_np))

            if args.device == 'cuda':
                memory_allocated.append(torch.cuda.max_memory_allocated() * 1e-6)
                torch.cuda.reset_max_memory_allocated()

            del data, y, mask, target, metadata
            torch.cuda.empty_cache()

        writer.add_scalar('Val_MSE', np.mean(mse_losses), epoch)
        writer.add_scalar('Val_PSNR', np.mean(psnr_losses), epoch)
        writer.add_scalar('Val_NMSE', np.mean(nmse_losses), epoch)
        writer.add_scalar('Val_SSIM', np.mean(ssim_losses), epoch)
        writer.add_scalar('Val_memory', np.max(memory_allocated), epoch)

    return np.mean(nmse_losses), np.mean(psnr_losses), np.mean(mse_losses), np.mean(ssim_losses), \
           time.perf_counter() - start, np.max(memory_allocated)
Ejemplo n.º 3
0
def evaluate(args, epoch, model, data_loader, writer, fnaf_mask=None):
    model.eval()
    mse_losses = []
    psnr_losses = []
    nmse_losses = []
    ssim_losses = []
    memory_allocated = []

    fnaf_losses = []

    start = time.perf_counter()
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            if not args.fnaf_eval:
                y, mask, target, metadata = data[:4]
                y = y.to(args.device)
                mask = mask.to(args.device)
                target = target.to(args.device)
                if args.n_slices > 1:
                    output = model.forward(y=y, mask=mask, metadata=metadata)
                    output = estimate_to_image(output, args.resolution)
                    output_np = output.to('cpu').transpose(0, -4).squeeze(-4)
                    del output
                else:
                    y = y.transpose(0, -4).squeeze(-4)
                    mask = mask.squeeze(-4).repeat(y.size(0), 1, 1, 1, 1)
                    metadata = metadata.repeat(y.size(0), 1)
                    output_np = []
                    for k, l in zip(
                            range(0, y.size(0), args.batch_size),
                            range(args.batch_size,
                                  y.size(0) + args.batch_size,
                                  args.batch_size)):
                        output = model.forward(y=y[k:l],
                                               mask=mask[k:l],
                                               metadata=metadata[k:l])
                        output = estimate_to_image(output, args.resolution)
                        output_np.append(output.to('cpu'))
                    output_np = torch.cat(output_np, 0)

                output_np = output_np.reshape(-1, output_np.size(-2),
                                              output_np.size(-1))
                target = target.reshape_as(output_np)

                output_np = output_np.to('cpu').numpy()
                target_np = target.to('cpu').numpy()
                mse_losses.append(numpy_eval.mse(target_np, output_np))
                psnr_losses.append(numpy_eval.psnr(target_np, output_np))
                nmse_losses.append(numpy_eval.nmse(target_np, output_np))
                ssim_losses.append(numpy_eval.ssim(target_np, output_np))

            if args.fnaf_eval:
                loss_f = batch_nmse

                y, mask, target, metadata, target_norm, target_max = data
                labels = target, metadata, target_norm, y
                # fnaf_loss = get_attack_loss(model, labels,
                #                loss_f=loss_f,
                #                     xs=np.random.randint(low=100, high=320-100, size=(ori_image.size(0),)),
                #                    ys=np.random.randint(low=100, high=320-100, size=(ori_image.size(0),)),
                #                     shape=(320, 320), n_pixel_range=(10, 11), vis=vis)
                vis = False
                n_pixel_range = (10, 11)
                #n_pixel_range = (10, 101)
                #n_pixel_range = (10, 1001)
                #n_pixel_range = (10, 5001)

                fnaf_loss_list = []

                for _ in range(11):
                    fnaf_loss = get_attack_loss(
                        args,
                        model,
                        labels,
                        loss_f=loss_f,
                        xs=np.random.randint(low=100 + 24,
                                             high=368 - (100 + 24),
                                             size=(y.size(0), )),
                        ys=np.random.randint(low=100 + 24,
                                             high=368 - (100 + 24),
                                             size=(y.size(0), )),
                        shape=(368, 368),
                        n_pixel_range=n_pixel_range,
                        fnaf_mask=fnaf_mask,
                        vis=vis)

                    fnaf_loss_list.append(fnaf_loss.cpu().numpy())

                fnaf_loss = np.max(fnaf_loss_list, axis=0)
                #print(fnaf_loss)
                fnaf_losses += list(fnaf_loss)

        if not args.fnaf_eval:
            writer.add_scalar('Val_MSE', np.mean(mse_losses), epoch)
            writer.add_scalar('Val_PSNR', np.mean(psnr_losses), epoch)
            writer.add_scalar('Val_NMSE', np.mean(nmse_losses), epoch)
            writer.add_scalar('Val_SSIM', np.mean(ssim_losses), epoch)
            writer.add_scalar('Val_memory', np.max(memory_allocated), epoch)
    if args.fnaf_eval:
        out = fnaf_losses
    else:
        out = np.mean(nmse_losses), np.mean(psnr_losses), np.mean(mse_losses), np.mean(ssim_losses), \
           time.perf_counter() - start, np.max(memory_allocated)

    return out
Ejemplo n.º 4
0
def train_epoch(args,
                epoch,
                model,
                train_loader,
                optimizer,
                writer,
                fnaf_iterloader=None,
                fnaf_loader=None,
                fnaf_mask=None):
    model.train()
    avg_loss = 0.
    start_epoch = start_iter = time.perf_counter()
    global_step = epoch * len(train_loader)

    #loss_f = torch.nn.MSELoss(reduction='none')
    loss_f = nmse

    memory_allocated = []
    if args.fnaf_train:
        loader = zip(train_loader, fnaf_loader)
    else:
        loader = train_loader
    for i, data in enumerate(loader):
        if args.fnaf_train:
            data, batch = data
        if args.bbox_root:
            (y, mask, target, metadata, target_norm, target_max), seg = data
        else:
            y, mask, target, metadata, target_norm, target_max = data
        y = y.to(args.device)
        mask = mask.to(args.device)
        target = target.to(args.device)
        target_norm = target_norm.to(args.device)
        target_max = target_max.to(args.device)

        optimizer.zero_grad()
        model.zero_grad()
        estimate = model.forward(y=y, mask=mask, metadata=metadata)
        if isinstance(estimate, list):
            loss = [
                image_loss(e, target, args, target_norm, target_max)
                for e in estimate
            ]
            loss = sum(loss) / len(loss)
        else:
            loss = image_loss(estimate, target, args, target_norm, target_max)

        if args.bbox_root:
            writer.add_scalar('SSIM_Loss', loss.item(), global_step + i)

            bbox_loss = []
            for j in range(11):
                seg_mask = seg[:, :, :, j]
                if seg_mask.sum() > 0:
                    seg_mask = seg_mask.to(args.device)
                    bbox_output = estimate_to_image(estimate,
                                                    args.resolution) * seg_mask
                    bbox_target = target * seg_mask
                    bbox_loss.append(nmse(bbox_target, bbox_output))

            if bbox_loss:
                bbox_loss = 10 * torch.stack(bbox_loss).mean()
                #print(loss.item(), bbox_loss.item())
                writer.add_scalar('BBOX_Loss', bbox_loss.item(),
                                  global_step + i)
                loss += bbox_loss

        loss.backward()
        optimizer.step()

        avg_loss = 0.99 * avg_loss + 0.01 * loss.item(
        ) if i > 0 else loss.item()
        writer.add_scalar('Loss', loss.item(), global_step + i)

        if args.fnaf_train:
            optimizer.zero_grad()
            model.zero_grad()
            y, mask, target, metadata, target_norm, target_max = batch
            labels = target, metadata, target_norm, y
            # fnaf_loss = get_attack_loss(model, labels,
            #                loss_f=loss_f,
            #                     xs=np.random.randint(low=100, high=320-100, size=(ori_image.size(0),)),
            #                    ys=np.random.randint(low=100, high=320-100, size=(ori_image.size(0),)),
            #                     shape=(320, 320), n_pixel_range=(10, 11), vis=vis)
            vis = False
            #n_pixel_range = (10, 11)
            #n_pixel_range = (10, 101)
            #n_pixel_range = (10, 1001)
            n_pixel_range = (10, 5001)
            fnaf_loss = get_attack_loss(
                args,
                model,
                labels,
                loss_f=loss_f,
                xs=np.random.randint(low=100 + 24,
                                     high=368 - (100 + 24),
                                     size=(y.size(0), )),
                ys=np.random.randint(low=100 + 24,
                                     high=368 - (100 + 24),
                                     size=(y.size(0), )),
                shape=(368, 368),
                n_pixel_range=n_pixel_range,
                fnaf_mask=fnaf_mask,
                vis=vis)

            #fnaf_loss = 10000000 * fnaf_loss
            fnaf_loss = 10 * fnaf_loss
            #print(loss, fnaf_loss)
            writer.add_scalar('FNAF_Loss', fnaf_loss.item(), global_step + i)

            fnaf_loss.backward()
            optimizer.step()

        if args.device == 'cuda':
            memory_allocated.append(torch.cuda.max_memory_allocated() * 1e-6)
            torch.cuda.reset_max_memory_allocated()
            torch.cuda.empty_cache()
        gc.collect()

        if i % args.report_interval == 0:
            logging.info(
                f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] '
                f'Iter = [{i:4d}/{len(train_loader):4d}] '
                f'Loss = {loss.detach().item():.4g} Avg Loss = {avg_loss:.4g} '
                f'Time = {time.perf_counter() - start_iter:.4f}s '
                f'Memory allocated (MB) = {np.min(memory_allocated):.2f}')
            memory_allocated = []
        start_iter = time.perf_counter()
    optimizer.zero_grad()

    return avg_loss, time.perf_counter() - start_epoch
Ejemplo n.º 5
0
def get_attack_loss(args,
                    model,
                    ori_target,
                    fnaf_mask,
                    loss_f=torch.nn.MSELoss(reduction='none'),
                    xs=np.random.randint(low=100, high=320 - 100),
                    ys=np.random.randint(low=100, high=320 - 100),
                    shape=(320, 320),
                    n_pixel_range=(10, 11),
                    vis=False):

    ori_target, metadata, target_norm, ori_input = ori_target

    input_o = transforms.complex_abs(ori_input.clone())

    p_max = input_o.max()
    #p_min = (p_max - input.min()) / 2
    #p_min = (p_max - input_o.min())
    p_min = (input_o.min())
    perturb_noise = [
        perturb_noise_init(x=x,
                           y=y,
                           shape=shape,
                           n_pixel_range=n_pixel_range,
                           pixel_value_range=(p_min, p_max))
        for x, y in zip(xs, ys)
    ]
    perturb_noise = np.stack(perturb_noise)

    # perturb the target to get the perturbed image
    #perturb_noise = np.expand_dims(perturb_noise, axis=0)
    #perturb_noise = np.stack((perturb_noise,)*ori_target.shape(0), -1)

    seed = np.random.randint(9999999)

    # normalizer = target_norm
    # for i in range(len(ori_target.size()) - 1):
    #     normalizer = normalizer.unsqueeze(-1)

    #target = ori_target / normalizer
    #print('target: ', target.max(), target.min())

    #perturb_noise = torch.stack([transforms.to_tensor(perturb_noise).unsqueeze(1)]*2, -1)
    perturb_noise = transforms.to_tensor(perturb_noise).unsqueeze(1)

    if not args.fnaf_eval_control:
        input_o += perturb_noise
    target = input_o.clone()

    #print(input_o.shape)
    input_o = np.complex64(input_o.numpy())
    input_o = transforms.to_tensor(input_o)
    input_o = transforms.fft2(input_o)
    input_o, mask = transforms.apply_mask(input_o, fnaf_mask, seed=seed)
    input_o = transforms.ifft2(input_o)

    # apply the perturbed image to the model to get the loss
    #print(input_o.shape)
    output = model.forward(y=input_o, mask=mask, metadata=metadata)
    #output = torch.zeros((8, 1, 368, 368, 2)).to(args.device)
    output = estimate_to_image(output, args.resolution)
    #output = output.reshape(-1, 1, output.size(-2), output.size(-1)).squeeze(1)

    #             output /= normalizer.cuda()
    #output, _, _ = transforms.normalize_instance(output, eps=1e-11)

    #output = transforms.normalize(output, mean, std, eps=1e-11)
    #output = output.clamp(-6, 6)

    #perturb_noise_tensor = transforms.to_tensor(perturb_noise).to(args.device, dtype=torch.double)
    perturb_noise = torch.stack([perturb_noise] * 2, -1)
    perturb_noise = estimate_to_image(perturb_noise, args.resolution).numpy()

    mask = adjusted_mask((perturb_noise != 0))
    #mask = (perturb_noise > 0).astype(np.double)

    mask = transforms.to_tensor(mask).to(args.device)

    #loss = loss_f((output.cpu()*mask_0), (transforms.to_tensor(target)*mask_0))
    target = torch.stack([target] * 2, -1)
    target = estimate_to_image(target, args.resolution).to(args.device)
    #     target /= normalizer
    #target, _, _ = transforms.normalize_instance(target, eps=1e-11)
    #     target = target.clamp(-6, 6)
    #target = transforms.normalize(output, mean, std, eps=1e-11)

    #loss = loss_f(target*mask, output*mask).sum() / torch.sum(mask)
    loss = loss_f(target * mask, output * mask)

    #loss = loss.mean(-1).mean(-1).cpu().numpy()
    #loss = loss.mean(-1).mean(-1).numpy()

    if vis and loss.max() >= 0.001:
        print('vis!')
        print(output.min(), output.max())
        print(target.min(), target.max())

    return loss