예제 #1
0
def train(params, n_epochs, verbose=True):
    # init interpolation model
    timestamp = int(time.time())
    formatted_params = '_'.join(f'{k}={v}' for k, v in params.items())

    torch.manual_seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    random.seed(FLAGS.seed)

    if FLAGS.filename == None:
        G = SepConvNetExtended(kl_size=params['kl_size'],
                               kq_size=params['kq_size'],
                               kl_d_size=params['kl_d_size'],
                               kl_d_scale=params['kl_d_scale'],
                               kq_d_scale=params['kq_d_scale'],
                               kq_d_size=params['kq_d_size'],
                               input_frames=params['input_size'])

        if params['pretrain'] in [1, 2]:
            print('LOADING L1')
            G.load_weights('l1')
        name = f'{timestamp}_seed_{FLAGS.seed}_{formatted_params}'
        G = torch.nn.DataParallel(G).cuda()

        # optimizer = torch.optim.Adamax(G.parameters(), lr=params['lr'], betas=(.9, .999))
        if params['optimizer'] == 'ranger':
            optimizer = Ranger([{
                'params':
                [p for l, p in G.named_parameters() if 'moduleConv' not in l]
            }, {
                'params':
                [p for l, p in G.named_parameters() if 'moduleConv' in l],
                'lr':
                params['lr2']
            }],
                               lr=params['lr'],
                               betas=(.95, .999))

        elif params['optimizer'] == 'adamax':
            optimizer = torch.optim.Adamax([{
                'params':
                [p for l, p in G.named_parameters() if 'moduleConv' not in l]
            }, {
                'params':
                [p for l, p in G.named_parameters() if 'moduleConv' in l],
                'lr':
                params['lr2']
            }],
                                           lr=params['lr'],
                                           betas=(.9, .999))

        else:
            raise NotImplementedError()

        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=60 - FLAGS.warmup, T_mult=1, eta_min=1e-5)

        start_epoch = 0

    else:
        checkpoint = torch.load(FLAGS.filename)
        G = checkpoint['last_model'].cuda()
        start_epoch = checkpoint['epoch'] + 1
        name = checkpoint['name']

        optimizer = torch.optim.Adamax([{
            'params':
            [p for l, p in G.named_parameters() if 'moduleConv' not in l]
        }, {
            'params':
            [p for l, p in G.named_parameters() if 'moduleConv' in l],
            'lr':
            params['lr2']
        }],
                                       lr=params['lr'],
                                       betas=(.9, .999))

        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                               T_max=n_epochs -
                                                               FLAGS.warmup,
                                                               eta_min=1e-5,
                                                               last_epoch=-1)

        for _ in range(start_epoch - FLAGS.warmup + 1):
            scheduler.step()

    print('SETTINGS:')
    print(params)
    print('NAME:')
    print(name)
    sys.stdout.flush()

    # loss_network = losses.LossNetwork(layers=[9,16,26]).cuda() #9, 16, 26
    # Perc_loss = losses.PerceptualLoss(loss_network, include_input=True)

    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

    quadratic = params['input_size'] == 4

    L1_loss = torch.nn.L1Loss().cuda()
    # Flow_loss = losses.FlowLoss(quadratic=quadratic).cuda()

    ds_train_lmd = dataloader.large_motion_dataset(quadratic=quadratic,
                                                   cropped=True,
                                                   fold='train',
                                                   min_flow=6)
    ds_valid_lmd = dataloader.large_motion_dataset(quadratic=quadratic,
                                                   cropped=True,
                                                   fold='valid')

    ds_vimeo_train = dataloader.vimeo90k_dataset(fold='train',
                                                 quadratic=quadratic)
    ds_vimeo_test = dataloader.vimeo90k_dataset(fold='test',
                                                quadratic=quadratic)

    # train, test_lmd = dataloader.split_data(ds_lmd, [.9, .1])
    train_vimeo, valid_vimeo = dataloader.split_data(ds_vimeo_train, [.9, .1])

    # torch.manual_seed(FLAGS.seed)
    # np.random.seed(FLAGS.seed)
    # random.seed(FLAGS.seed)

    train_settings = {
        'flip_probs': FLAGS.flip_probs,
        'normalize': True,
        'crop_size': (FLAGS.crop_size, FLAGS.crop_size),
        'jitter_prob': FLAGS.jitter_prob,
        'random_rescale_prob': FLAGS.random_rescale_prob
        # 'rescale_distr':(.8, 1.2),
    }

    valid_settings = {
        'flip_probs': 0,
        'random_rescale_prob': 0,
        'random_crop': False,
        'normalize': True
    }

    train_lmd = dataloader.TransformedDataset(ds_train_lmd, **train_settings)
    valid_lmd = dataloader.TransformedDataset(ds_valid_lmd, **valid_settings)

    train_vimeo = dataloader.TransformedDataset(train_vimeo, **train_settings)
    valid_vimeo = dataloader.TransformedDataset(valid_vimeo, **valid_settings)
    test_vimeo = dataloader.TransformedDataset(ds_vimeo_test, **valid_settings)

    train_data = torch.utils.data.ConcatDataset([train_lmd, train_vimeo])

    # displacement
    df = pd.read_csv(f'hardinstancesinfo/vimeo90k_test_flow.csv')
    test_disp = torch.utils.data.Subset(
        ds_vimeo_test,
        indices=df[df.mean_manh_flow >= df.quantile(.9).mean_manh_flow].index.
        tolist())
    test_disp = dataloader.TransformedDataset(test_disp, **valid_settings)
    test_disp = torch.utils.data.DataLoader(test_disp,
                                            batch_size=4,
                                            pin_memory=True)

    # nonlinearity
    df = pd.read_csv(f'hardinstancesinfo/Vimeo90K_test.csv')
    test_nonlin = torch.utils.data.Subset(
        ds_vimeo_test,
        indices=df[
            df.non_linearity >= df.quantile(.9).non_linearity].index.tolist())
    test_nonlin = dataloader.TransformedDataset(test_nonlin, **valid_settings)
    test_nonlin = torch.utils.data.DataLoader(test_nonlin,
                                              batch_size=4,
                                              pin_memory=True)

    # create weights for train sampler
    df_vim = pd.read_csv(f'hardinstancesinfo/vimeo90k_train_flow.csv')
    weights_vim = df_vim[df_vim.index.isin(
        train_vimeo.dataset.indices)].mean_manh_flow.tolist()
    weights_lmd = ds_train_lmd.weights
    train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
        weights_lmd + weights_vim, FLAGS.num_train_samples, replacement=False)

    train_dl = torch.utils.data.DataLoader(train_data,
                                           batch_size=FLAGS.batch_size,
                                           pin_memory=True,
                                           shuffle=False,
                                           sampler=train_sampler,
                                           num_workers=FLAGS.num_workers)
    valid_dl_vim = torch.utils.data.DataLoader(valid_vimeo,
                                               batch_size=4,
                                               pin_memory=True,
                                               num_workers=FLAGS.num_workers)
    valid_dl_lmd = torch.utils.data.DataLoader(valid_lmd,
                                               batch_size=4,
                                               pin_memory=True,
                                               num_workers=FLAGS.num_workers)
    test_dl_vim = torch.utils.data.DataLoader(test_vimeo,
                                              batch_size=4,
                                              pin_memory=True,
                                              num_workers=FLAGS.num_workers)

    # metrics
    writer = SummaryWriter(f'runs/final_exp/full_run_losses/{name}')

    results = ResultStore(writer=writer,
                          metrics=['psnr', 'ssim', 'ie', 'L1_loss', 'lf'],
                          folds=FOLDS)

    early_stopping_metric = 'L1_loss'
    early_stopping = EarlyStopping(results,
                                   patience=FLAGS.patience,
                                   metric=early_stopping_metric,
                                   fold='valid_vimeo')

    loss_network = losses.LossNetwork(layers=[26]).cuda()  #9, 16, 26
    Perc_loss = losses.PerceptualLoss(loss_network).cuda()

    def do_epoch(dataloader, fold, epoch, train=False):
        assert fold in FOLDS

        if verbose:
            pb = tqdm(desc=f'{fold} {epoch+1}/{n_epochs}',
                      total=len(dataloader),
                      leave=True,
                      position=0)

        for i, (X, y) in enumerate(dataloader):
            X = X.cuda()
            y = y.cuda()

            y_hat = G(X)

            l1_loss = L1_loss(y_hat, y)
            feature_loss = Perc_loss(y_hat, y)

            lf_loss = l1_loss + feature_loss

            if train:
                optimizer.zero_grad()
                lf_loss.backward()
                optimizer.step()

            # compute metrics
            y_hat = (y_hat * 255).clamp(0, 255)
            y = (y * 255).clamp(0, 255)

            psnr = metrics.psnr(y_hat, y)
            ssim = metrics.ssim(y_hat, y)
            ie = metrics.interpolation_error(y_hat, y)

            results.store(
                fold, epoch, {
                    'L1_loss': l1_loss.item(),
                    'psnr': psnr,
                    'ssim': ssim,
                    'ie': ie,
                    'lf': lf_loss.item()
                })

            if verbose: pb.update()

        # update tensorboard
        results.write_tensorboard(fold, epoch)
        sys.stdout.flush()

    start_time = time.time()
    for epoch in range(start_epoch, n_epochs):

        G.train()
        do_epoch(train_dl, 'train_fold', epoch, train=True)

        if epoch >= FLAGS.warmup - 1:
            scheduler.step()

        G.eval()
        with torch.no_grad():
            do_epoch(valid_dl_vim, 'valid_vimeo', epoch)
            do_epoch(valid_dl_lmd, 'valid_lmd', epoch)

        if (early_stopping.stop() and epoch >= FLAGS.min_epochs
            ) or epoch % FLAGS.test_every == 0 or epoch + 1 == n_epochs:
            with torch.no_grad():
                do_epoch(test_disp, 'test_disp', epoch)
                do_epoch(test_nonlin, 'test_nonlin', epoch)

                do_epoch(test_dl_vim, 'test_vimeo', epoch)

            visual_evaluation(model=G,
                              quadratic=params['input_size'] == 4,
                              writer=writer,
                              epoch=epoch)

            visual_evaluation_vimeo(model=G,
                                    quadratic=params['input_size'] == 4,
                                    writer=writer,
                                    epoch=epoch)

        # save model if new best
        if early_stopping.new_best():
            filepath_out = os.path.join(MODEL_FOLDER, '{0}_{1}')
            torch.save(G, filepath_out.format('generator', name))

        # save last model state
        checkpoint = {
            'last_model': G,
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'name': name,
            'scheduler': scheduler
        }
        torch.save(checkpoint, filepath_out.format('checkpoint', name))

        if early_stopping.stop() and epoch >= FLAGS.min_epochs:
            break

        torch.cuda.empty_cache()

    end_time = time.time()
    # free memory
    del G
    torch.cuda.empty_cache()
    time_elapsed = end_time - start_time
    print(f'Ran {n_epochs} epochs in {round(time_elapsed, 1)} seconds')

    return results
