def train_model(train_loader, model, vgg, criterion, optimizer, epoch, tb_writer): losses = AverageMeter() hole_losses = AverageMeter() valid_losses = AverageMeter() style_losses = AverageMeter() content_losses = AverageMeter() tv_losses = AverageMeter() s1 = AverageMeter() s2 = AverageMeter() s3 = AverageMeter() s4 = AverageMeter() s5 = AverageMeter() # ensure model is in train mode model.train() pbar = tqdm(train_loader) for i, data in enumerate(pbar): inputs = data['hole_img'].float() labels = data['ori_img'].float() ori_img = labels.clone() # mask: 1 for the hole and 0 for others masks = data['mask'].float() inputs = inputs.to(config.device) labels = labels.to(config.device) masks = masks.to(config.device) ori_img = ori_img.to(config.device) # pass this batch through our model and get y_pred outputs = model(inputs) # use five different level features, each are extracted after down-sampling targets = vgg(ori_img) features = vgg(outputs) # get content and style loss content_loss = 0 style_loss = 0 now_style_loss = [0.0, 0.0, 0.0, 0.0, 0.0] # np.ndarray(shape=(5, )) for k in range(inputs.size(0)): content_loss += torch.sum((features[3][k] - targets[3][k])**2) / 2 # now_content_loss = F.mse_loss(features[3][k], targets[3][k]) # content_loss = content_loss + now_content_loss targets_gram = [gram_matrix(f[k]) for f in targets] features_gram = [gram_matrix(f[k]) for f in features] # style_loss += torch.sum(torch.mean((targets - features_gram) ** 2, dim = 0)) for j in range(len(targets_gram)): now_style_loss[j] = torch.sum( (features_gram[j] - targets_gram[j])**2) style_loss = style_loss + now_style_loss[j] style_loss /= inputs.size(0) content_loss /= inputs.size(0) style_losses.update(style_loss.item(), inputs.size(0)) content_losses.update(content_loss.item(), inputs.size(0)) # update loss metric # suppose criterion is L1 loss hole_loss = criterion(outputs * masks, labels * masks) valid_loss = criterion(outputs * (1 - masks), labels * (1 - masks)) hole_losses.update(hole_loss.item(), inputs.size(0)) valid_losses.update(valid_loss.item(), inputs.size(0)) write_avgs([s1, s2, s3, s4, s5], now_style_loss) # get total variation loss outputs_hole = outputs * masks targets_hole = labels * masks tv_loss = torch.sum(torch.abs(outputs_hole[:, :, :, 1:] - targets_hole[:, :, :, :-1])) \ + torch.sum(torch.abs(outputs_hole[:, :, 1:, :] - targets_hole[:, :, :-1, :])) tv_loss /= inputs.size(0) tv_losses.update(tv_loss.item(), inputs.size(0)) # total loss loss = hole_loss * rHole_Loss_weight + valid_loss * rValid_Loss_weight + \ style_loss * rStyle_Loss_weight + content_loss * rContent_Loss_weight + \ tv_loss * rTv_Loss_weight losses.update(loss.item(), inputs.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_description("EPOCH[{}][{}/{}]".format(epoch, i, len(train_loader))) pbar.set_postfix(loss="LOSS:{:.4f}".format(losses.avg)) tb_writer.add_scalar('train/epoch_loss', losses.avg, epoch) tb_writer.add_scalar('train/hole_loss', hole_losses.avg * Hole_Loss_weight, epoch) tb_writer.add_scalar('train/valid_loss', valid_losses.avg * Valid_Loss_weight, epoch) tb_writer.add_scalar('train/style_loss', style_losses.avg * Style_Loss_weight, epoch) tb_writer.add_scalar('train/content_loss', content_losses.avg * Content_Loss_weight, epoch) tb_writer.add_scalar('train/tv_loss', tv_losses.avg * Tv_Loss_weight, epoch) write_tensor(perceptual_style_name, [s1, s2, s3, s4, s5], epoch, tb_writer) torch.cuda.empty_cache() return
def valid_model(valid_loader, model, vgg, criterion, optimizer, epoch, tb_writer): losses = AverageMeter() hole_losses = AverageMeter() valid_losses = AverageMeter() style_losses = AverageMeter() content_losses = AverageMeter() tv_losses = AverageMeter() s1 = AverageMeter() s2 = AverageMeter() s3 = AverageMeter() s4 = AverageMeter() s5 = AverageMeter() # ensure model is in train mode model.eval() vgg.eval() pbar = tqdm(valid_loader) for i, data in enumerate(pbar): inputs = data['hole_img'].float() labels = data['ori_img'].float() ori_img = labels.clone() # mask: 1 for the hole and 0 for others masks = data['mask'].float() inputs = inputs.to(config.device) labels = labels.to(config.device) masks = masks.to(config.device) ori_img = ori_img.to(config.device) with torch.no_grad(): # pass this batch through our model and get y_pred outputs = model(inputs) targets = vgg(ori_img) features = vgg(outputs) # get content and style loss content_loss = 0 style_loss = 0 now_style_loss = [0.0, 0.0, 0.0, 0.0, 0.0] # np.ndarray(shape=(5, )) for k in range(inputs.size(0)): content_loss += torch.sum( (features[3][k] - targets[3][k])**2) / 2 # now_content_loss = F.mse_loss(features[3][k], targets[3][k]) # content_loss = content_loss + now_content_loss targets_gram = [gram_matrix(f[k]) for f in targets] features_gram = [gram_matrix(f[k]) for f in features] # style_loss += torch.sum(torch.mean((targets - features_gram) ** 2, dim = 0)) for j in range(len(targets_gram)): now_style_loss[j] = torch.sum( (features_gram[j] - targets_gram[j])**2) style_loss = style_loss + now_style_loss[j] style_loss /= inputs.size(0) content_loss /= inputs.size(0) style_losses.update(style_loss.item(), inputs.size(0)) content_losses.update(content_loss.item(), inputs.size(0)) # update loss metric # suppose criterion is L1 loss hole_loss = criterion(outputs * masks, labels * masks) valid_loss = criterion(outputs * (1 - masks), labels * (1 - masks)) hole_losses.update(hole_loss.item(), inputs.size(0)) valid_losses.update(valid_loss.item(), inputs.size(0)) # get total variation loss outputs_hole = outputs * masks targets_hole = labels * masks tv_loss = torch.sum(torch.abs(outputs_hole[:, :, :, 1:] - targets_hole[:, :, :, :-1])) \ + torch.sum(torch.abs(outputs_hole[:, :, 1:, :] - targets_hole[:, :, :-1, :])) tv_loss /= inputs.size(0) tv_losses.update(tv_loss.item(), inputs.size(0)) # total loss loss = hole_loss * rHole_Loss_weight + valid_loss * rValid_Loss_weight + \ style_loss * rStyle_Loss_weight + content_loss * rContent_Loss_weight + \ tv_loss * rTv_Loss_weight losses.update(loss.item(), inputs.size(0)) write_avgs([s1, s2, s3, s4, s5], now_style_loss) if i == 0: for j in range(min(inputs.size(0), 3)): hole_img = data['hole_img'][j] ori_img = data['ori_img'][j] out_img = outputs[j].detach() out_img = out_img / (torch.max(out_img) - torch.min(out_img)) tb_writer.add_image('valid/ori_img{}'.format(j), ori_img, epoch) tb_writer.add_image('valid/hole_img{}'.format(j), hole_img, epoch) tb_writer.add_image('valid/out_img{}'.format(j), out_img, epoch) pbar.set_description("EPOCH[{}][{}/{}]".format(epoch, i, len(valid_loader))) pbar.set_postfix(loss="LOSS:{:.4f}".format(losses.avg)) tb_writer.add_scalar('valid/epoch_loss', losses.avg, epoch) tb_writer.add_scalar('valid/hole_loss', hole_losses.avg * Hole_Loss_weight, epoch) tb_writer.add_scalar('valid/valid_loss', valid_losses.avg * Valid_Loss_weight, epoch) tb_writer.add_scalar('valid/style_loss', style_losses.avg * Style_Loss_weight, epoch) tb_writer.add_scalar('valid/content_loss', content_losses.avg * Content_Loss_weight, epoch) tb_writer.add_scalar('valid/tv_loss', tv_losses.avg * Tv_Loss_weight, epoch) write_tensor(t_perceptual_style_name, [s1, s2, s3, s4, s5], epoch, tb_writer) torch.cuda.empty_cache() outspects = { 'epoch_loss': losses.avg, } return outspects