Ejemplo n.º 1
0
    def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None):

        super(UNIT_Trainer, self).__init__()

        lr = hyperparameters['lr']

        # Initiate the networks.
        self.gen = VAEGen(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['gen'],
            hyperparameters['n_datasets'])  # Auto-encoder for domain a.
        self.dis = MsImageDis(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['dis'])  # Discriminator for domain a.

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        self.sup = UNet(input_channels=hyperparameters['input_dim'],
                        num_classes=2).cuda()

        # Setup the optimizers.
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis.parameters())
        gen_params = list(self.gen.parameters()) + list(self.sup.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization.
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))

        # Presetting one hot encoding vectors.
        self.one_hot_img = torch.zeros(hyperparameters['n_datasets'],
                                       hyperparameters['batch_size'],
                                       hyperparameters['n_datasets'], 256,
                                       256).cuda()
        self.one_hot_h = torch.zeros(hyperparameters['n_datasets'],
                                     hyperparameters['batch_size'],
                                     hyperparameters['n_datasets'], 64,
                                     64).cuda()

        for i in range(hyperparameters['n_datasets']):
            self.one_hot_img[i, :, i, :, :].fill_(1)
            self.one_hot_h[i, :, i, :, :].fill_(1)

        if resume_epoch != -1:

            self.resume(snapshot_dir, hyperparameters)
Ejemplo n.º 2
0
    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'UNet':
            self.unet = UNet(n_channels=1, n_classes=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(
                img_ch=1, output_ch=1,
                t=self.t)  # TODO: changed for green image channel
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=1, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'Iternet':
            self.unet = Iternet(n_channels=1, n_classes=1)
        elif self.model_type == 'AttUIternet':
            self.unet = AttUIternet(n_channels=1, n_classes=1)
        elif self.model_type == 'R2UIternet':
            self.unet = R2UIternet(n_channels=3, n_classes=1)
        elif self.model_type == 'NestedUNet':
            self.unet = NestedUNet(in_ch=1, out_ch=1)

        self.optimizer = optim.Adam(list(self.unet.parameters()),
                                    self.lr,
                                    betas=tuple(self.beta_list))
        self.unet.to(self.device)
Ejemplo n.º 3
0
def main():
    net = UNet(dtype, image_size=opt.image_size).type(dtype)
    print(net)

    if opt.test:
        runTest(net)
        return

    physical_loss = PhysicalLoss()
    optimizer = optim.Adam(net.parameters(), lr=opt.learning_rate)

    fixed_sample_0, fixed_solution_0, fixed_sample_1, fixed_solution_1 = makeSamples(
        opt.image_size)

    data = torch.zeros(opt.batch_size, 1, opt.image_size, opt.image_size)
    print("Training Started")
    for epoch in range(opt.epochs):
        mean_loss = 0
        for sample in range(opt.epoch_size):
            data[:, :, :, 0] = np.random.uniform(100)
            data[:, :, 0, :] = np.random.uniform(100)
            data[:, :, :, -1] = np.random.uniform(100)
            data[:, :, -1, :] = np.random.uniform(100)
            img = Variable(data).type(dtype)
            output = net(img)
            loss = physical_loss(output)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            mean_loss += loss.data[0]
        mean_loss /= opt.epoch_size
        print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, opt.epochs,
                                                  mean_loss))

        plotSamples(fixed_solution_0, net(fixed_sample_0), fixed_solution_1,
                    net(fixed_sample_1), opt.experiment, epoch)

        # checkpoint networks
        if epoch + 1 % 50 == 0:
            torch.save(net.state_dict(),
                       '%s/net_epoch_%d.pth' % (opt.experiment, epoch + 1))

    print("Training Complete")
    torch.save(net.state_dict(),
               '%s/net_epoch_%d.pth' % (opt.experiment, epoch + 1))
    print("Network Weights Saved in %s" % opt.experiment)
    def __init__(self, maxiter=300, tol=1e-6, restart=50):
        super(Wiener_KPN_SA, self).__init__()
        self.maxiter = maxiter
        self.tol = tol
        self.restart = restart
        self.cg_solver = ConugateGradient_Function
        self.info = None

        self.reg_weight = nn.Parameter(torch.Tensor([0.]))  #.double())
        self.model = UNet(mode='none', n_channels=1, n_classes=9)
Ejemplo n.º 5
0
    def __init__(self):
        '''
        Deconvolution function for a batch of images. Although the regularization
        term does not have a shape of Tikhonov regularizer, with a slight abuse of notations
        the function is called WienerUNet.

        The function is built upon the iterative gradient descent scheme:

        x_k+1 = x_k - lamb[K^T(Kx_k - y) + exp(alpha)*reg(x_k)]

        Initial parameters are:
        regularizer: a neural network to parametrize the prior on each iteration x_k.
        alpha: power of the trade-off coefficient 
        lamb: step of the gradient descent algorithm
        '''
        super(WienerUNet, self).__init__()
        self.regularizer = UNet(mode='instance')
        self.alpha = nn.Parameter(torch.FloatTensor([0.0]))
        self.lamb = nn.Parameter(torch.FloatTensor([0.3]))
Ejemplo n.º 6
0
                    self.bad += 1
                    print("bad ++")
            except ValueError:
                torch.save(self.net.state_dict(), str(self.save_weight_path))
            self.val_losses.append(val_loss)
        else:
            print("loss is too large. Continue train")
            self.val_losses.append(val_loss)
        print("bad = {}".format(self.bad))
        self.epoch_loss = 0


if __name__ == "__main__":
    args = parse_args()

    args.train_path = [Path(args.train_path)]
    args.val_path = [Path(args.val_path)]
    # save weight path
    args.weight_path = Path(args.weight_path)

    # define model
    net = UNet(n_channels=1, n_classes=1)
    if args.gpu:
        net.cuda()

    args.net = net

    train = TrainNet(args)

    train.main()
Ejemplo n.º 7
0
mask_dirs1 = ["../MedTest_Masked"]

train_transforms = transforms.Compose([transforms.ToTensor()])

data = SegDataset("train_data", transform=train_transforms)
#print("THe size of the complete dataset is {}".format(data.__len__()))

train_loader, test_loader = load_train_test(data, valid_size=0.01)

#print("The size of training examples is : {}".format(len(train_loader.dataset)))
#print("The size of testing examples is : {}".format(len(test_loader.dataset)))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = UNet(num_class).to(device)
#summary(model,(1,512,512))
# d = open('sample.json', 'w+')

optimizer = optim.Adam(model.parameters(), lr=0.0001)

#def weights_init(m):
#    if isinstance(m, nn.Conv2d):
#        xavier_uniform(m.weight.data)
#        xavier_uniform(m.bias.data)

#model.apply(weights_init)

model.load_state_dict(torch.load("../outputs12/checkpoints/ckpt_0_0.pth"))
transforms1 = transforms.Compose([transforms.ToTensor()])
Ejemplo n.º 8
0
train_transforms = transforms.Compose([transforms.ToTensor()])

data = SegDataset(image_dirs, mask_dir=mask_dirs, transform=train_transforms)
print("THe size of the complete dataset is {}".format(data.__len__()))

train_loader, test_loader = load_train_test(data, valid_size=0.01)

print("The size of training examples is : {}".format(len(
    train_loader.dataset)))
print("The size of testing examples is : {}".format(len(test_loader.dataset)))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model = UNet(num_class).to(device)
summary(model, (1, 512, 512))

d = open('sample.json', 'w+')

# Observe that all parameters are being optimized
optimizer = optim.Adam(model.parameters(), lr=4e-4)

#model.load_state_dict(torch.load("../outputs5/checkpoints/ckpt_0_31.pth"))
transforms1 = transforms.Compose(
    [transforms.Resize((512, 512)),
     transforms.ToTensor()])

