Exemple #1
0
def main(args):
    utils.seedme(args.seed)
    cudnn.benchmark = True
    os.system('mkdir -p {}'.format(args.outf))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    images_train, images_test, masks_train, masks_test = utils.load_seismic_data(
        args.root_dir, test_size=.2, random_state=args.seed)
    images_train, masks_train = utils.concatenate_hflips(
        images_train, masks_train, shuffle=True, random_state=args.seed)
    images_test, masks_test = utils.concatenate_hflips(images_test,
                                                       masks_test,
                                                       shuffle=True,
                                                       random_state=args.seed)

    # transform = transforms.Compose([utils.augment(), utils.ToTensor()])
    transform = transforms.Compose([utils.ToTensor()])
    dataset_train = utils.SegmentationDataset(images_train,
                                              masks_train,
                                              transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset_train,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             drop_last=True,
                                             num_workers=1)
    dataiter = utils.dataiterator(dataloader)

    netF = models.choiceF[args.archF](num_features=args.num_features_F,
                                      num_residuals=args.num_residuals,
                                      gated=args.gated,
                                      gate_param=args.gate_param).to(device)
    optimizerF = optim.Adam(netF.parameters(), lr=args.lr, amsgrad=True)
    loss_func = torch.nn.BCELoss()

    log = logger.LoggerBCE(args.outf,
                           netF,
                           torch.from_numpy(images_train),
                           torch.from_numpy(masks_train),
                           torch.from_numpy(images_test),
                           torch.from_numpy(masks_test),
                           bcefunc=loss_func,
                           device=device)

    for i in range(args.niter):
        optimizerF.zero_grad()
        images, masks = next(dataiter)
        images, masks = images.to(device), masks.to(device)
        masks_pred = netF(images)
        loss = loss_func(masks_pred, masks)
        loss.backward()
        optimizerF.step()

        if (i + 1) % args.nprint == 0:
            print "[{}/{}] | loss: {:.3f}".format(i + 1, args.niter,
                                                  loss.item())
            log.flush(i + 1)

            if (i + 1) > 5000:
                torch.save(netF.state_dict(),
                           '{}/netF_iter_{}.pth'.format(args.outf, i + 1))
Exemple #2
0
def main(args):
    utils.seedme(args.seed)
    cudnn.benchmark = True
    device = torch.device(
        'cuda' if torch.cuda.is_available() and not args.nocuda else 'cpu')

    os.system('mkdir -p {}'.format(args.outf))

    img = utils.load_image(
        args.image, resize=args.resize)  # (channel, height, width), [-1,1]
    x0 = torch.from_numpy(img).unsqueeze(0).to(
        device)  # (1, channel, height, width), torch

    args.img = img
    args.nc = img.shape[0]

    x = models.X(image_size=args.syn_size,
                 nc=args.nc,
                 batch_size=args.batch_size).to(device)
    optimizer = optim.Adam(x.parameters(), lr=args.lr)

    netE = models.choose_archE(args).to(device)
    print netE

    mmdrq = mmd.MMDrq(nu=args.nu, encoder=netE)
    loss_func = utils.Loss(x0, mmdrq, args.patch_size, args.npatch)

    losses = []
    start_time = time.time()
    for i in range(args.niter):
        optimizer.zero_grad()

        x1 = x()
        loss = loss_func(x1).mean()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        if (i + 1) % 500 == 0:
            print '[{}/{}] loss: {}'.format(i + 1, args.niter, loss.item())
            fig, ax = plt.subplots()
            ax.plot(signal.medfilt(losses, 101)[50:-50])
            ax.set_yscale('symlog')
            fig.tight_layout()
            fig.savefig('{}/loss.png'.format(args.outf))
            plt.close(fig)
            logger.vutils.save_image(x1,
                                     '{}/x_{}.png'.format(args.outf, i + 1),
                                     normalize=True,
                                     nrow=10)
            print 'This round took {0} secs'.format(time.time() - start_time)
            start_time = time.time()

    np.save('{}/x1.npy'.format(args.outf), x1.detach().cpu().numpy().squeeze())
