def __init__(self, config, data_loader):
        """Set parameters of neural network and its training."""
        self.ssim_loss = SSIM()
        self.generator = None
        self.discriminator = None
        self.distance_based_loss = None

        self.g_optimizer = None
        self.d_optimizer = None

        self.g_conv_dim = 128

        self.beta1 = 0.9
        self.beta2 = 0.999
        self.learning_rate = 0.0001
        self.image_size = config.image_size
        self.num_epochs = config.num_epochs
        self.distance_weight = config.distance_weight
        self.noise = config.noise
        self.residual = config.residual

        self.data_loader = data_loader
        self.generate_path = config.generate_path
        self.model_path = config.model_path
        self.tensorboard = config.tensorboard

        if self.tensorboard:
            self.tb_writer = tensorboardX.SummaryWriter(
                filename_suffix='_%s_%s' %
                (config.distance_weight, config.dataset))
            self.tb_graph_added = False

        self.build_model()
Ejemplo n.º 2
0
def train(epoch, loader, model, optimizer, scheduler, device):
    loader = tqdm(loader)

    criterion = SSIM()#nn.MSELoss()

    latent_loss_weight = 0.25
    sample_size = 25

    mse_sum = 0
    mse_n = 0

    for i, (img, label) in enumerate(loader):
        model.zero_grad()

        img = img.to(device)

        out, latent_loss = model(img)
        recon_loss = criterion((out*0.5 + 0.5), (img*0.5 + 0.5))
        latent_loss = latent_loss.mean()
        loss = recon_loss + latent_loss_weight * latent_loss
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        mse_sum += recon_loss.item() * img.shape[0]
        mse_n += img.shape[0]

        print(recon_loss)

        lr = optimizer.param_groups[0]['lr']

        loader.set_description(
            (
                f'epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; '
                f'latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; '
                f'lr: {lr:.5f}'
            )
        )

        if i % 100 == 0:
            model.eval()

            sample = img[:sample_size]

            with torch.no_grad():
                out, _ = model(sample)

            utils.save_image(
                torch.cat([sample, out], 0),
                f'sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png',
                nrow=sample_size,
                normalize=True,
                range=(-1, 1),
            )

            model.train()
Ejemplo n.º 3
0
    def train(self):
        self.scheduler.step()
        self.loss.step()
        epoch = self.scheduler.last_epoch + 1
        lr = self.scheduler.get_lr()[0]
        self.args.Noisy = True
        self.ckp.write_log(
            '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
        )
        self.loss.start_log()
        self.model.train()
        criterion_ssim = SSIM(window_size=11, size_average=False)
        criterion_ssim = criterion_ssim.cuda()
        timer_data, timer_model = utility.timer(), utility.timer()
        for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):
            lr, hr = self.prepare(lr, hr)
            timer_data.hold()
            timer_model.tic()

            self.optimizer.zero_grad()
            sr = self.model(lr, idx_scale)
            loss = self.loss(sr, hr) #+ self.ssim*criterion_ssim(sr, hr)
            loss.backward()

            if self.args.gclip > 0:
                utils.clip_grad_value_(
                    self.model.parameters(),
                    self.args.gclip
                )
            self.optimizer.step()

            timer_model.hold()

            if (batch + 1) % self.args.print_every == 0:
                self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
                    (batch + 1) * self.args.batch_size,
                    len(self.loader_train.dataset),
                    self.loss.display_loss(batch),
                    timer_model.release(),
                    timer_data.release()))

            timer_data.tic()

        self.loss.end_log(len(self.loader_train))
        self.error_last = self.loss.log[-1, -1]