test_data = SegDataset1(test_dirs,
                        mask_dir=mask_dirs1,
                        transform=transforms1,
Ejemplo n.º 9
0
def test_unet(root,\
              psf_path,
              method,\
              scale, \
              model_path,\
              visual, \
              use_gpu,
              b_size=1):
    '''
    Model UNet
    '''

    model_name = method + '_poisson'
    save_images_path = './Results/' + model_name + '_peak_' + str(
        int(scale)) + '/'

    test_dataset = CellDataset(root, psf_path, 'poisson', scale, 0.0)
    test_loader = DataLoader(test_dataset,
                             batch_size=b_size,
                             shuffle=False,
                             num_workers=1)

    model = UNet(mode='batch')
    state_dict = torch.load(os.path.join(model_path, model_name))
    state_dict = state_dict['model_state_dict']
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        new_state_dict[k] = v

    model.load_state_dict(new_state_dict)

    if use_gpu == 1:
        model.cuda()

    model.eval()

    psnr_values_test = []
    ssim_values_test = []

    distorted_psnr_test = []
    distorted_ssim_test = []

    with torch.no_grad():
        for i_batch, ((gt, image), psf, index, image_name, peak,
                      _) in enumerate(tqdm(test_loader)):
            image = image.reshape(
                (b_size, 1, image.shape[-2], image.shape[-1]))

            gt = gt.reshape((b_size, 1, gt.shape[-2], gt.shape[-1]))

            if use_gpu == 1:
                image = image.cuda()
                gt = gt.cuda()

            for l in range(gt.shape[0]):
                image[l] = image[l] / gt[l].max()
                gt[l] /= gt[l].max()

            output = model(image)

            distorted_psnr = calc_psnr(image.clamp(0, 1), gt)
            distorted_ssim = ssim(image.clamp(0, 1), gt)

            psnr_test = calc_psnr(output.clamp(0, 1), gt)
            s_sim_test = ssim(output.clamp(0, 1), gt)

            psnr_values_test.append(psnr_test.item())
            ssim_values_test.append(s_sim_test.item())

            distorted_psnr_test.append(distorted_psnr.item())
            distorted_ssim_test.append(distorted_ssim.item())

            #Save image
            if visual == 1:

                if not os.path.exists(save_images_path):
                    os.makedirs(save_images_path, exist_ok=True)

                io.imsave(os.path.join(save_images_path, 'output_' + str(image_name[0][:-4]) + '_' + \
                          str(model_name) + '_' + str(int(scale)) + '.png'), np.uint8(output[0][0].detach().cpu().numpy().clip(0,1) * 255.))


    print('Test on Poisson noise with peak %d: PSNR %.2f, SSIM %.4f, distorted PSNR %.2f, distorted SSIM %.4f' % (peak, np.array(psnr_values_test).mean(), \
                                                                                              np.array(ssim_values_test).mean(), \
                                                                                              np.array(distorted_psnr_test).mean(), \
                                                                                              np.array(distorted_ssim_test).mean()))

    return
Ejemplo n.º 10
0
            self.cal_tp_fp_fn(ori, gt_img, pre_img, i)
        if self.tps == 0:
            f_measure = 0
        else:
            recall = self.tps / (self.tps + self.fns)
            precision = self.tps / (self.tps + self.fps)
            f_measure = (2 * recall * precision) / (recall + precision)

        print(precision, recall, f_measure)
        with self.save_txt_path.open(mode="a") as f:
            f.write("%f,%f,%f\n" % (precision, recall, f_measure))


if __name__ == "__main__":
    args = parse_args()

    args.input_path = Path(args.input_path)
    args.output_path = Path(args.output_path)

    net = UNet(n_channels=1, n_classes=1)
    net.load_state_dict(torch.load(args.weight_path, map_location="cpu"))

    if args.gpu:
        net.cuda()
    args.net = net

    pred = PredictFmeasure(args)

    pred.main()
Ejemplo n.º 11
0
def train(params, args, world_rank):
    logging.info('rank %d, begin data loader init' % world_rank)
    train_data_loader = get_data_loader_distributed(params, world_rank)
    test_data_loader = get_data_loader_distributed_test(params, world_rank)
    logging.info('rank %d, data loader initialized' % world_rank)
    model = UNet.UNet(params).cuda()
    if not args.resuming:
        model.apply(model.get_weights_function(params.weight_init))

    optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision
    if params.distributed:
        model = DistributedDataParallel(model)

    iters = 0
    startEpoch = 0
    checkpoint = None
    if args.resuming:
        if world_rank == 0:
            logging.info("Loading checkpoint %s" % params.checkpoint_path)
        checkpoint = torch.load(params.checkpoint_path,
                                map_location='cuda:{}'.format(args.local_rank))
        model.load_state_dict(checkpoint['model_state'])
        iters = checkpoint['iters']
        startEpoch = checkpoint['epoch'] + 1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if world_rank == 0:
        logging.info(model)
        logging.info("Starting Training Loop...")

    device = torch.cuda.current_device()
    for epoch in range(startEpoch, startEpoch + params.num_epochs):
        start = time.time()
        tr_time = 0.
        log_time = 0.

        for i, data in enumerate(train_data_loader, 0):
            iters += 1
            adjust_LR(optimizer, params, iters)
            inp, tar = map(lambda x: x.to(device), data)
            tr_start = time.time()
            b_size = inp.size(0)

            model.zero_grad()
            gen = model(inp)
            loss = UNet.loss_func(gen, tar, params)

            loss.backward()  # fixed precision

            # automatic mixed precision:
            #with amp.scale_loss(loss, optimizer) as scaled_loss:
            #  scaled_loss.backward()

            optimizer.step()

            tr_end = time.time()
            tr_time += tr_end - tr_start

        # Output training stats
        if world_rank == 0:
            log_start = time.time()
            gens = []
            tars = []
            with torch.no_grad():
                for i, data in enumerate(test_data_loader, 0):
                    if i >= 50:
                        break
                    inp, tar = map(lambda x: x.to(device), data)
                    gen = model(inp)
                    gens.append(gen.detach().cpu().numpy())
                    tars.append(tar.detach().cpu().numpy())
            gens = np.concatenate(gens, axis=0)
            tars = np.concatenate(tars, axis=0)

            # Scalars
            args.tboard_writer.add_scalar('G_loss', loss.item(), iters)

            # Plots
            fig = plot_gens_tars(gens, tars)
            #fig, chi, L1score = meanL1(gens, tars)
            #args.tboard_writer.add_figure('pixhist', fig, iters, close=True)
            #args.tboard_writer.add_scalar('Metrics/chi', chi, iters)
            #args.tboard_writer.add_scalar('Metrics/rhoL1', L1score[0], iters)
            #args.tboard_writer.add_scalar('Metrics/vxL1', L1score[1], iters)
            #args.tboard_writer.add_scalar('Metrics/vyL1', L1score[2], iters)
            #args.tboard_writer.add_scalar('Metrics/vzL1', L1score[3], iters)
            #args.tboard_writer.add_scalar('Metrics/TL1', L1score[4], iters)
            #
            #fig = generate_images(inp.detach().cpu().numpy()[0], gens[-1], tars[-1])
            for figiter in range(5):
                figtag = 'test' + str(figiter)
                args.tboard_writer.add_figure(tag=figtag,
                                              figure=fig[figiter],
                                              close=True)
            #log_end = time.time()
            #log_time += log_end - log_start

            # Save checkpoint
            torch.save(
                {
                    'iters': iters,
                    'epoch': epoch,
                    'model_state': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, params.checkpoint_path)

        end = time.time()
        if world_rank == 0:
            logging.info('Time taken for epoch {} is {} sec'.format(
                epoch + 1, end - start))
            logging.info('train step time={}, logging time={}'.format(
                tr_time, log_time))
Ejemplo n.º 12
0
from pathlib import Path
from detection import TrainNet
from networks import UNet
from propagation import GuideCall

if __name__ == "__main__":
    torch.cuda.set_device(1)

    date = datetime.now().date()
    gpu = True
    key = 2

    weight_path = "./weight/best.pth"
    # image_path
    train_path = Path("./images/train")
    val_path = Path("./images/val")
    guided_input_path = sorted(train_path.joinpath("ori").glob("*.tif"))

    # guided output
    output_path = Path("output")

    # define model
    net = UNet(n_channels=1, n_classes=1)
    net.cuda()

    net.load_state_dict(
        torch.load(weight_path, map_location={"cuda:2": "cuda:0"}))

    bp = GuideCall(guided_input_path, output_path, net)
    bp.main()
Ejemplo n.º 13
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None):

        super(MUNIT_Trainer, self).__init__()

        lr = hyperparameters['lr']

        # Initiate the networks.
        self.gen = AdaINGen2(
            hyperparameters['input_dim'],
            hyperparameters['gen'])  # Auto-encoder for domain a.
        self.dis = NLayerDiscriminator(
            hyperparameters['input_dim'])  # Discriminator for domain a.
        self.dis2 = NLayerDiscriminator(3 * hyperparameters['input_dim'],
                                        n_layers=4)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.beta1 = hyperparameters['beta1']
        self.beta2 = hyperparameters['beta2']
        self.weight_decay = hyperparameters['weight_decay']

        # Initiating and loader pretrained UNet.
        self.sup = UNet(input_channels=hyperparameters['input_dim'],
                        num_classes=3).cuda()

        # Fix the noise used in sampling.
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers.
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']

        dis_params = list(self.dis.parameters())
        dis2_params = list(self.dis2.parameters())
        gen_params = list(self.gen.parameters()) + list(self.sup.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis2_opt = torch.optim.Adam(
            [p for p in dis2_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.dis2_scheduler = get_scheduler(self.dis2_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization.
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))
        self.dis2.apply(weights_init('gaussian'))
        # Presetting one hot encoding vectors.
        self.one_hot_img = torch.zeros(hyperparameters['n_datasets'],
                                       hyperparameters['batch_size'],
                                       hyperparameters['n_datasets'], 256,
                                       256).cuda()
        self.one_hot_c = torch.zeros(hyperparameters['n_datasets'],
                                     hyperparameters['batch_size'],
                                     hyperparameters['n_datasets'], 64,
                                     64).cuda()

        for i in range(hyperparameters['n_datasets']):
            self.one_hot_img[i, :, i, :, :].fill_(1)
            self.one_hot_c[i, :, i, :, :].fill_(1)

        if resume_epoch != -1:

            self.resume(snapshot_dir, hyperparameters, resume_epoch)

    def recon_criterion(self, input, target):

        return torch.mean(torch.abs(input - target))

    def semi_criterion(self, input, target):

        loss = CrossEntropyLoss2d(size_average=True).cuda()
        return loss(input, target)

    def forward(self, x_a, x_b):

        self.eval()

        x_a.volatile = True
        x_b.volatile = True

        s_a = Variable(self.s_a, volatile=True)
        s_b = Variable(self.s_b, volatile=True)

        c_a, s_a_fake = self.gen.encode(x_a)
        c_b, s_b_fake = self.gen.encode(x_b)

        x_ba = self.gen.decode(c_b, s_a)
        x_ab = self.gen.decode(c_a, s_b)

        self.train()

        return x_ab, x_ba

    def set_gen_trainable(self, train_bool):

        if train_bool:
            self.gen.train()
            for param in self.gen.parameters():
                param.requires_grad = True

        else:
            self.gen.eval()
            for param in self.gen.parameters():
                param.requires_grad = True

    def set_sup_trainable(self, train_bool):

        if train_bool:
            self.sup.train()
            for param in self.sup.parameters():
                param.requires_grad = True
        else:
            self.sup.eval()
            for param in self.sup.parameters():
                param.requires_grad = True
##################################################################################
# Mainly adapted from https://github.com/hugo-oliveira/CoDAGANs ##################
##################################################################################

    def sup_update(self, x_a, x_b, y_a, y_b, d_index_a, d_index_b, use_a,
                   use_b, ep, hyperparameters):

        self.gen_opt.zero_grad()

        # temp_open=hyperparameters['temp_open']
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        c_a, s_a_prime = self.gen.encode(x_a)
        c_b, s_b_prime = self.gen.encode(x_b)

        x_ba = self.gen.decode(c_b, s_a)
        x_ab = self.gen.decode(c_a, s_b)

        c_b_recon, s_a_recon = self.gen.encode(x_ba)
        c_a_recon, s_b_recon = self.gen.encode(x_ab)

        # Forwarding through supervised model.
        p_a = None
        p_b = None
        loss_semi_a = None
        loss_semi_b = None
        # if temp_open==1:
        c_a = c_a.detach()
        c_b = c_b.detach()
        c_b_recon = c_b_recon.detach()
        c_a_recon = c_a_recon.detach()

        p_a = self.sup(c_a, use_a, True)
        p_a_recon = self.sup(c_a_recon, use_a, True)
        p_b = self.sup(c_b, use_a, True)
        p_b_recon = self.sup(c_b_recon, use_a, True)

        loss_semi_a = self.semi_criterion(p_a, y_a[use_a, :, :]) + \
                          self.semi_criterion(p_a_recon, y_a[use_a, :, :])
        if (ep + 1) > 10:
            loss_gen_b = self.dis2.calc_gen_loss(
                p_b) + self.dis2.calc_gen_loss(p_b_recon)
        else:
            loss_gen_b = Variable(torch.tensor(0).cuda(), requires_grad=False)
        self.loss_gen_total = None
        weight_temp = hyperparameters['weight_temp']
        if loss_semi_a is not None:
            self.loss_gen_total = hyperparameters[
                'recon_x_w'] * loss_semi_a + weight_temp * loss_gen_b
            seg_loss = hyperparameters['recon_x_w'] * loss_semi_a
            seg_gen_loss = weight_temp * loss_gen_b
        if self.loss_gen_total is not None:
            self.loss_gen_total.backward()
            self.gen_opt.step()
        return seg_loss.item(), seg_gen_loss.item()

    def sup_forward(self, x, y, d_index, hyperparameters):

        self.sup.eval()

        # Encoding content image.
        content, _ = self.gen.encode(x)

        # Forwarding on supervised model.
        y_pred = self.sup(content, only_prediction=True)

        # Computing metrics.
        pred = y_pred.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()

        jacc, jacc_cup = jaccard(pred, y.cpu().squeeze(0).numpy())

        return jacc, jacc_cup, pred, content

    def gen_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.gen_opt.zero_grad()

        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        # Encode.

        c_a, s_a_prime = self.gen.encode(x_a)
        c_b, s_b_prime = self.gen.encode(x_b)

        # Decode (within domain).
        x_a_recon = self.gen.decode(c_a, s_a_prime)
        x_b_recon = self.gen.decode(c_b, s_b_prime)

        # Decode (cross domain).
        x_ba = self.gen.decode(c_b, s_a)
        x_ab = self.gen.decode(c_a, s_b)

        # Encode again.
        c_b_recon, s_a_recon = self.gen.encode(x_ba)
        c_a_recon, s_b_recon = self.gen.encode(x_ab)

        # Decode again (if needed).
        x_aba = self.gen.decode(c_a_recon, s_a_prime)
        x_bab = self.gen.decode(c_b_recon, s_b_prime)

        # Reconstruction loss.
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)

        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a)
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b)

        # GAN loss.
        self.loss_gen_adv_a = self.dis.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis.calc_gen_loss(x_ab)

        # Total loss.
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b

        self.loss_gen_total.backward()
        self.gen_opt.step()
        return self.loss_gen_total.item()

    def compute_vgg_loss(self, vgg, img, target):

        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def dis_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.dis_opt.zero_grad()

        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        # Encode.
        c_a, _ = self.gen.encode(x_a)
        c_b, _ = self.gen.encode(x_b)

        # Decode (cross domain).
        x_ba = self.gen.decode(c_b, s_a)
        x_ab = self.gen.decode(c_a, s_b)

        # D loss.
        self.loss_dis_a = self.dis.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \
                              hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()
        return self.loss_dis_total.item()

    def dis2_update(self, x_a, x_b, d_index_a, d_index_b, use_a, use_b,
                    hyperparameters):

        self.dis2_opt.zero_grad()

        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        # Encode.
        c_a, s_a_prime = self.gen.encode(x_a)
        c_b, s_b_prime = self.gen.encode(x_b)

        # Decode (within domain).
        x_a_recon = self.gen.decode(c_a, s_a_prime)
        x_b_recon = self.gen.decode(c_b, s_b_prime)

        # Decode (cross domain).
        x_ba = self.gen.decode(c_b, s_a)
        x_ab = self.gen.decode(c_a, s_b)

        # Encode again.
        c_b_recon, s_a_recon = self.gen.encode(x_ba)
        c_a_recon, s_b_recon = self.gen.encode(x_ab)

        p_b = self.sup(c_b, use_a, True)
        p_b_recon = self.sup(c_b_recon, use_a, True)
        p_a = self.sup(c_a, use_a, True)
        p_a_recon = self.sup(c_a_recon, use_a, True)

        self.loss_dis2_b = self.dis2.calc_dis_loss(
            p_b.detach(), p_a.detach()) + self.dis2.calc_dis_loss(
                p_b_recon.detach(), p_a_recon.detach())
        self.loss_dis2_total = hyperparameters['gan_w'] * self.loss_dis2_b

        self.loss_dis2_total.backward()
        self.dis2_opt.step()
        return self.loss_dis2_total.item()

    def update_learning_rate(self):

        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.dis2_scheduler is not None:
            self.dis2_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters, resume_epoch):

        print("--> " + checkpoint_dir)

        # Load generator.
        last_model_name = get_model_list(checkpoint_dir, "gen", resume_epoch)
        # print('\n',last_model_name)
        state_dict = torch.load(last_model_name)
        self.gen.load_state_dict(state_dict)
        epochs = int(last_model_name[-11:-3])

        # Load supervised model.
        last_model_name = get_model_list(checkpoint_dir, "sup", resume_epoch)
        state_dict = torch.load(last_model_name)
        self.sup.load_state_dict(state_dict)

        # Load discriminator.
        # last_model_name = get_model_list(checkpoint_dir, "dis", resume_epoch)
        # state_dict = torch.load(last_model_name)
        # self.dis.load_state_dict(state_dict)

        # # Load discriminator2.
        # last_model_name = get_model_list(checkpoint_dir, "dis2", resume_epoch)
        # state_dict = torch.load(last_model_name)
        # self.dis2.load_state_dict(state_dict)

        # # Load optimizers.
        # last_model_name = get_model_list(checkpoint_dir, "opt", resume_epoch)
        # state_dict = torch.load(last_model_name)
        # self.dis_opt.load_state_dict(state_dict['dis'])
        # self.dis2_opt.load_state_dict(state_dict['dis2'])
        # self.gen_opt.load_state_dict(state_dict['gen'])

        # for state in self.dis_opt.state.values():
        #     for k, v in state.items():
        #         if isinstance(v, torch.Tensor):
        #             state[k] = v.cuda()

        # for state in self.dis2_opt.state.values():
        #     for k, v in state.items():
        #         if isinstance(v, torch.Tensor):
        #             state[k] = v.cuda()

        # for state in self.gen_opt.state.values():
        #     for k, v in state.items():
        #         if isinstance(v, torch.Tensor):
        #             state[k] = v.cuda()

        # # Reinitilize schedulers.
        # self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, epochs)
        # self.dis2_scheduler = get_scheduler(self.dis2_opt, hyperparameters, epochs)
        # self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, epochs)

        # print('Resume from epoch %d' % epochs)
        # return epochs

    def save(self, snapshot_dir, epoch):

        # Save generators, discriminators, and optimizers.
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % epoch)
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % epoch)
        dis2_name = os.path.join(snapshot_dir, 'dis2_%08d.pt' % epoch)
        sup_name = os.path.join(snapshot_dir, 'sup_%08d.pt' % epoch)
        opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % epoch)

        torch.save(self.gen.state_dict(), gen_name)
        torch.save(self.dis.state_dict(), dis_name)
        torch.save(self.dis2.state_dict(), dis2_name)
        torch.save(self.sup.state_dict(), sup_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict(),
                'dis2': self.dis2_opt.state_dict()
            }, opt_name)