Exemple #3
0
def main(args):
    utils.seedme(args.seed)
    cudnn.benchmark = True
    device = torch.device('cuda' if torch.cuda.is_available() and not args.nocuda else 'cpu')

    img = utils.load_image(args.image, resize=args.resize) # (channel, height, width), [-1,1]
    x0 = torch.from_numpy(img).unsqueeze(0).to(device)  # (1, channel, height, width), torch

    args.img = img
    args.nc = img.shape[0]

    netG = models.choose_archG(args).to(device)
    netE = models.choose_archE(args).to(device)
    print netE
    print netG

    optimizer = optim.Adam(netG.parameters(), lr=args.lr, betas=(args.beta1, args.beta2), amsgrad=True)
    z = torch.randn(args.batch_size,args.nz,1,1).to(device)

    mmdrq = mmd.MMDrq(nu=args.nu, encoder=netE)
    loss_func = utils.Loss(x0, mmdrq, args.patch_size, args.npatch)

    log = logger.Logger(args, netG, netE)
    log.save_image(x0, 'ref.png')
    nstart, nend = log.nstart, log.nend

    start_time = time.time()
    for i in range(nstart, nend):
        optimizer.zero_grad()

        x1 = netG(z.normal_())
        loss = loss_func(x1).mean()
        ent = utils.sample_entropy(x1.view(x1.shape[0],-1))
        kl = loss - args.alpha*ent

        kl.backward()
        optimizer.step()

        # --- logging
        log.log(loss.item(), ent.item(), kl.item())
        if (i+1) % 500 == 0:
            print 'This round took {0} secs'.format(time.time()-start_time)
            start_time = time.time()
def main(args):
    utils.seedme(args.seed)
    cudnn.benchmark = True
    device = torch.device(
        'cuda' if torch.cuda.is_available() and not args.nocuda else 'cpu')

    os.system('mkdir -p {}'.format(args.outf))

    dataloader_train = utils.get_patchloader(args.image_train,
                                             resize=args.resize_train,
                                             patch_size=args.patch_size,
                                             batch_size=args.batch_size_train,
                                             fliplr=args.fliplr,
                                             flipud=args.flipud,
                                             rot90=args.rot90,
                                             smooth=args.smooth)
    if args.image_valid:
        dataloader_valid = utils.get_patchloader(
            args.image_valid,
            resize=args.resize_valid,
            patch_size=args.patch_size,
            batch_size=args.batch_size_valid,
            fliplr=args.fliplr,
            flipud=args.flipud,
            rot90=args.rot90,
            smooth=args.smooth)

    netG = models.DCGAN_G(image_size=args.patch_size,
                          nc=args.nc,
                          nz=args.ncode,
                          ngf=args.ngf).to(device)
    netE = models.Encoder(patch_size=args.patch_size,
                          nc=args.nc,
                          ncode=args.ncode,
                          ndf=args.ndf).to(device)

    print netG
    print netE

    optimizer = optim.Adam(list(netG.parameters()) + list(netE.parameters()),
                           lr=args.lr,
                           amsgrad=True)
    loss_func = nn.MSELoss()

    losses = []
    losses_valid = []
    best_loss = 1e16
    for i in range(args.niter):
        optimizer.zero_grad()
        x = next(dataloader_train).to(device)
        if args.sigma:
            x = utils.add_noise(x, args.sigma)
        y = netG(netE(x))
        loss = loss_func(y, x)
        loss.backward()
        optimizer.step()

        if args.image_valid:
            with torch.no_grad():
                netG.eval()
                netE.eval()
                x_ = next(dataloader_valid).to(device)
                if args.sigma:
                    x_ = utils.add_noise(x, args.sigma)
                y_ = netG(netE(x_))
                loss_valid = loss_func(y_, x_)
                netG.train()
                netE.train()
                losses_valid.append(loss_valid.item())

        _loss = loss_valid.item() if args.image_valid else loss.item()
        if _loss + 1e-3 < best_loss:
            best_loss = _loss
            print "[{}/{}] best loss: {}".format(i + 1, args.niter, best_loss)
            if args.save_best:
                torch.save(netE.state_dict(),
                           '{}/netD_best.pth'.format(args.outf))

        losses.append(loss.item())
        if (i + 1) % args.nprint == 0:
            if args.image_valid:
                print '[{}/{}] train: {}, test: {}, best: {}'.format(
                    i + 1, args.niter, loss.item(), loss_valid.item(),
                    best_loss)
            else:
                print '[{}/{}] train: {}, best: {}'.format(
                    i + 1, args.niter, loss.item(), best_loss)
            logger.vutils.save_image(torch.cat([x, y], dim=0),
                                     '{}/train_{}.png'.format(
                                         args.outf, i + 1),
                                     normalize=True)
            fig, ax = plt.subplots()
            ax.semilogy(scipy.signal.medfilt(losses, 11)[5:-5], label='train')
            if args.image_valid:
                logger.vutils.save_image(torch.cat([x_, y_], dim=0),
                                         '{}/test_{}.png'.format(
                                             args.outf, i + 1),
                                         normalize=True,
                                         nrow=32)
                ax.semilogy(scipy.signal.medfilt(losses_valid, 11)[5:-5],
                            label='valid')
            fig.legend()
            fig.savefig('{}/loss.png'.format(args.outf))
            plt.close(fig)
            torch.save(netE.state_dict(),
                       '{}/netD_iter_{}.pth'.format(args.outf, i + 1))