예제 #2
0
def train(epochs):
    imTransform = transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    trainLoader = helper.PreprocessInFocusOnlyData(imTransform, rootDirTrain,
                                                   batchSize, True)
    validLoader = helper.PreprocessInFocusOnlyData(imTransform, rootDirValid,
                                                   batchSize, True)
    slices = ('outputSlice1', 'outputSlice4', 'outputSlice7', 'outputSlice10')
    numSlices = len(slices)

    rootDirToSave = rootDirSave
    model1SaveDir = rootDirToSave + 'model1/'
    model2SaveDir = rootDirToSave + 'model2/'
    model3SaveDir = rootDirToSave + 'model3/'
    extension = '.pth'

    loadRootDir = rootDirSave
    model1LoadPath = loadRootDir + 'model1/'
    model2LoadPath = loadRootDir + 'model2/'
    model3LoadPath = loadRootDir + 'model3/'

    model1 = model.EncoderDecoder(3, 6 * numSlices).to(device)
    model2 = model.DepthToInFocus(3).to(device)
    model3 = model.EncoderDecoderv2(3, 1).to(device)

    model1.load_state_dict(
        torch.load(model1LoadPath + str(loadEpochNum) + extension))
    model2.load_state_dict(
        torch.load(model2LoadPath + str(loadEpochNum) + extension))
    model3.load_state_dict(
        torch.load(model3LoadPath + str(loadEpochNum) + extension))

    contentLoss = losses.PerceptualLoss()
    contentLoss.initialize(nn.MSELoss())

    trainLossArray = np.zeros(epochs)
    validLossArray = np.zeros(epochs)

    lr = learningRate
    minValidEpoch = 0
    minValidLoss = 100.0

    for epoch in range(epochs):
        start = time.clock()

        epochLoss = np.zeros(2)
        epochCount = epoch + 1

        print("Epoch num : " + str(epochCount))
        if epoch >= 20 and epoch % 20 == 0:
            lr = lr / 2
            print('Learning rate changed to ' + str(lr))

        for step, sample in enumerate(trainLoader):

            loss = torch.zeros(1).to(device)
            loss.requires_grad = False

            data = sample.to(device)

            blurOut = model1(data)

            for dispIndex in range(len(blurOut)):

                scaleLoss = torch.zeros(1).to(device)
                scaleLoss.requires_grad = False

                dispOut = blurOut[dispIndex]

                inFocus = F.interpolate(data,
                                        size=None,
                                        scale_factor=1 / pow(2.0, dispIndex),
                                        mode='bilinear')

                blurMapSlices = helper.BlurMapsToSlices(dispOut, numSlices, 6)
                inFocusRegionSlices = []

                for slices in range(numSlices):

                    nearFocusOutSlice = model2(inFocus, blurMapSlices[slices])
                    binaryBlurMapSlice = model3(nearFocusOutSlice)

                    inFocusRegionSlice = torch.mul(nearFocusOutSlice,
                                                   binaryBlurMapSlice)
                    inFocusRegionSlices.append(inFocusRegionSlice)

                inFocusOutput = helper.CombineFocusRegionsToAIF(
                    inFocusRegionSlices, numSlices, device)

                perceptLoss = contentLoss.get_loss(inFocus, inFocusOutput)
                simiLoss = losses.similarityLoss(inFocus, inFocusOutput, alpha,
                                                 winSize)
                scaleLoss += perceptLossW * perceptLoss + simiLossW * simiLoss

                loss += scaleLoss

            epochLoss[0] += loss

            optimizer = optim.Adam(list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters()), lr=lr,\
            betas=(beta1, beta2), eps=1e-08, weight_decay=0, amsgrad=False)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step % 500 == 0:
                print("Step num :" + str(step) + " Train Loss :",
                      str(loss.item()))

        with torch.no_grad():
            for validStep, validSample in enumerate(validLoader):

                validLoss = torch.zeros(1).to(device)
                validLoss.requires_grad = False

                validData = sample.to(device)

                validBlurOut = model1(validData)

                for validDispIndex in range(len(validBlurOut)):

                    validScaleLoss = torch.zeros(1).to(device)
                    validScaleLoss.requires_grad = False

                    validDispOut = validBlurOut[validDispIndex]

                    validInFocus = F.interpolate(validData,
                                                 size=None,
                                                 scale_factor=1 /
                                                 pow(2.0, validDispIndex),
                                                 mode='bilinear')

                    validBlurMapSlices = helper.BlurMapsToSlices(
                        validDispOut, numSlices, 6)
                    validInFocusRegionSlices = []

                    for validSlices in range(numSlices):

                        validNearFocusOutSlice = model2(
                            validInFocus, validBlurMapSlices[validSlices])
                        validBinaryBlurMapSlice = model3(
                            validNearFocusOutSlice)
                        validInFocusRegionSlice = torch.mul(
                            validNearFocusOutSlice, validBinaryBlurMapSlice)
                        validInFocusRegionSlices.append(
                            validInFocusRegionSlice)

                    validInFocusOutput = helper.CombineFocusRegionsToAIF(
                        validInFocusRegionSlices, numSlices, device)

                    validPerceptLoss = contentLoss.get_loss(
                        validInFocus, validInFocusOutput)
                    validSimiLoss = losses.similarityLoss(
                        validInFocus, validInFocusOutput, alpha, winSize)
                    validScaleLoss += perceptLossW * validPerceptLoss + simiLossW * validSimiLoss

                    validLoss += validScaleLoss

                epochLoss[1] += validLoss

        epochLoss[0] /= step
        epochLoss[1] /= validStep
        epochFreq = epochSaveFreq
        print("Time taken for epoch " + str(epoch) + " is " +
              str(time.clock() - start))

        if validLoss < minValidLoss:
            minValidLoss = validLoss
            minValidEpoch = epoch
            torch.save(model1.state_dict(),
                       model1SaveDir + str(epoch) + 'generic_' + extension)
            torch.save(model2.state_dict(),
                       model2SaveDir + str(epoch) + 'generic_' + extension)
            torch.save(model3.state_dict(),
                       model3SaveDir + str(epoch) + 'generic_' + extension)
            print("Saving latest model at epoch: " + str(epoch))

        if epochCount % epochFreq == 0:
            torch.save(model1.state_dict(),
                       model1SaveDir + str(epoch) + 'generic_' + extension)
            torch.save(model2.state_dict(),
                       model2SaveDir + str(epoch) + 'generic_' + extension)
            torch.save(model3.state_dict(),
                       model3SaveDir + str(epoch) + 'generic_' + extension)

        print("Training loss for epoch " + str(epochCount) + " : " +
              str(epochLoss[0]))
        print("Validation loss for epoch " + str(epochCount) + " : " +
              str(epochLoss[1]))

        trainLossArray[epoch] = epochLoss[0]
        validLossArray[epoch] = epochLoss[1]

    print(minValidEpoch)
    return (trainLossArray, validLossArray)