Ejemplo n.º 14
0
    def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None):

        super(MUNIT_Trainer, self).__init__()

        lr = hyperparameters['lr']

        # Initiate the networks.
        self.gen = AdaINGen2(
            hyperparameters['input_dim'],
            hyperparameters['gen'])  # Auto-encoder for domain a.
        self.dis = NLayerDiscriminator(
            hyperparameters['input_dim'])  # Discriminator for domain a.
        self.dis2 = NLayerDiscriminator(3 * hyperparameters['input_dim'],
                                        n_layers=4)

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.beta1 = hyperparameters['beta1']
        self.beta2 = hyperparameters['beta2']
        self.weight_decay = hyperparameters['weight_decay']

        # Initiating and loader pretrained UNet.
        self.sup = UNet(input_channels=hyperparameters['input_dim'],
                        num_classes=3).cuda()

        # Fix the noise used in sampling.
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers.
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']

        dis_params = list(self.dis.parameters())
        dis2_params = list(self.dis2.parameters())
        gen_params = list(self.gen.parameters()) + list(self.sup.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis2_opt = torch.optim.Adam(
            [p for p in dis2_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.dis2_scheduler = get_scheduler(self.dis2_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization.
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))
        self.dis2.apply(weights_init('gaussian'))
        # Presetting one hot encoding vectors.
        self.one_hot_img = torch.zeros(hyperparameters['n_datasets'],
                                       hyperparameters['batch_size'],
                                       hyperparameters['n_datasets'], 256,
                                       256).cuda()
        self.one_hot_c = torch.zeros(hyperparameters['n_datasets'],
                                     hyperparameters['batch_size'],
                                     hyperparameters['n_datasets'], 64,
                                     64).cuda()

        for i in range(hyperparameters['n_datasets']):
            self.one_hot_img[i, :, i, :, :].fill_(1)
            self.one_hot_c[i, :, i, :, :].fill_(1)

        if resume_epoch != -1:

            self.resume(snapshot_dir, hyperparameters, resume_epoch)
Ejemplo n.º 15
0
class UNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None):

        super(UNIT_Trainer, self).__init__()

        lr = hyperparameters['lr']

        # Initiate the networks.
        self.gen = VAEGen(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['gen'],
            hyperparameters['n_datasets'])  # Auto-encoder for domain a.
        self.dis = MsImageDis(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['dis'])  # Discriminator for domain a.

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        self.sup = UNet(input_channels=hyperparameters['input_dim'],
                        num_classes=2).cuda()

        # Setup the optimizers.
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis.parameters())
        gen_params = list(self.gen.parameters()) + list(self.sup.parameters())
        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization.
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))

        # Presetting one hot encoding vectors.
        self.one_hot_img = torch.zeros(hyperparameters['n_datasets'],
                                       hyperparameters['batch_size'],
                                       hyperparameters['n_datasets'], 256,
                                       256).cuda()
        self.one_hot_h = torch.zeros(hyperparameters['n_datasets'],
                                     hyperparameters['batch_size'],
                                     hyperparameters['n_datasets'], 64,
                                     64).cuda()

        for i in range(hyperparameters['n_datasets']):
            self.one_hot_img[i, :, i, :, :].fill_(1)
            self.one_hot_h[i, :, i, :, :].fill_(1)

        if resume_epoch != -1:

            self.resume(snapshot_dir, hyperparameters)

    def recon_criterion(self, input, target):

        return torch.mean(torch.abs(input - target))

    def semi_criterion(self, input, target):

        loss = CrossEntropyLoss2d(size_average=False).cuda()
        return loss(input, target)

    def forward(self, x_a, x_b):

        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        h_a, _ = self.gen_a.encode(x_a)
        h_b, _ = self.gen_b.encode(x_b)
        x_ba = self.gen_a.decode(h_b)
        x_ab = self.gen_b.decode(h_a)
        self.train()
        return x_ab, x_ba

    def __compute_kl(self, mu):

        # def _compute_kl(self, mu, sd):
        # mu_2 = torch.pow(mu, 2)
        # sd_2 = torch.pow(sd, 2)
        # encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
        # return encoding_loss

        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def set_gen_trainable(self, train_bool):

        if train_bool:
            self.gen.train()
            for param in self.gen.parameters():
                param.requires_grad = True

        else:
            self.gen.eval()
            for param in self.gen.parameters():
                param.requires_grad = True

    def set_sup_trainable(self, train_bool):

        if train_bool:
            self.sup.train()
            for param in self.sup.parameters():
                param.requires_grad = True
        else:
            self.sup.eval()
            for param in self.sup.parameters():
                param.requires_grad = True

    def sup_update(self, x_a, x_b, y_a, y_b, d_index_a, d_index_b, use_a,
                   use_b, hyperparameters):

        self.gen_opt.zero_grad()

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)
        h_a, n_a = self.gen.encode(one_hot_x_a)
        h_b, n_b = self.gen.encode(one_hot_x_b)

        # Decode (within domain).
        one_hot_h_a = torch.cat([h_a + n_a, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b = torch.cat([h_b + n_b, self.one_hot_h[d_index_b]], 1)
        x_a_recon = self.gen.decode(one_hot_h_a)
        x_b_recon = self.gen.decode(one_hot_h_b)

        # Decode (cross domain).
        one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1)
        one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_h_ba)
        x_ab = self.gen.decode(one_hot_h_ab)

        # Encode again.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        h_b_recon, n_b_recon = self.gen.encode(one_hot_x_ba)
        h_a_recon, n_a_recon = self.gen.encode(one_hot_x_ab)

        # Decode again (if needed).
        one_hot_h_a_recon = torch.cat(
            [h_a_recon + n_a_recon, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b_recon = torch.cat(
            [h_b_recon + n_b_recon, self.one_hot_h[d_index_b]], 1)
        x_aba = self.gen.decode(
            one_hot_h_a_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen.decode(
            one_hot_h_b_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # Forwarding through supervised model.
        p_a = None
        p_b = None
        loss_semi_a = None
        loss_semi_b = None

        has_a_label = (h_a[use_a, :, :, :].size(0) != 0)
        if has_a_label:
            p_a = self.sup(h_a, use_a, True)
            p_a_recon = self.sup(h_a_recon, use_a, True)
            loss_semi_a = self.semi_criterion(p_a, y_a[use_a, :, :]) + \
                          self.semi_criterion(p_a_recon, y_a[use_a, :, :])

        has_b_label = (h_b[use_b, :, :, :].size(0) != 0)
        if has_b_label:
            p_b = self.sup(h_b, use_b, True)
            p_b_recon = self.sup(h_b, use_b, True)
            loss_semi_b = self.semi_criterion(p_b, y_b[use_b, :, :]) + \
                          self.semi_criterion(p_b_recon, y_b[use_b, :, :])

        self.loss_gen_total = None
        if loss_semi_a is not None and loss_semi_b is not None:
            self.loss_gen_total = loss_semi_a + loss_semi_b
        elif loss_semi_a is not None:
            self.loss_gen_total = loss_semi_a
        elif loss_semi_b is not None:
            self.loss_gen_total = loss_semi_b

        if self.loss_gen_total is not None:
            self.loss_gen_total.backward()
            self.gen_opt.step()

    def sup_forward(self, x, y, d_index, hyperparameters):

        self.sup.eval()

        # Encoding content image.
        one_hot_x = torch.cat([x, self.one_hot_img[d_index, 0].unsqueeze(0)],
                              1)
        hidden, _ = self.gen.encode(one_hot_x)

        # Forwarding on supervised model.
        y_pred = self.sup(hidden, only_prediction=True)

        # Computing metrics.
        pred = y_pred.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()

        jacc = jaccard(pred, y.cpu().squeeze(0).numpy())

        return jacc, pred, hidden

    def gen_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.gen_opt.zero_grad()

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)
        h_a, n_a = self.gen.encode(one_hot_x_a)
        h_b, n_b = self.gen.encode(one_hot_x_b)

        # Decode (within domain).
        one_hot_h_a = torch.cat([h_a + n_a, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b = torch.cat([h_b + n_b, self.one_hot_h[d_index_b]], 1)
        x_a_recon = self.gen.decode(one_hot_h_a)
        x_b_recon = self.gen.decode(one_hot_h_b)

        # Decode (cross domain).
        one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1)
        one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_h_ba)
        x_ab = self.gen.decode(one_hot_h_ab)

        # Encode again.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        h_b_recon, n_b_recon = self.gen.encode(one_hot_x_ba)
        h_a_recon, n_a_recon = self.gen.encode(one_hot_x_ab)

        # Decode again (if needed).
        one_hot_h_a_recon = torch.cat(
            [h_a_recon + n_a_recon, self.one_hot_h[d_index_a]], 1)
        one_hot_h_b_recon = torch.cat(
            [h_b_recon + n_b_recon, self.one_hot_h[d_index_b]], 1)
        x_aba = self.gen.decode(
            one_hot_h_a_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen.decode(
            one_hot_h_b_recon
        ) if hyperparameters['recon_x_cyc_w'] > 0 else None

        # Reconstruction loss.
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_kl_a = self.__compute_kl(h_a)
        self.loss_gen_recon_kl_b = self.__compute_kl(h_b)
        self.loss_gen_cyc_x_a = self.recon_criterion(x_aba, x_a)
        self.loss_gen_cyc_x_b = self.recon_criterion(x_bab, x_b)
        self.loss_gen_recon_kl_cyc_aba = self.__compute_kl(h_a_recon)
        self.loss_gen_recon_kl_cyc_bab = self.__compute_kl(h_b_recon)

        # GAN loss.
        self.loss_gen_adv_a = self.dis.calc_gen_loss(one_hot_x_ba)
        self.loss_gen_adv_b = self.dis.calc_gen_loss(one_hot_x_ab)

        # Total loss.
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_kl_w'] * self.loss_gen_recon_kl_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_a + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_aba + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cyc_x_b + \
                              hyperparameters['recon_kl_cyc_w'] * self.loss_gen_recon_kl_cyc_bab

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def sample(self, x_a, x_b):

        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        x_a_recon, x_b_recon, x_ba, x_ab = [], [], [], []
        for i in range(x_a.size(0)):
            h_a, _ = self.gen_a.encode(x_a[i].unsqueeze(0))
            h_b, _ = self.gen_b.encode(x_b[i].unsqueeze(0))
            x_a_recon.append(self.gen_a.decode(h_a))
            x_b_recon.append(self.gen_b.decode(h_b))
            x_ba.append(self.gen_a.decode(h_b))
            x_ab.append(self.gen_b.decode(h_a))
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba = torch.cat(x_ba)
        x_ab = torch.cat(x_ab)
        self.train()
        return x_a, x_a_recon, x_ab, x_b, x_b_recon, x_ba

    def dis_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.dis_opt.zero_grad()

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)
        h_a, n_a = self.gen.encode(one_hot_x_a)
        h_b, n_b = self.gen.encode(one_hot_x_b)

        # Decode (cross domain).
        one_hot_h_ab = torch.cat([h_a + n_a, self.one_hot_h[d_index_b]], 1)
        one_hot_h_ba = torch.cat([h_b + n_b, self.one_hot_h[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_h_ba)
        x_ab = self.gen.decode(one_hot_h_ab)

        # D loss.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        self.loss_dis_a = self.dis.calc_dis_loss(one_hot_x_ba.detach(),
                                                 one_hot_x_a)
        self.loss_dis_b = self.dis.calc_dis_loss(one_hot_x_ab.detach(),
                                                 one_hot_x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \
                              hyperparameters['gan_w'] * self.loss_dis_b

        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):

        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):

        # Load generators.
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen.load_state_dict(state_dict)
        epochs = int(last_model_name[-11:-3])

        # Load discriminators.
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis.load_state_dict(state_dict)

        # Load supervised model.
        last_model_name = get_model_list(checkpoint_dir, "sup")
        state_dict = torch.load(last_model_name)
        self.sup.load_state_dict(state_dict)

        # Load optimizers.
        last_model_name = get_model_list(checkpoint_dir, "opt")
        state_dict = torch.load(last_model_name)
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])

        for state in self.dis_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        for state in self.gen_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        # Reinitilize schedulers.
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           epochs)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           epochs)

        print('Resume from iteration %d' % epochs)
        return epochs

    def save(self, snapshot_dir, epoch):

        # Save generators, discriminators, and optimizers.
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % epoch)
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % epoch)
        sup_name = os.path.join(snapshot_dir, 'sup_%08d.pt' % epoch)
        opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % epoch)

        torch.save(self.gen.state_dict(), gen_name)
        torch.save(self.dis.state_dict(), dis_name)
        torch.save(self.sup.state_dict(), sup_name)
        torch.save(
            {
                'dis': self.dis_opt.state_dict(),
                'gen': self.gen_opt.state_dict()
            }, opt_name)