Exemple #5
0
def main(args):
    utils.seedme(args.seed)
    cudnn.benchmark = True
    os.system('mkdir -p {}'.format(args.outf))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print "Using BCE loss: {}".format(not args.no_bce)

    images_train, images_test, masks_train, masks_test = utils.load_seismic_data(args.root_dir, test_size=.2, random_state=args.seed)
    images_train, masks_train = utils.concatenate_hflips(images_train, masks_train, shuffle=True, random_state=args.seed)
    images_test, masks_test = utils.concatenate_hflips(images_test, masks_test, shuffle=True, random_state=args.seed)

    # transform = transforms.Compose([utils.augment(), utils.ToTensor()])
    transform = transforms.Compose([utils.ToTensor()])
    dataset_train = utils.SegmentationDataset(images_train, masks_train, transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=1)
    dataiter = utils.dataiterator(dataloader)

    netF = models.choiceF[args.archF](num_features=args.num_features_F, num_residuals=args.num_residuals, gated=args.gated, gate_param=args.gate_param).to(device)
    netD = models.choiceD[args.archD](num_features=args.num_features_D, nc=2, dropout=args.dropout).to(device)
    if args.netF:
        netF.load_state_dict(torch.load(args.netF))
    if args.netD:
        netD.load_state_dict(torch.load(args.netD))
    print netF
    print netD
    optimizerF = optim.Adam(netF.parameters(), betas=(0.5, 0.999), lr=args.lr, amsgrad=True)
    optimizerD = optim.Adam(netD.parameters(), betas=(0.5, 0.999), lr=args.lr, amsgrad=True)
    alpha = torch.tensor(args.alpha).to(device)
    loss_func = torch.nn.BCELoss()

    smooth_binary = utils.SmoothBinary(scale=args.smooth_noise)

    # images_test, masks_test = torch.from_numpy(images_test).to(device), torch.from_numpy(masks_test).to(device)
    log = logger.LoggerGAN(args.outf, netF, netD, torch.from_numpy(images_train), torch.from_numpy(masks_train), torch.from_numpy(images_test), torch.from_numpy(masks_test), bcefunc=loss_func, device=device)

    start_time = time.time()
    for i in range(args.niter):

        # --- train D
        for p in netD.parameters():
            p.requires_grad_(True)

        for _ in range(args.niterD):
            optimizerD.zero_grad()

            images_real, masks_real = next(dataiter)
            images_real, masks_real = images_real.to(device), masks_real.to(device)
            masks_fake = netF(images_real).detach()
            x_fake = torch.cat((images_real, masks_fake), dim=1)

            # images_real, masks_real = next(dataiter)
            # images_real, masks_real = images_real.to(device), masks_real.to(device)
            masks_real = smooth_binary(masks_real)
            x_real = torch.cat((images_real, masks_real), dim=1)

            x_real.requires_grad_()  # to compute gradD_real
            x_fake.requires_grad_()  # to compute gradD_fake

            y_real = netD(x_real)
            y_fake = netD(x_fake)
            lossE = y_real.mean() - y_fake.mean()

            # grad() does not broadcast so we compute for the sum, effect is the same
            gradD_real = torch.autograd.grad(y_real.sum(), x_real, create_graph=True)[0]
            gradD_fake = torch.autograd.grad(y_fake.sum(), x_fake, create_graph=True)[0]
            omega = 0.5*(gradD_real.view(gradD_real.size(0), -1).pow(2).sum(dim=1).mean() +
                         gradD_fake.view(gradD_fake.size(0), -1).pow(2).sum(dim=1).mean())

            loss = -lossE - alpha*(1.0 - omega) + 0.5*args.rho*(1.0 - omega).pow(2)
            loss.backward()
            optimizerD.step()
            alpha -= args.rho*(1.0 - omega.item())

        # --- train G
        for p in netD.parameters():
            p.requires_grad_(False)
        optimizerF.zero_grad()
        images_real, masks_real = next(dataiter)
        images_real, masks_real = images_real.to(device), masks_real.to(device)
        masks_fake = netF(images_real)
        x_fake = torch.cat((images_real, masks_fake), dim=1)
        y_fake = netD(x_fake)
        loss = -y_fake.mean()
        bceloss = loss_func(masks_fake, masks_real)
        if not args.no_bce:
            loss = loss + bceloss * args.bce_weight
        loss.backward()
        optimizerF.step()

        log.dump(i+1, lossE.item(), alpha.item(), omega.item())

        if (i+1) % args.nprint == 0:
            print 'Time per loop: {} sec/loop'.format((time.time() - start_time)/args.nprint)

            print "[{}/{}] lossE: {:.3f}, bceloss: {:.3f}, alpha: {:.3f}, omega: {:.3f}".format((i+1), args.niter, lossE.item(), bceloss.item(), alpha.item(), omega.item())

            log.flush(i+1)

            # if (i+1) > 5000:
            torch.save(netF.state_dict(), '{}/netF_iter_{}.pth'.format(args.outf, i+1))
            torch.save(netD.state_dict(), '{}/netD_iter_{}.pth'.format(args.outf, i+1))

            start_time = time.time()
Exemple #6
0
                    default=30,
                    help='size of latent vector z')
# --- netI params
parser.add_argument('--archI', default='FC_selu')
parser.add_argument('--netI', default=None)
parser.add_argument('--hidden_layer_size', type=int, default=512)
parser.add_argument('--num_extra_layers', type=int, default=2)
parser.add_argument('--nw',
                    type=int,
                    default=30,
                    help='size of latent vector w')
args = parser.parse_args()

os.system('mkdir -p {0}'.format(args.outdir))
utils.config_logging(logging, args.outdir)
utils.seedme(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

netG = getattr(models, args.archG)(image_size=args.image_size,
                                   nz=args.nz,
                                   image_depth=args.image_depth,
                                   num_filters=args.num_filters).to(device)
netG.load_state_dict(torch.load(args.netG))
for p in netG.parameters():
    p.requires_grad_(False)
netG.eval()
print netG

netI = getattr(models,
               args.archI)(input_size=args.nw,
                           output_size=args.nz,