def create_computation_graph(x_in,
                             x_gt,
                             x_app=None,
                             arch_type='pggan',
                             use_appearance=True):
    """Create the models and the losses.

  Args:
    x_in: 4D tensor, batch of conditional input images in NHWC format.
    x_gt: 2D tensor, batch ground-truth images in NHWC format.
    x_app: 4D tensor, batch of input appearance images.

  Returns:
    Dictionary of placeholders and TF graph functions.
  """
    # ---------------------------------------------------------------------------
    # Build models/networks
    # ---------------------------------------------------------------------------

    rerenderer = networks.RenderingModel(arch_type, use_appearance)
    app_enc = rerenderer.get_appearance_encoder()
    discriminator = networks.MultiScaleDiscriminator('d_model',
                                                     opts.appearance_nc,
                                                     num_scales=3,
                                                     nf=64,
                                                     n_layers=3,
                                                     get_fmaps=False)

    # ---------------------------------------------------------------------------
    # Forward pass
    # ---------------------------------------------------------------------------

    if opts.use_appearance:
        z_app, _, _ = app_enc(x_app)
    else:
        z_app = None

    y = rerenderer(x_in, z_app)

    # ---------------------------------------------------------------------------
    # Losses
    # ---------------------------------------------------------------------------

    w_loss_gan = opts.w_loss_gan
    w_loss_recon = opts.w_loss_vgg if opts.use_vgg_loss else opts.w_loss_l1

    # compute discriminator logits
    disc_real_featmaps = discriminator(x_gt, x_in)
    disc_fake_featmaps = discriminator(y, x_in)

    # discriminator loss
    loss_d_real = losses.multiscale_discriminator_loss(disc_real_featmaps,
                                                       True)
    loss_d_fake = losses.multiscale_discriminator_loss(disc_fake_featmaps,
                                                       False)
    loss_d = loss_d_real + loss_d_fake

    # generator loss
    loss_g_gan = losses.multiscale_discriminator_loss(disc_fake_featmaps, True)
    if opts.use_vgg_loss:
        vgg_layers = ['conv%d_2' % i
                      for i in range(1, 6)]  # conv1 through conv5
        vgg_layer_weights = [1. / 32, 1. / 16, 1. / 8, 1. / 4, 1.]
        vgg_loss = losses.PerceptualLoss(
            y, x_gt, [256, 256, 3], vgg_layers,
            vgg_layer_weights)  # NOTE: shouldn't hardcode image size!
        loss_g_recon = vgg_loss()
    else:
        loss_g_recon = losses.L1_loss(y, x_gt)
    loss_g = w_loss_gan * loss_g_gan + w_loss_recon * loss_g_recon

    # ---------------------------------------------------------------------------
    # Tensorboard visualizations
    # ---------------------------------------------------------------------------

    x_in_render = tf.slice(x_in, [0, 0, 0, 0], [-1, -1, -1, 3])
    if opts.use_semantic:
        x_in_semantic = tf.slice(x_in, [0, 0, 0, 4], [-1, -1, -1, 3])
        tb_visualization = tf.concat([x_in_render, x_in_semantic, y, x_gt],
                                     axis=2)
    else:
        tb_visualization = tf.concat([x_in_render, y, x_gt], axis=2)
    tf.summary.image('rendered-semantic-generated-gt tuple', tb_visualization)

    # Show input appearance images
    if opts.use_appearance:
        x_app_rgb = tf.slice(x_app, [0, 0, 0, 0], [-1, -1, -1, 3])
        x_app_sem = tf.slice(x_app, [0, 0, 0, 7], [-1, -1, -1, -1])
        tb_app_visualization = tf.concat([x_app_rgb, x_app_sem], axis=2)
        tf.summary.image('input appearance image', tb_app_visualization)

    # Loss summaries
    with tf.name_scope('Discriminator_Loss'):
        tf.summary.scalar('D_real_loss', loss_d_real)
        tf.summary.scalar('D_fake_loss', loss_d_fake)
        tf.summary.scalar('D_total_loss', loss_d)
    with tf.name_scope('Generator_Loss'):
        tf.summary.scalar('G_GAN_loss', w_loss_gan * loss_g_gan)
        tf.summary.scalar('G_reconstruction_loss', w_loss_recon * loss_g_recon)
        tf.summary.scalar('G_total_loss', loss_g)

    # ---------------------------------------------------------------------------
    # Optimizers
    # ---------------------------------------------------------------------------

    def get_optimizer(lr, loss, var_list):
        optimizer = tf.train.AdamOptimizer(lr, opts.adam_beta1,
                                           opts.adam_beta2)
        # optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
        return optimizer.minimize(loss, var_list=var_list)

    # Training ops.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        with tf.variable_scope('optimizers'):
            d_vars = utils.model_vars('d_model')[0]
            g_vars_all = utils.model_vars('g_model')[0]
            train_d = [get_optimizer(opts.d_lr, loss_d, d_vars)]
            train_g = [get_optimizer(opts.g_lr, loss_g, g_vars_all)]

            train_app_encoder = []
            if opts.train_app_encoder:
                lr_app = opts.ez_lr
                app_enc_vars = utils.model_vars('appearance_net')[0]
                train_app_encoder.append(
                    get_optimizer(lr_app, loss_g, app_enc_vars))

    ema = tf.train.ExponentialMovingAverage(decay=0.999)
    with tf.control_dependencies(train_g + train_app_encoder):
        inference_vars_all = g_vars_all
        if opts.use_appearance:
            app_enc_vars = utils.model_vars('appearance_net')[0]
            inference_vars_all += app_enc_vars
        ema_op = ema.apply(inference_vars_all)

    print('***************************************************')
    print('len(g_vars_all) = %d' % len(g_vars_all))
    for ii, v in enumerate(g_vars_all):
        print('%03d) %s' % (ii, str(v)))
    print('-------------------------------------------------------')
    print('len(d_vars) = %d' % len(d_vars))
    for ii, v in enumerate(d_vars):
        print('%03d) %s' % (ii, str(v)))
    if opts.train_app_encoder:
        print('-------------------------------------------------------')
        print('len(app_enc_vars) = %d' % len(app_enc_vars))
        for ii, v in enumerate(app_enc_vars):
            print('%03d) %s' % (ii, str(v)))
    print('***************************************************\n\n')

    return {
        'train_disc_op': tf.group(train_d),
        'train_renderer_op': ema_op,
        'total_loss_d': loss_d,
        'loss_d_real': loss_d_real,
        'loss_d_fake': loss_d_fake,
        'loss_g_gan': w_loss_gan * loss_g_gan,
        'loss_g_recon': w_loss_recon * loss_g_recon,
        'total_loss_g': loss_g
    }