Ejemplo n.º 16
0
class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters, resume_epoch=-1, snapshot_dir=None):

        super(MUNIT_Trainer, self).__init__()

        lr = hyperparameters['lr']

        # Initiate the networks.
        self.gen = AdaINGen(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['gen'],
            hyperparameters['n_datasets'])  # Auto-encoder for domain a.
        self.dis = MsImageDis(
            hyperparameters['input_dim'] + hyperparameters['n_datasets'],
            hyperparameters['dis'])  # Discriminator for domain a.

        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']
        self.beta1 = hyperparameters['beta1']
        self.beta2 = hyperparameters['beta2']
        self.weight_decay = hyperparameters['weight_decay']

        # Initiating and loader pretrained UNet.
        self.sup = UNet(input_channels=hyperparameters['input_dim'],
                        num_classes=2).cuda()

        # Fix the noise used in sampling.
        self.s_a = torch.randn(8, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(8, self.style_dim, 1, 1).cuda()

        # Setup the optimizers.
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']

        dis_params = list(self.dis.parameters())
        gen_params = list(self.gen.parameters()) + list(self.sup.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(self.beta1, self.beta2),
            weight_decay=hyperparameters['weight_decay'])

        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization.
        self.apply(weights_init(hyperparameters['init']))
        self.dis.apply(weights_init('gaussian'))

        # Presetting one hot encoding vectors.
        self.one_hot_img = torch.zeros(hyperparameters['n_datasets'],
                                       hyperparameters['batch_size'],
                                       hyperparameters['n_datasets'], 256,
                                       256).cuda()
        self.one_hot_c = torch.zeros(hyperparameters['n_datasets'],
                                     hyperparameters['batch_size'],
                                     hyperparameters['n_datasets'], 64,
                                     64).cuda()

        for i in range(hyperparameters['n_datasets']):
            self.one_hot_img[i, :, i, :, :].fill_(1)
            self.one_hot_c[i, :, i, :, :].fill_(1)

        if resume_epoch != -1:

            self.resume(snapshot_dir, hyperparameters)

    def recon_criterion(self, input, target):

        return torch.mean(torch.abs(input - target))

    def semi_criterion(self, input, target):

        loss = CrossEntropyLoss2d(size_average=False).cuda()
        return loss(input, target)

    def forward(self, x_a, x_b):

        self.eval()

        x_a.volatile = True
        x_b.volatile = True

        s_a = Variable(self.s_a, volatile=True)
        s_b = Variable(self.s_b, volatile=True)

        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)
        c_a, s_a_fake = self.gen.encode(one_hot_x_a)
        c_b, s_b_fake = self.gen.encode(one_hot_x_b)

        one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_a]], 1)
        one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_b]], 1)
        x_ba = self.gen.decode(one_hot_c_b, s_a)
        x_ab = self.gen.decode(one_hot_c_a, s_b)

        self.train()

        return x_ab, x_ba

    def set_gen_trainable(self, train_bool):

        if train_bool:
            self.gen.train()
            for param in self.gen.parameters():
                param.requires_grad = True

        else:
            self.gen.eval()
            for param in self.gen.parameters():
                param.requires_grad = True

    def set_sup_trainable(self, train_bool):

        if train_bool:
            self.sup.train()
            for param in self.sup.parameters():
                param.requires_grad = True
        else:
            self.sup.eval()
            for param in self.sup.parameters():
                param.requires_grad = True

    def sup_update(self, x_a, x_b, y_a, y_b, d_index_a, d_index_b, use_a,
                   use_b, hyperparameters):

        self.gen_opt.zero_grad()

        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)

        # Encode.
        c_a, s_a_prime = self.gen.encode(one_hot_x_a)
        c_b, s_b_prime = self.gen.encode(one_hot_x_b)

        # Decode (within domain).
        one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_a]], 1)
        one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_b]], 1)
        x_a_recon = self.gen.decode(one_hot_c_a, s_a_prime)
        x_b_recon = self.gen.decode(one_hot_c_b, s_b_prime)

        # Decode (cross domain).
        one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1)
        one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_c_ba, s_a)
        x_ab = self.gen.decode(one_hot_c_ab, s_b)

        # Encode again.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        c_b_recon, s_a_recon = self.gen.encode(one_hot_x_ba)
        c_a_recon, s_b_recon = self.gen.encode(one_hot_x_ab)

        # Forwarding through supervised model.
        p_a = None
        p_b = None
        loss_semi_a = None
        loss_semi_b = None

        has_a_label = (c_a[use_a, :, :, :].size(0) != 0)
        if has_a_label:
            p_a = self.sup(c_a, use_a, True)
            p_a_recon = self.sup(c_a_recon, use_a, True)
            loss_semi_a = self.semi_criterion(p_a, y_a[use_a, :, :]) + \
                          self.semi_criterion(p_a_recon, y_a[use_a, :, :])

        has_b_label = (c_b[use_b, :, :, :].size(0) != 0)
        if has_b_label:
            p_b = self.sup(c_b, use_b, True)
            p_b_recon = self.sup(c_b, use_b, True)
            loss_semi_b = self.semi_criterion(p_b, y_b[use_b, :, :]) + \
                          self.semi_criterion(p_b_recon, y_b[use_b, :, :])

        self.loss_gen_total = None
        if loss_semi_a is not None and loss_semi_b is not None:
            self.loss_gen_total = loss_semi_a + loss_semi_b
        elif loss_semi_a is not None:
            self.loss_gen_total = loss_semi_a
        elif loss_semi_b is not None:
            self.loss_gen_total = loss_semi_b

        if self.loss_gen_total is not None:
            self.loss_gen_total.backward()
            self.gen_opt.step()

    def sup_forward(self, x, y, d_index, hyperparameters):

        self.sup.eval()

        # Encoding content image.
        one_hot_x = torch.cat([x, self.one_hot_img[d_index, 0].unsqueeze(0)],
                              1)
        content, _ = self.gen.encode(one_hot_x)

        # Forwarding on supervised model.
        y_pred = self.sup(content, only_prediction=True)

        # Computing metrics.
        pred = y_pred.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()

        jacc = jaccard(pred, y.cpu().squeeze(0).numpy())

        return jacc, pred, content

    def gen_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.gen_opt.zero_grad()

        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)

        c_a, s_a_prime = self.gen.encode(one_hot_x_a)
        c_b, s_b_prime = self.gen.encode(one_hot_x_b)

        # Decode (within domain).
        one_hot_c_a = torch.cat([c_a, self.one_hot_c[d_index_a]], 1)
        one_hot_c_b = torch.cat([c_b, self.one_hot_c[d_index_b]], 1)
        x_a_recon = self.gen.decode(one_hot_c_a, s_a_prime)
        x_b_recon = self.gen.decode(one_hot_c_b, s_b_prime)

        # Decode (cross domain).
        one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1)
        one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1)
        x_ba = self.gen.decode(one_hot_c_ba, s_a)
        x_ab = self.gen.decode(one_hot_c_ab, s_b)

        # Encode again.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)
        c_b_recon, s_a_recon = self.gen.encode(one_hot_x_ba)
        c_a_recon, s_b_recon = self.gen.encode(one_hot_x_ab)

        # Decode again (if needed).
        one_hot_c_aba_recon = torch.cat([c_a_recon, self.one_hot_c[d_index_a]],
                                        1)
        one_hot_c_bab_recon = torch.cat([c_b_recon, self.one_hot_c[d_index_b]],
                                        1)
        x_aba = self.gen.decode(one_hot_c_aba_recon, s_a_prime)
        x_bab = self.gen.decode(one_hot_c_bab_recon, s_b_prime)

        # Reconstruction loss.
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)

        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a)
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b)

        # GAN loss.
        self.loss_gen_adv_a = self.dis.calc_gen_loss(one_hot_x_ba)
        self.loss_gen_adv_b = self.dis.calc_gen_loss(one_hot_x_ab)

        # Total loss.
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b

        self.loss_gen_total.backward()
        self.gen_opt.step()

    def compute_vgg_loss(self, vgg, img, target):

        img_vgg = vgg_preprocess(img)
        target_vgg = vgg_preprocess(target)
        img_fea = vgg(img_vgg)
        target_fea = vgg(target_vgg)
        return torch.mean(
            (self.instancenorm(img_fea) - self.instancenorm(target_fea))**2)

    def sample(self, x_a, x_b):

        self.eval()
        x_a.volatile = True
        x_b.volatile = True
        s_a1 = Variable(self.s_a, volatile=True)
        s_b1 = Variable(self.s_b, volatile=True)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda(),
                        volatile=True)
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda(),
                        volatile=True)
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
        for i in range(x_a.size(0)):

            one_hot_x_a = torch.cat(
                [x_a[i].unsqueeze(0), self.one_hot_img_a[i].unsqueeze(0)], 1)
            one_hot_x_b = torch.cat(
                [x_b[i].unsqueeze(0), self.one_hot_img_b[i].unsqueeze(0)], 1)

            c_a, s_a_fake = self.gen.encode(one_hot_x_a)
            c_b, s_b_fake = self.gen.encode(one_hot_x_b)
            x_a_recon.append(self.gen.decode(c_a, s_a_fake))
            x_b_recon.append(self.gen.decode(c_b, s_b_fake))
            x_ba1.append(self.gen.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen.decode(c_a, s_b2[i].unsqueeze(0)))

        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

    def dis_update(self, x_a, x_b, d_index_a, d_index_b, hyperparameters):

        self.dis_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        # Encode.
        one_hot_x_a = torch.cat([x_a, self.one_hot_img[d_index_a]], 1)
        one_hot_x_b = torch.cat([x_b, self.one_hot_img[d_index_b]], 1)

        c_a, _ = self.gen.encode(one_hot_x_a)
        c_b, _ = self.gen.encode(one_hot_x_b)

        one_hot_c_ba = torch.cat([c_b, self.one_hot_c[d_index_a]], 1)
        one_hot_c_ab = torch.cat([c_a, self.one_hot_c[d_index_b]], 1)

        # Decode (cross domain).
        x_ba = self.gen.decode(one_hot_c_ba, s_a)
        x_ab = self.gen.decode(one_hot_c_ab, s_b)

        # D loss.
        one_hot_x_ba = torch.cat([x_ba, self.one_hot_img[d_index_a]], 1)
        one_hot_x_ab = torch.cat([x_ab, self.one_hot_img[d_index_b]], 1)

        self.loss_dis_a = self.dis.calc_dis_loss(one_hot_x_ba, one_hot_x_a)
        self.loss_dis_b = self.dis.calc_dis_loss(one_hot_x_ab, one_hot_x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + \
                              hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

    def update_learning_rate(self):

        if self.dis_scheduler is not None:
            self.dis_scheduler.step()
        if self.gen_scheduler is not None:
            self.gen_scheduler.step()

    def resume(self, checkpoint_dir, hyperparameters):

        print("--> " + checkpoint_dir)

        # Load generator.
        last_model_name = get_model_list(checkpoint_dir, "gen")
        state_dict = torch.load(last_model_name)
        self.gen.load_state_dict(state_dict)
        epochs = int(last_model_name[-11:-3])

        # Load supervised model.
        last_model_name = get_model_list(checkpoint_dir, "sup")
        state_dict = torch.load(last_model_name)
        self.sup.load_state_dict(state_dict)

        # Load discriminator.
        last_model_name = get_model_list(checkpoint_dir, "dis")
        state_dict = torch.load(last_model_name)
        self.dis.load_state_dict(state_dict)

        # Load optimizers.
        last_model_name = get_model_list(checkpoint_dir, "opt")
        state_dict = torch.load(last_model_name)
        self.dis_opt.load_state_dict(state_dict['dis'])
        self.gen_opt.load_state_dict(state_dict['gen'])

        for state in self.dis_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        for state in self.gen_opt.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

        # Reinitilize schedulers.
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters,
                                           epochs)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters,
                                           epochs)

        print('Resume from epoch %d' % epochs)
        return epochs

    def save(self, snapshot_dir, epoch):

        # Save generators, discriminators, and optimizers.
        gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % epoch)
        dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % epoch)
        sup_name = os.path.join(snapshot_dir, 'sup_%08d.pt' % epoch)
        opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % epoch)

        torch.save(self.gen.state_dict(), gen_name)
        torch.save(self.dis.state_dict(), dis_name)
        torch.save(self.sup.state_dict(), sup_name)
        torch.save(
            {
                'gen': self.gen_opt.state_dict(),
                'dis': self.dis_opt.state_dict()
            }, opt_name)
    # seed 値の固定
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    #======================================================================
    # データセットを読み込み or 生成
    # データの前処理
    #======================================================================
    ds_test = Map2AerialDataset( args.dataset_dir, "val", args.image_size, args.image_size, args.debug )
    dloader_test = torch.utils.data.DataLoader(ds_test, batch_size=args.batch_size, shuffle=False )

    #======================================================================
    # モデルの構造を定義する。
    #======================================================================
    model = UNet( 
        n_in_channels = 3, n_out_channels = 3,
        n_fmaps = args.n_fmaps,
    ).to( device )
        
    if( args.debug ):
        print( "model :\n", model )

    # モデルを読み込む
    if not args.load_checkpoints_dir == '' and os.path.exists(args.load_checkpoints_dir):
        init_step = load_checkpoint(model, device, os.path.join(args.load_checkpoints_dir, "model_final.pth") )
    
    #======================================================================
    # モデルの推論処理
    #======================================================================
    print("Starting Test Loop...")
    n_print = 1
    model.eval()