Ejemplo n.º 4
0
    def __init__(self):
        super(Model, self).__init__()

        self.cross_entropy = nn.CrossEntropyLoss()
        self.mse = nn.MSELoss(reduce=True, size_average=True)
        self.l1 = nn.L1Loss()
        self.SL1 = nn.SmoothL1Loss()
        self.ssim = SSIM(window_size=11)
        self.avg = nn.AdaptiveAvgPool2d(1)

        self.predict_net = PredictNet(64)
        self.device = next(self.predict_net.parameters()).device

        self.resnet = resnet50_backbone(pretrained=True)
        self.regression = nn.Sequential(
            nn.Conv2d(1 * 2, 1, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(inplace=True),
        )

        self.res_out = nn.Sequential(
            nn.Conv2d(2048, 1024, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(1024, 512, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(512, 64, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(inplace=True),
        )

        init_nets = [self.regression, self.predict_net, self.res_out]
        for net in init_nets:
            net.apply(weights_init_xavier)

        def get_log_diff_fn(eps=0.2):
            log_255_sq = np.float32(2 * np.log(255.0))
            log_255_sq = log_255_sq.item()  # int
            max_val = np.float32(log_255_sq - np.log(eps))
            max_val = max_val.item()  # int
            log_255_sq = torch.from_numpy(np.array(log_255_sq)).float().to(
                self.device)
            max_val = torch.from_numpy(np.array(max_val)).float().to(
                self.device)

            def log_diff_fn(in_a, in_b):
                diff = 255.0 * (in_a - in_b)
                val = log_255_sq - torch.log(diff**2 + eps)
                return val / max_val

            return log_diff_fn

        self.log_diff_fn = get_log_diff_fn(1)

        self.downsample_filter = DownSampleFilter()
        self.upsample_filter = UpSampleFilter()
Ejemplo n.º 5
0
 def compute_loss_ssim(self, img_pyramid, img_warped_pyramid,
                       occ_mask_list):
     loss_list = []
     for scale in range(self.num_scales):
         img, img_warped, occ_mask = img_pyramid[scale], img_warped_pyramid[
             scale], occ_mask_list[scale]
         divider = occ_mask.mean((1, 2, 3))
         occ_mask_pad = occ_mask.repeat(1, 3, 1, 1)
         ssim = SSIM(img * occ_mask_pad, img_warped * occ_mask_pad)
         loss_ssim = torch.clamp((1.0 - ssim) / 2.0, 0, 1).mean((1, 2, 3))
         loss_ssim = loss_ssim / (divider + 1e-12)
         loss_list.append(loss_ssim[:, None])
     loss = torch.cat(loss_list, 1).sum(1)
     return loss
Ejemplo n.º 6
0
    def __init__(self,
                 model,
                 lr=1e-1,
                 n_epochs=10,
                 verbose=True,
                 dir_base='./output/checkpoints'):
        self.model = model
        self.lr = lr
        self.n_epochs = n_epochs
        self.verbose = verbose

        # Initializations
        self.dir_base = dir_base
        if not os.path.exists(self.dir_base):
            os.makedirs(self.dir_base)
        self.dir_log = f'{dir_base}/log.txt'
        self.best_summary_loss = 10**5
        self.epoch = 0

        self.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.model.to(self.device)

        # Define the optimizer
        self.params = [p for p in self.model.parameters() if p.requires_grad]
        self.optimizer = torch.optim.AdamW(self.params, lr=self.lr)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=self.optimizer,
            mode='min',
            factor=0.5,
            patience=1,
            verbose=True,
            threshold=0.00005,
            threshold_mode='abs',
            cooldown=0,
            min_lr=1e-8,
            eps=1e-8)

        # Define the loss
        self.criterion = SSIM()
        self.log(f'====================================================')
        self.log(
            f'Fitter prepared | Time: {datetime.utcnow().isoformat()} | Device: {self.device}'
        )
Ejemplo n.º 7
0
    def compute_pairwise_loss(self, tgt_img, ref_img, tgt_depth, ref_depth,
                              pose, intrinsic):
        ref_img_warped, valid_mask, projected_depth, computed_depth = inverse_warp2(
            ref_img, tgt_depth, ref_depth, pose, intrinsic, 'zeros')

        diff_img = (tgt_img - ref_img_warped).abs()

        diff_depth = ((computed_depth - projected_depth).abs() /
                      (computed_depth + projected_depth).abs()).clamp(0, 1)

        ssim_map = (0.5 * (1 - SSIM(tgt_img, ref_img_warped))).clamp(0, 1)
        diff_img = (0.15 * diff_img + 0.85 * ssim_map)

        # Modified in 01.19.2020
        #weight_mask = (1 - diff_depth)
        #diff_img = diff_img * weight_mask

        # compute loss
        reconstruction_loss = diff_img.mean()
        geometry_consistency_loss = diff_depth.mean()
        #reconstruction_loss = mean_on_mask(diff_img, valid_mask)
        #geometry_consistency_loss = mean_on_mask(diff_depth, valid_mask)

        return reconstruction_loss, geometry_consistency_loss
Ejemplo n.º 8
0
def main():
    global args, best_loss
    global logger
    global device, kwargs

    if args.model_type == 'fcn':
        filter_list = [
            1,
            int(args.model_multiplier * 4),
            int(args.model_multiplier * 8),
            int(args.model_multiplier * 16),
            int(args.model_multiplier * 32),
            int(args.model_multiplier * 64), 10
        ]

        print('Model filter sizes list is {}'.format(filter_list))

        model = VAE(filters=filter_list,
                    dilations=[1, 1, 1, 1, 1, 1],
                    paddings=[0, 0, 0, 0, 0, 0],
                    strides=[1, 1, 2, 1, 2, 2],
                    decoder_kernels=[3, 4, 4, 4, 4, 4],
                    decoder_paddings=[1, 0, 0, 0, 0, 0],
                    decoder_strides=[1, 1, 1, 2, 2, 1],
                    split_filter=args.split_filter).to(device)

        print(model)
    elif args.model_type == 'fcns_1n':
        filter_list = [
            1,
            int(args.model_multiplier * 4),
            int(args.model_multiplier * 8),
            int(args.model_multiplier * 16),
        ]

        print('Model filter sizes list is {}'.format(filter_list))

        model = VAE1N(filters=filter_list,
                      dilations=[1, 1, 1],
                      paddings=[1, 1, 1],
                      strides=[2, 2, 2],
                      decoder_kernels=[4, 4, 3],
                      decoder_paddings=[1, 1, 1],
                      decoder_strides=[2, 2, 2],
                      latent_space_size=10).to(device)
    elif args.model_type == 'fcns':
        model = VAESimplifiedFC().to(device)
    elif args.model_type == 'fc':
        model = VAEBaseline(
            latent_space_size=args.latent_space_size).to(device)
    elif args.model_type == 'fc_conv':
        model = VAEBaselineConv(
            latent_space_size=args.latent_space_size).to(device)

    if args.optimizer.startswith('adam'):
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            # Only finetunable params
            lr=args.lr)
    elif args.optimizer.startswith('rmsprop'):
        optimizer = torch.optim.RMSprop(
            filter(lambda p: p.requires_grad, model.parameters()),
            # Only finetunable params
            lr=args.lr)
    elif args.optimizer.startswith('sgd'):
        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            # Only finetunable params
            lr=args.lr)
    else:
        raise ValueError('Optimizer not supported')

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.predict:
        pass
    elif args.evaluate:
        pass
    else:
        if args.dataset_type == 'fmnist':
            train_dataset = FMNISTDataset(mode='train',
                                          random_state=args.seed,
                                          use_augs=args.do_augs)

            val_dataset = FMNISTDataset(mode='val',
                                        random_state=args.seed,
                                        use_augs=False)

            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                drop_last=False,
                **kwargs)

            val_loader = torch.utils.data.DataLoader(
                val_dataset,
                batch_size=args.batch_size,
                shuffle=True,
                drop_last=False,
                **kwargs)
        elif args.dataset_type == 'mnist':
            train_loader = torch.utils.data.DataLoader(
                datasets.MNIST('../data',
                               train=True,
                               download=True,
                               transform=transforms.ToTensor()),
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)

            val_loader = torch.utils.data.DataLoader(
                datasets.MNIST('../data',
                               train=False,
                               transform=transforms.ToTensor()),
                batch_size=args.batch_size,
                shuffle=True,
                **kwargs)

        criterion = VAELoss(
            use_running_mean=args.do_running_mean,
            image_loss_type=args.image_loss_type,
            image_loss_weight=args.img_loss_weight,
            kl_loss_weight=args.kl_loss_weight,
            ssim_window_size=args.ssim_window_size,
            latent_space_size=args.latent_space_size).to(device)

        # criterion = loss_function

        ssim = SSIM(window_size=args.ssim_window_size,
                    size_average=True).to(device)

        scheduler = MultiStepLR(optimizer,
                                milestones=[args.m1, args.m2],
                                gamma=0.1)

        for epoch in range(args.start_epoch, args.epochs):

            # train for one epoch
            train_loss, train_img_loss, train_kl_loss, train_ssim = train(
                train_loader, model, criterion, ssim, optimizer, epoch)

            # evaluate on validation set
            val_loss, val_img_loss, val_kl_loss, val_ssim = validate(
                val_loader, model, criterion, ssim)

            scheduler.step()

            #============ TensorBoard logging ============#
            # Log the scalar values
            if args.tensorboard:
                info = {
                    'eph_tr_loss': train_loss,
                    'eph_tr_ssim': train_ssim,
                    'eph_val_loss': val_loss,
                    'eph_val_ssim': val_ssim,
                    'eph_tr_img_loss': train_img_loss,
                    'eph_tr_kl_loss': train_kl_loss,
                    'eph_val_img_loss': val_img_loss,
                    'eph_val_kl_loss': val_kl_loss,
                }
                for tag, value in info.items():
                    logger.scalar_summary(tag, value, epoch + 1)

            # remember best prec@1 and save checkpoint
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'optimizer': optimizer.state_dict(),
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                }, is_best,
                'weights/{}_checkpoint.pth.tar'.format(str(args.lognumber)),
                'weights/{}_best.pth.tar'.format(str(args.lognumber)))
Ejemplo n.º 9
0
    G = model.VGG_VAE(5)

    D.apply(weights_init)
    G.apply(weights_init)

    D.cuda()
    G.cuda()
    print(D)
    print(G)
    D_criterion = torch.nn.BCEWithLogitsLoss().cuda()
    D_optimizer = torch.optim.SGD(D.parameters(), lr=1e-3)

    G_criterion = torch.nn.BCEWithLogitsLoss().cuda()
    G_l1 = torch.nn.L1Loss().cuda()
    G_msssim = MSSSIM().cuda()
    G_ssim = SSIM().cuda()
    G_optimizer = torch.optim.Adam(G.parameters(), lr=1e-3)

    pathlib.Path(sample_output).mkdir(parents=True, exist_ok=True)
    pathlib.Path(os.path.join(sample_output, "images")).mkdir(parents=True, exist_ok=True)
    d_loss = 0
    g_loss = 0
    
    d_to_g_threshold = 0.5
    g_to_d_threshold = 0.3

    train_d = True
    train_g = True

    conditional_training = False
    _si = 1