예제 #4
0
def train(epochs) :
	imTransform = transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
	sliceNames =  ('outputSlice1','outputSlice10')
	numSlices = len(sliceNames)

	trainLoaderLF = helper.PreprocessData(imTransform, sliceNames, rootDirTrainLF, 'input', 'segments', batchSize, True)
	validLoaderLF = helper.PreprocessData(imTransform, sliceNames, rootDirValidLF, 'input', 'segments', batchSize, True)

	trainLoaderReal = helper.PreprocessInFocusOnlyData(imTransform, rootDirTrainReal, 'inputs', 'masks', batchSize, False)
	validLoaderReal = helper.PreprocessInFocusOnlyData(imTransform, rootDirValidReal, 'inputs', 'masks', batchSize, False)



	rootDirToSave	= rootDirSave
	model1SaveDir 	= rootDirToSave + 'model1/'
	model2SaveDir 	= rootDirToSave + 'model2/'
	model3SaveDir 	= rootDirToSave + 'model3/'
	extension		= '.pth'

	model1 = model.EncoderDecoder(4, 6 * numSlices).to(device)
	model2 = model.DepthToInFocus(3).to(device)
	model3 = model.EncoderDecoderv2(3, 1).to(device)
	
	contentLoss = losses.PerceptualLoss()
	contentLoss.initialize(nn.MSELoss())


	trainLossArray = np.zeros(epochs)
	validLossArray = np.zeros(epochs)

	lr = learningRate
	minValidEpoch = 0
	minValidLoss = 100.0

	trainLoaderRealList = list(trainLoaderReal)
	validLoaderRealList = list(validLoaderReal)
	numRealSteps = len(trainLoaderRealList)

	for epoch in range(epochs) :
		start = time.clock()

		realStep = 0

		epochLoss = np.zeros(2)
		epochCount = epoch + 1

		print ("Epoch num : " + str(epochCount))
		if epoch >= 20 and epoch % 20 == 0 :
			lr = lr / 2
			print ('Learning rate changed to ' + str(lr))
		if epoch < 40 :
			binarizationW = 0.0

		for step, sample in enumerate(trainLoaderLF) :

			if step % 3 == 1 and realStep < numRealSteps:
				realStepLoss = torch.zeros(1).to(device)
				realInData = trainLoaderRealList[realStep]['Input'].to(device)
				realSegMask = trainLoaderRealList[realStep]['Mask'].to(device)
				realBlurOut = model1(realInData, realSegMask)
				for realDispIndex in range(len(realBlurOut)) :
					realScaleLoss = torch.zeros(1).to(device)
					realBinarizationLoss = torch.zeros(1).to(device)
					realDispOut = realBlurOut[realDispIndex]
					realInFocus = F.interpolate(realInData, size = None, scale_factor = 1/pow(2.0, realDispIndex), mode = 'bilinear')
					realBlurMapSlices = helper.BlurMapsToSlices(realDispOut, numSlices, 6)
					realInFocusRegionSlices = []
					for realSlices in range(numSlices) :
						realNearFocusOutSlice = model2(realInFocus, realBlurMapSlices[realSlices])
						realBinaryBlurMapSlice = model3(realNearFocusOutSlice)
						realBinaryBlurMapSlice = torch.cat((realBinaryBlurMapSlice, realBinaryBlurMapSlice, realBinaryBlurMapSlice), dim = 1)
						binaryLoss = losses.binarizationLoss(realBinaryBlurMapSlice) * binarizationLossW
						realBinarizationLoss += binaryLoss
						realInFocusRegionSlice = torch.mul(realNearFocusOutSlice, realBinaryBlurMapSlice)
						realInFocusRegionSlices.append(realInFocusRegionSlice)
					realInFocusOutput = helper.CombineFocusRegionsToAIF(realInFocusRegionSlices, numSlices, device)
					realPerceptLoss = contentLoss.get_loss(realInFocus, realInFocusOutput)
					realSimiLoss = losses.similarityLoss(realInFocus, realInFocusOutput, alpha, winSize)
					realScaleLoss += perceptLossW * realPerceptLoss + simiLossW * realSimiLoss + realBinarizationLoss
					realStepLoss += realScaleLoss
					loss += realStepLoss
				optimizer = optim.Adam(list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters()), lr=lr/5,\
				betas=(beta1, beta2), eps=1e-08, weight_decay=0, amsgrad=False)
				optimizer.zero_grad()
				realStepLoss.backward(retain_graph = True)
				optimizer.step()
				realStep += 1

			stepLossStage1 = torch.zeros(1).to(device)
			stepLossStage1.requires_grad = False

			stepLossStage2 = torch.zeros(1).to(device)
			stepLossStage2.requires_grad = False

			loss = torch.zeros(1).to(device)
			loss.requires_grad = False
			
			inData 	= sample['Input'].to(device)
			outData = sample['Output'].to(device)
			inSegMask = sample['Mask'].to(device)

			blurOut = model1(inData, inSegMask)

			for dispIndex in range(len(blurOut)) :

				scaleLossStage1 = torch.zeros(1).to(device)
				scaleLossStage1.requires_grad = False

				scaleLossStage2 = torch.zeros(1).to(device)
				scaleLossStage2.requires_grad = False

				dispOut 	= blurOut[dispIndex]

				inFocus	= F.interpolate(inData, size = None, scale_factor = 1 / pow(2.0, dispIndex), mode = 'bilinear')
				nearFocusStackGT 	= F.interpolate(outData, size = None, scale_factor = 1 / pow(2.0, dispIndex) , mode = 'bilinear')

				nearFocusGTSlices = helper.FocalStackToSlices(nearFocusStackGT, numSlices)
				blurMapSlices = helper.BlurMapsToSlices(dispOut, numSlices, 6)
				inFocusRegionSlices = []

				
				for slices in range(numSlices) :

					sliceLoss = torch.zeros(1).to(device)
					sliceLoss.requires_grad = False
					#print (time.clock() - start)
				
					nearFocusOutSlice = model2(inFocus, blurMapSlices[slices])
					#print (time.clock() - start)

					perceptLoss = contentLoss.get_loss(nearFocusGTSlices[slices], nearFocusOutSlice)
					simiLoss = losses.similarityLoss(nearFocusGTSlices[slices], nearFocusOutSlice, alpha, winSize)

					sliceLoss = perceptLossW * perceptLoss + simiLoss * simiLoss

					binaryBlurMapSlice = model3(nearFocusOutSlice)
					binaryBlurMapSlice = torch.cat((binaryBlurMapSlice, binaryBlurMapSlice, binaryBlurMapSlice), dim = 1)
					binaryLoss = losses.binarizationLoss(binaryBlurMapSlice) * binarizationLossW
					sliceLoss += binaryLoss
					inFocusRegionSlice = torch.mul(nearFocusOutSlice, binaryBlurMapSlice)
					inFocusRegionSlices.append(inFocusRegionSlice)

					scaleLossStage1 += sliceLoss

				inFocusOutput = helper.CombineFocusRegionsToAIF(inFocusRegionSlices, numSlices, device)

				perceptLossStage2 = contentLoss.get_loss(inFocus, inFocusOutput)
				simiLossStage2 = losses.similarityLoss(inFocus, inFocusOutput, alpha, winSize)
				scaleLossStage2 += perceptLossW * perceptLossStage2 + simiLossW * simiLossStage2

				stepLossStage1 += scaleLossStage1
				stepLossStage2 += scaleLossStage2

				loss += stepLossStage1 + stepLossStage2
			
			epochLoss[0] += loss

			if jointTraining == True :
				optimizer = optim.Adam(list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters()), lr=lr,\
				betas=(beta1, beta2), eps=1e-08, weight_decay=0, amsgrad=False)
				optimizer.zero_grad()
				loss.backward()
				optimizer.step()
			else :
				stage1Parameters = list(model1.parameters()) + list(model2.parameters())
				stage2Parameters = list(model3.parameters())

				optimizerStage1 = optim.Adam(stage1Parameters, lr=lr, betas=(beta1, beta2), eps=1e-08, weight_decay=0, amsgrad=False)
				optimizerStage2 = optim.Adam(stage2Parameters, lr=lr, betas=(beta1, beta2), eps=1e-08, weight_decay=0, amsgrad=False)

				optimizerStage1.zero_grad()
				stepLossStage1.backward(retain_graph = True)
				optimizerStage1.step()

				optimizerStage2.zero_grad()
				stepLossStage2.backward()
				optimizerStage2.step()

			if step % 500 == 0 :
					print ("Step num :" + str(step) + " Train Loss :", str(loss.item()))

		with torch.no_grad() :
			validRealStep = 0
			for validStep, validSample in enumerate(validLoaderLF) :

				if validRealStep % 3 == 1 and validRealStep < numRealSteps:
					validRealStepLoss = torch.zeros(1).to(device)
					validRealInData = validLoaderRealList[realStep]['Input'].to(device)
					validRealSegMask = validLoaderRealList[realStep]['Mask'].to(device)
					validRealBlurOut = model1(validRealInData, validRealSegMask)
					for validRealDispIndex in range(len(validRealBlurOut)) :
						validRealScaleLoss = torch.zeros(1).to(device)
						validRealDispOut = validRealBlurOut[validRealDispIndex]
						validRealInFocus = F.interpolate(validRealInData, size = None, scale_factor = 1/pow(2.0, validRealDispIndex), mode = 'bilinear')
						validRealBlurMapSlices = helper.BlurMapsToSlices(validRealDispOut, numSlices, 6)
						validRealInFocusRegionSlices = []
						for validRealSlices in range(numSlices) :
							validRealNearFocusOutSlice = model2(validRealInFocus, validRealBlurMapSlices[validRealSlices])
							validRealBinaryBlurMapSlice = model3(validRealNearFocusOutSlice)
							validRealBinaryBlurMapSlice = torch.cat((validRealBinaryBlurMapSlice, validRealBinaryBlurMapSlice, validRealBinaryBlurMapSlice), dim = 1)
							validRealInFocusRegionSlice = torch.mul(validRealNearFocusOutSlice, validRealBinaryBlurMapSlice)
							validRealInFocusRegionSlices.append(validRealInFocusRegionSlice)
						validRealInFocusOutput = helper.CombineFocusRegionsToAIF(validRealInFocusRegionSlices, numSlices, device)
						validRealPerceptLoss = contentLoss.get_loss(validRealInFocus, validRealInFocusOutput)
						validRealSimiLoss = losses.similarityLoss(validRealInFocus, validRealInFocusOutput, alpha, winSize)
						validRealScaleLoss += perceptLossW * validRealPerceptLoss + simiLossW * validRealSimiLoss
						validRealStepLoss += validRealScaleLoss

					validRealStep += 1
					validLoss += validRealStepLoss

				validStepLossStage1 = torch.zeros(1).to(device)
				validStepLossStage1.requires_grad = False

				validStepLossStage2 = torch.zeros(1).to(device)
				validStepLossStage2.requires_grad = False

				validLoss = torch.zeros(1).to(device)
				validLoss.requires_grad = False
			
				validInData 	= validSample['Input'].to(device)
				validOutData = validSample['Output'].to(device)
				validInSegMask = validSample['Mask'].to(device)

				validBlurOut = model1(validInData, validInSegMask)


				for validDispIndex in range(len(validBlurOut)) :

					validScaleLossStage1 = torch.zeros(1).to(device)
					validScaleLossStage1.requires_grad = False

					validScaleLossStage2 = torch.zeros(1).to(device)
					validScaleLossStage2.requires_grad = False

					validDispOut 	= validBlurOut[validDispIndex]

					validInFocus	= F.interpolate(validInData, size = None, scale_factor = 1 / pow(2.0, validDispIndex), mode = 'bilinear')
					validNearFocusStackGT 	= F.interpolate(validOutData, size = None, scale_factor = 1 / pow(2.0, validDispIndex) , mode = 'bilinear')

					validNearFocusGTSlices = helper.FocalStackToSlices(validNearFocusStackGT, numSlices)
					validBlurMapSlices = helper.BlurMapsToSlices(validDispOut, numSlices, 6)
					validInFocusRegionSlices = []

				
					for validSlices in range(numSlices) :

						validSliceLoss = torch.zeros(1).to(device)
						validSliceLoss.requires_grad = False

						validNearFocusOutSlice = model2(validInFocus, validBlurMapSlices[validSlices])

						validPerceptLoss = contentLoss.get_loss(validNearFocusGTSlices[slices], validNearFocusOutSlice)
						validSimiLoss = losses.similarityLoss(validNearFocusGTSlices[slices], validNearFocusOutSlice, alpha, winSize)

						validSliceLoss = perceptLossW * validPerceptLoss + validSimiLoss * validSimiLoss

						validBinaryBlurMapSlice = model3(validNearFocusOutSlice)
						validBinaryBlurMapSlice = torch.cat((validBinaryBlurMapSlice, validBinaryBlurMapSlice, validBinaryBlurMapSlice), dim = 1)
						validBinaryLoss = losses.binarizationLoss(validBinaryBlurMapSlice) * binarizationLossW
						validSliceLoss += validBinaryLoss
						validInFocusRegionSlice = torch.mul(validNearFocusOutSlice, validBinaryBlurMapSlice)
						validInFocusRegionSlices.append(validInFocusRegionSlice)

						validScaleLossStage1 += validSliceLoss

					validInFocusOutput = helper.CombineFocusRegionsToAIF(validInFocusRegionSlices, numSlices, device)

					validPerceptLossStage2 = contentLoss.get_loss(validInFocus, validInFocusOutput)
					validSimiLossStage2 = losses.similarityLoss(validInFocus, validInFocusOutput, alpha, winSize)
					validScaleLossStage2 += perceptLossW * validPerceptLossStage2 + simiLossW * validSimiLossStage2

					validStepLossStage1 += validScaleLossStage1
					validStepLossStage2 += validScaleLossStage2

					validLoss += validStepLossStage1 + validStepLossStage2
			
				epochLoss[1] += validLoss

		epochLoss[0] /= step
		epochLoss[1] /= validStep
		epochFreq = epochSaveFreq
		print ("Time taken for epoch " + str(epoch) + " is " + str(time.clock() - start))
		
		if validLoss < minValidLoss :
			minValidLoss = validLoss
			minValidEpoch = epoch
			torch.save(model1.state_dict(), model1SaveDir + str(epoch) + extension)
			torch.save(model2.state_dict(), model2SaveDir + str(epoch) + extension)
			torch.save(model3.state_dict(), model3SaveDir + str(epoch) + extension)
			print ("Saving latest model at epoch: " + str(epoch))

		if epochCount % epochFreq == 0 :
			torch.save(model1.state_dict(), model1SaveDir + str(epoch) + extension)
			torch.save(model2.state_dict(), model2SaveDir + str(epoch) + extension)
			torch.save(model3.state_dict(), model3SaveDir + str(epoch) + extension)
					
		print ("Training loss for epoch " + str(epochCount) + " : " + str(epochLoss[0]))
		print ("Validation loss for epoch " + str(epochCount) + " : " + str(epochLoss[1]))

		trainLossArray[epoch] = epochLoss[0]
		validLossArray[epoch] = epochLoss[1]
	
	print (minValidEpoch)
	return (trainLossArray, validLossArray)