Ejemplo n.º 18
0
def test_unet(root,\
              psf_path,
              method,\
              std, \
              model_path,\
              visual, \
              use_gpu,\
              b_size=1):
    """
    Model UNet
    """

    model_name = method + '_gaussian'
    save_images_path = './Results/' + model_name + '_std_' + str(std).replace(
        '.', '') + '/'

    test_dataset = CellDataset(root, psf_path, 'gaussian', 1.0, std)
    test_loader = DataLoader(test_dataset,
                             batch_size=b_size,
                             shuffle=False,
                             num_workers=1)

    model = UNet(mode='batch')

    state_dict = torch.load(os.path.join(model_path, model_name))
    state_dict = state_dict['model_state_dict']
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model.eval()

    if use_gpu == 1:
        model.cuda()

    psnr_values_test = []
    ssim_values_test = []

    distorted_psnr_test = []
    distorted_ssim_test = []

    for i_batch, ((gt, image), psf, index, image_name, _,
                  std) in enumerate(tqdm(test_loader)):

        image = image.reshape((b_size, 1, image.shape[-2], image.shape[-1]))

        gt = gt.reshape((b_size, 1, gt.shape[-2], gt.shape[-1]))

        if use_gpu == 1:
            image = image.cuda()
            gt = gt.cuda()

        distorted_psnr = calc_psnr(image.clamp(gt.min(), gt.max()), gt)
        distorted_ssim = ssim(image.clamp(gt.min(), gt.max()), gt)

        output = model(image)

        psnr_test = calc_psnr(output.clamp(gt.min(), gt.max()), gt)
        s_sim_test = ssim(output.clamp(gt.min(), gt.max()), gt)

        psnr_values_test.append(psnr_test.item())
        ssim_values_test.append(s_sim_test.item())

        distorted_psnr_test.append(distorted_psnr.item())
        distorted_ssim_test.append(distorted_ssim.item())

        #Save image
        if visual == 1:

            if not os.path.exists(save_images_path):
                os.makedirs(save_images_path, exist_ok=True)

            io.imsave(os.path.join(save_images_path, 'output_' + str(image_name[0][:-4]) + '_' + \
                      str(model_name) + '_' + str(std.item()).replace('.', '') + '.png'), np.uint8(output[0][0].detach().cpu().numpy().clip(0,1) * 255.))

    print('Test on Gaussian noise with %.3f std: PSNR %.2f, SSIM %.4f, distorted PSNR %.2f, distorted SSIM %.4f' % (std, np.array(psnr_values_test).mean(), \
                                                                                              np.array(ssim_values_test).mean(), \
                                                                                              np.array(distorted_psnr_test).mean(), \
                                                                                              np.array(distorted_ssim_test).mean()))

    return
