示例#1
0
def train_model(train_loader, model, vgg, criterion, optimizer, epoch,
                tb_writer):
    losses = AverageMeter()
    hole_losses = AverageMeter()
    valid_losses = AverageMeter()
    style_losses = AverageMeter()
    content_losses = AverageMeter()
    tv_losses = AverageMeter()

    s1 = AverageMeter()
    s2 = AverageMeter()
    s3 = AverageMeter()
    s4 = AverageMeter()
    s5 = AverageMeter()
    # ensure model is in train mode

    model.train()
    pbar = tqdm(train_loader)
    for i, data in enumerate(pbar):
        inputs = data['hole_img'].float()
        labels = data['ori_img'].float()
        ori_img = labels.clone()
        # mask: 1 for the hole and 0 for others
        masks = data['mask'].float()
        inputs = inputs.to(config.device)
        labels = labels.to(config.device)
        masks = masks.to(config.device)
        ori_img = ori_img.to(config.device)

        # pass this batch through our model and get y_pred
        outputs = model(inputs)
        # use five different level features, each are extracted after down-sampling
        targets = vgg(ori_img)
        features = vgg(outputs)

        # get content and style loss
        content_loss = 0
        style_loss = 0

        now_style_loss = [0.0, 0.0, 0.0, 0.0, 0.0]  # np.ndarray(shape=(5, ))
        for k in range(inputs.size(0)):
            content_loss += torch.sum((features[3][k] - targets[3][k])**2) / 2
            # now_content_loss = F.mse_loss(features[3][k], targets[3][k])
            # content_loss = content_loss + now_content_loss
            targets_gram = [gram_matrix(f[k]) for f in targets]
            features_gram = [gram_matrix(f[k]) for f in features]

            # style_loss += torch.sum(torch.mean((targets - features_gram) ** 2, dim = 0))
            for j in range(len(targets_gram)):
                now_style_loss[j] = torch.sum(
                    (features_gram[j] - targets_gram[j])**2)
                style_loss = style_loss + now_style_loss[j]
        style_loss /= inputs.size(0)
        content_loss /= inputs.size(0)
        style_losses.update(style_loss.item(), inputs.size(0))
        content_losses.update(content_loss.item(), inputs.size(0))

        # update loss metric
        # suppose criterion is L1 loss
        hole_loss = criterion(outputs * masks, labels * masks)
        valid_loss = criterion(outputs * (1 - masks), labels * (1 - masks))
        hole_losses.update(hole_loss.item(), inputs.size(0))
        valid_losses.update(valid_loss.item(), inputs.size(0))

        write_avgs([s1, s2, s3, s4, s5], now_style_loss)

        # get total variation loss
        outputs_hole = outputs * masks
        targets_hole = labels * masks
        tv_loss = torch.sum(torch.abs(outputs_hole[:, :, :, 1:] - targets_hole[:, :, :, :-1])) \
                  + torch.sum(torch.abs(outputs_hole[:, :, 1:, :] - targets_hole[:, :, :-1, :]))
        tv_loss /= inputs.size(0)
        tv_losses.update(tv_loss.item(), inputs.size(0))

        # total loss
        loss = hole_loss * rHole_Loss_weight + valid_loss * rValid_Loss_weight + \
               style_loss * rStyle_Loss_weight + content_loss * rContent_Loss_weight + \
               tv_loss * rTv_Loss_weight
        losses.update(loss.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description("EPOCH[{}][{}/{}]".format(epoch, i,
                                                       len(train_loader)))
        pbar.set_postfix(loss="LOSS:{:.4f}".format(losses.avg))

    tb_writer.add_scalar('train/epoch_loss', losses.avg, epoch)
    tb_writer.add_scalar('train/hole_loss', hole_losses.avg * Hole_Loss_weight,
                         epoch)
    tb_writer.add_scalar('train/valid_loss',
                         valid_losses.avg * Valid_Loss_weight, epoch)
    tb_writer.add_scalar('train/style_loss',
                         style_losses.avg * Style_Loss_weight, epoch)
    tb_writer.add_scalar('train/content_loss',
                         content_losses.avg * Content_Loss_weight, epoch)
    tb_writer.add_scalar('train/tv_loss', tv_losses.avg * Tv_Loss_weight,
                         epoch)

    write_tensor(perceptual_style_name, [s1, s2, s3, s4, s5], epoch, tb_writer)

    torch.cuda.empty_cache()
    return
示例#2
0
def valid_model(valid_loader, model, vgg, criterion, optimizer, epoch,
                tb_writer):
    losses = AverageMeter()
    hole_losses = AverageMeter()
    valid_losses = AverageMeter()
    style_losses = AverageMeter()
    content_losses = AverageMeter()
    tv_losses = AverageMeter()

    s1 = AverageMeter()
    s2 = AverageMeter()
    s3 = AverageMeter()
    s4 = AverageMeter()
    s5 = AverageMeter()
    # ensure model is in train mode
    model.eval()
    vgg.eval()
    pbar = tqdm(valid_loader)
    for i, data in enumerate(pbar):
        inputs = data['hole_img'].float()
        labels = data['ori_img'].float()
        ori_img = labels.clone()
        # mask: 1 for the hole and 0 for others
        masks = data['mask'].float()
        inputs = inputs.to(config.device)
        labels = labels.to(config.device)
        masks = masks.to(config.device)
        ori_img = ori_img.to(config.device)

        with torch.no_grad():
            # pass this batch through our model and get y_pred
            outputs = model(inputs)
            targets = vgg(ori_img)
            features = vgg(outputs)

            # get content and style loss
            content_loss = 0
            style_loss = 0

            now_style_loss = [0.0, 0.0, 0.0, 0.0,
                              0.0]  # np.ndarray(shape=(5, ))
            for k in range(inputs.size(0)):
                content_loss += torch.sum(
                    (features[3][k] - targets[3][k])**2) / 2
                # now_content_loss = F.mse_loss(features[3][k], targets[3][k])
                # content_loss = content_loss + now_content_loss
                targets_gram = [gram_matrix(f[k]) for f in targets]
                features_gram = [gram_matrix(f[k]) for f in features]

                # style_loss += torch.sum(torch.mean((targets - features_gram) ** 2, dim = 0))
                for j in range(len(targets_gram)):
                    now_style_loss[j] = torch.sum(
                        (features_gram[j] - targets_gram[j])**2)
                    style_loss = style_loss + now_style_loss[j]

            style_loss /= inputs.size(0)
            content_loss /= inputs.size(0)
            style_losses.update(style_loss.item(), inputs.size(0))
            content_losses.update(content_loss.item(), inputs.size(0))

            # update loss metric
            # suppose criterion is L1 loss
            hole_loss = criterion(outputs * masks, labels * masks)
            valid_loss = criterion(outputs * (1 - masks), labels * (1 - masks))
            hole_losses.update(hole_loss.item(), inputs.size(0))
            valid_losses.update(valid_loss.item(), inputs.size(0))

            # get total variation loss
            outputs_hole = outputs * masks
            targets_hole = labels * masks
            tv_loss = torch.sum(torch.abs(outputs_hole[:, :, :, 1:] - targets_hole[:, :, :, :-1])) \
                      + torch.sum(torch.abs(outputs_hole[:, :, 1:, :] - targets_hole[:, :, :-1, :]))
            tv_loss /= inputs.size(0)
            tv_losses.update(tv_loss.item(), inputs.size(0))

            # total loss
            loss = hole_loss * rHole_Loss_weight + valid_loss * rValid_Loss_weight + \
                   style_loss * rStyle_Loss_weight + content_loss * rContent_Loss_weight + \
                   tv_loss * rTv_Loss_weight
            losses.update(loss.item(), inputs.size(0))

            write_avgs([s1, s2, s3, s4, s5], now_style_loss)
            if i == 0:
                for j in range(min(inputs.size(0), 3)):
                    hole_img = data['hole_img'][j]
                    ori_img = data['ori_img'][j]
                    out_img = outputs[j].detach()
                    out_img = out_img / (torch.max(out_img) -
                                         torch.min(out_img))
                    tb_writer.add_image('valid/ori_img{}'.format(j), ori_img,
                                        epoch)
                    tb_writer.add_image('valid/hole_img{}'.format(j), hole_img,
                                        epoch)
                    tb_writer.add_image('valid/out_img{}'.format(j), out_img,
                                        epoch)

        pbar.set_description("EPOCH[{}][{}/{}]".format(epoch, i,
                                                       len(valid_loader)))
        pbar.set_postfix(loss="LOSS:{:.4f}".format(losses.avg))

    tb_writer.add_scalar('valid/epoch_loss', losses.avg, epoch)
    tb_writer.add_scalar('valid/hole_loss', hole_losses.avg * Hole_Loss_weight,
                         epoch)
    tb_writer.add_scalar('valid/valid_loss',
                         valid_losses.avg * Valid_Loss_weight, epoch)
    tb_writer.add_scalar('valid/style_loss',
                         style_losses.avg * Style_Loss_weight, epoch)
    tb_writer.add_scalar('valid/content_loss',
                         content_losses.avg * Content_Loss_weight, epoch)
    tb_writer.add_scalar('valid/tv_loss', tv_losses.avg * Tv_Loss_weight,
                         epoch)

    write_tensor(t_perceptual_style_name, [s1, s2, s3, s4, s5], epoch,
                 tb_writer)

    torch.cuda.empty_cache()
    outspects = {
        'epoch_loss': losses.avg,
    }
    return outspects