Ejemplo n.º 19
0
    imgs = tiff.imread(trn_imgs_path)[:, :, :, np.newaxis] / 255.
    msks = tiff.imread(trn_msks_path)[:, :, :, np.newaxis] / 255.

    # Normalize images.
    imgs -= np.mean(imgs)
    imgs /= np.std(imgs)

    # Train/val split.
    imgs_trn, msks_trn = imgs[:20, ...], msks[:20, ...]
    imgs_val, msks_val = imgs[20:, ...], msks[20:, ...]
    data = (imgs_trn, msks_trn, imgs_val, msks_val)

    # Network and training parameters.
    input_shape = (128, 128, 1)
    total_iters = 10000
    iters_trn = 500
    iters_val = 100
    batch = 8
    epochs = total_iters // iters_trn
    alpha0 = 0.1
    alpha1 = 1
    alpha_switch_epoch = 3

    # Networks.
    S = UNet(input_shape)
    D = ConvNetClassifier(S.output_shape[1:])

    # Training.
    train_adversarial(S, D, *data, iters_trn, iters_val, epochs, batch, alpha0,
                      alpha1, alpha_switch_epoch)
Ejemplo n.º 20
0
def train(params, args, world_rank, local_rank):

    #logging info
    logging.info('rank {:d}, begin data loader init (local rank {:d})'.format(
        world_rank, local_rank))

    # set device
    device = torch.device("cuda:{}".format(local_rank))

    # data loader
    pipe = dl.DaliPipeline(params,
                           num_threads=params.num_data_workers,
                           device_id=device.index)
    pipe.build()
    train_data_loader = DALIGenericIterator([pipe], ['inp', 'tar'],
                                            params.Nsamples,
                                            auto_reset=True)
    logging.info('rank %d, data loader initialized' % world_rank)

    model = UNet.UNet(params).to(device)

    if not args.resuming:
        model.apply(model.get_weights_function(params.weight_init))

    optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision
    if params.distributed:
        model = DDP(model,
                    device_ids=[device.index],
                    output_device=device.index)

    # loss
    criterion = UNet.CosmoLoss(params.LAMBDA_2)

    # amp stuff
    if args.enable_amp:
        gscaler = amp.GradScaler()

    iters = 0
    startEpoch = 0
    checkpoint = None
    if args.resuming:
        if world_rank == 0:
            logging.info("Loading checkpoint %s" % params.checkpoint_path)
        checkpoint = torch.load(params.checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state'])
        iters = checkpoint['iters']
        startEpoch = checkpoint['epoch'] + 1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if world_rank == 0:
        logging.info(model)
        logging.info("Starting Training Loop...")

    with torch.autograd.profiler.emit_nvtx():
        for epoch in range(startEpoch, startEpoch + params.num_epochs):

            if args.global_timing:
                dist.barrier()

            start = time.time()
            epoch_step = 0
            tr_time = 0.
            fw_time = 0.
            bw_time = 0.
            log_time = 0.

            model.train()
            for data in train_data_loader:
                torch.cuda.nvtx.range_push("cosmo3D:step {}".format(iters))
                tr_start = time.time()
                adjust_LR(optimizer, params, iters)

                # fetch data
                inp = data[0]["inp"]
                tar = data[0]["tar"]

                if not args.io_only:
                    torch.cuda.nvtx.range_push(
                        "cosmo3D:forward {}".format(iters))
                    # fw pass
                    fw_time -= time.time()
                    optimizer.zero_grad()
                    with amp.autocast(args.enable_amp):
                        gen = model(inp)
                        loss = criterion(gen, tar)
                    fw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                    # bw pass
                    torch.cuda.nvtx.range_push(
                        "cosmo3D:backward {}".format(iters))
                    bw_time -= time.time()
                    if args.enable_amp:
                        gscaler.scale(loss).backward()
                        gscaler.step(optimizer)
                        gscaler.update()
                    else:
                        loss.backward()
                        optimizer.step()
                    bw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                iters += 1
                epoch_step += 1

                # step done
                tr_end = time.time()
                tr_time += tr_end - tr_start
                torch.cuda.nvtx.range_pop()

            # epoch done
            if args.global_timing:
                dist.barrier()

            end = time.time()
            epoch_time = end - start
            step_time = epoch_time / float(epoch_step)
            tr_time /= float(epoch_step)
            fw_time /= float(epoch_step)
            bw_time /= float(epoch_step)
            io_time = max([step_time - fw_time - bw_time, 0])
            iters_per_sec = 1. / step_time
            fw_per_sec = 1. / tr_time

            if world_rank == 0:
                logging.info('Time taken for epoch {} is {} sec'.format(
                    epoch + 1, epoch_time))
                logging.info(
                    'train step time = {} ({} steps), logging time = {}'.
                    format(tr_time, epoch_step, log_time))
                logging.info('train samples/sec = {} fw steps/sec = {}'.format(
                    iters_per_sec, fw_per_sec))
Ejemplo n.º 21
0
def train(params, args, world_rank, local_rank):

    #logging info
    logging.info('rank {:d}, begin data loader init (local rank {:d})'.format(
        world_rank, local_rank))

    # set device
    device = torch.device("cuda:{}".format(local_rank))

    # data loader
    pipe = dl.DaliPipeline(params,
                           num_threads=params.num_data_workers,
                           device_id=device.index)
    pipe.build()
    train_data_loader = DALIGenericIterator([pipe], ['inp', 'tar'],
                                            params.Nsamples,
                                            auto_reset=True)
    logging.info('rank %d, data loader initialized' % world_rank)

    model = UNet.UNet(params)
    model.to(device)
    if not args.resuming:
        model.apply(model.get_weights_function(params.weight_init))

    optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision
    if params.distributed:
        model = DDP(model, device_ids=[local_rank])

    # loss
    criterion = UNet.CosmoLoss(params.LAMBDA_2)

    # amp stuff
    if args.enable_amp:
        gscaler = amp.GradScaler()

    iters = 0
    startEpoch = 0
    checkpoint = None
    if args.resuming:
        if world_rank == 0:
            logging.info("Loading checkpoint %s" % params.checkpoint_path)
        checkpoint = torch.load(params.checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state'])
        iters = checkpoint['iters']
        startEpoch = checkpoint['epoch'] + 1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if world_rank == 0:
        logging.info(model)
        logging.info("Starting Training Loop...")

    for epoch in range(startEpoch, startEpoch + params.num_epochs):
        start = time.time()
        nsteps = 0
        fw_time = 0.
        bw_time = 0.
        log_time = 0.

        model.train()
        step_time = time.time()
        #for i, data in enumerate(train_data_loader, 0):
        with torch.autograd.profiler.emit_nvtx():
            for data in train_data_loader:
                iters += 1

                #adjust_LR(optimizer, params, iters)
                inp = data[0]["inp"]
                tar = data[0]["tar"]

                if not args.io_only:
                    torch.cuda.nvtx.range_push("cosmo3D:forward")
                    # fw pass
                    fw_time -= time.time()
                    optimizer.zero_grad()
                    with amp.autocast(args.enable_amp):
                        gen = model(inp)
                        loss = criterion(gen, tar)
                    fw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                    # bw pass
                    torch.cuda.nvtx.range_push("cosmo3D:backward")
                    bw_time -= time.time()
                    if args.enable_amp:
                        gscaler.scale(loss).backward()
                        gscaler.step(optimizer)
                        gscaler.update()
                    else:
                        loss.backward()
                        optimizer.step()
                    bw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                nsteps += 1

            # epoch done
            dist.barrier()
            step_time = (time.time() - step_time) / float(nsteps)
            fw_time /= float(nsteps)
            bw_time /= float(nsteps)
            io_time = max([step_time - fw_time - bw_time, 0])
            iters_per_sec = 1. / step_time

            end = time.time()
            if world_rank == 0:
                logging.info('Time taken for epoch {} is {} sec'.format(
                    epoch + 1, end - start))
                logging.info(
                    'total time / step = {}, fw time / step = {}, bw time / step = {}, exposed io time / step = {}, iters/s = {}, logging time = {}'
                    .format(step_time, fw_time, bw_time, io_time,
                            iters_per_sec, log_time))

        ## Output training stats
        #model.eval()
        #if world_rank==0:
        #  log_start = time.time()
        #  gens = []
        #  tars = []
        #  with torch.no_grad():
        #    for i, data in enumerate(train_data_loader, 0):
        #      if i>=16:
        #        break
        #      #inp, tar = map(lambda x: x.to(device), data)
        #      inp, tar = data
        #      gen = model(inp)
        #      gens.append(gen.detach().cpu().numpy())
        #      tars.append(tar.detach().cpu().numpy())
        #  gens = np.concatenate(gens, axis=0)
        #  tars = np.concatenate(tars, axis=0)
        #
        #  # Scalars
        #  args.tboard_writer.add_scalar('G_loss', loss.item(), iters)
        #
        #  # Plots
        #  fig, chi, L1score = meanL1(gens, tars)
        #  args.tboard_writer.add_figure('pixhist', fig, iters, close=True)
        #  args.tboard_writer.add_scalar('Metrics/chi', chi, iters)
        #  args.tboard_writer.add_scalar('Metrics/rhoL1', L1score[0], iters)
        #  args.tboard_writer.add_scalar('Metrics/vxL1', L1score[1], iters)
        #  args.tboard_writer.add_scalar('Metrics/vyL1', L1score[2], iters)
        #  args.tboard_writer.add_scalar('Metrics/vzL1', L1score[3], iters)
        #  args.tboard_writer.add_scalar('Metrics/TL1', L1score[4], iters)
        #
        #  fig = generate_images(inp.detach().cpu().numpy()[0], gens[-1], tars[-1])
        #  args.tboard_writer.add_figure('genimg', fig, iters, close=True)
        #  log_end = time.time()
        #  log_time += log_end - log_start

        #  # Save checkpoint
        #  torch.save({'iters': iters, 'epoch':epoch, 'model_state': model.state_dict(),
        #              'optimizer_state_dict': optimizer.state_dict()}, params.checkpoint_path)

        #end = time.time()
        #if world_rank==0:
        #    logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, end-start))
        #    logging.info('total time / step = {}, fw time / step = {}, bw time / step = {}, exposed io time / step = {}, iters/s = {}, logging time = {}'
        #                 .format(step_time, fw_time, bw_time, io_time, iters_per_sec, log_time))

        # finalize
        dist.barrier()
Ejemplo n.º 22
0
BASE_DIR = os.path.join(os.pardir, "data", "test")
SAVE_DIR = os.path.join(os.pardir, "results")
test_dataset = SimulationDataset(base_dir=BASE_DIR)

for i, data_sample in tqdm(enumerate(test_dataset)):

    if i > 20:
        break

    geometry_array = data_sample["geometry"]
    flow_array = data_sample["flow"]

    # Load model and make prediction based on geometry
    path_to_model = os.path.join(os.pardir, "models", "model_checkpoint.pt")
    model = UNet()
    model.load_state_dict(torch.load(path_to_model))
    model.eval()

    geometry = torch.from_numpy(geometry_array)
    geometry = geometry.unsqueeze(0)
    prediction = model(geometry)

    # Postprocessing
    geometry = np.transpose(geometry_array, [1, 2, 0])

    prediction = prediction.squeeze(0)
    prediction = prediction.permute(1, 2, 0)
    prediction = prediction.detach().numpy()
    flow_array = np.transpose(flow_array, (1, 2, 0))