class Trainer: def __init__(self, config, dataloader): self.batch_size = config.batch_size self.config = config self.lr = config.lr self.epoch = config.epoch self.num_epoch = config.num_epoch self.checkpoint_dir = config.checkpoint_dir self.model_path = config.checkpoint_dir self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") self.data_loader = dataloader self.image_len = len(dataloader) self.num_classes = config.num_classes self.eps = config.eps self.rho = config.rho self.decay = config.decay self.sample_step = config.sample_step self.sample_dir = config.sample_dir self.gradient_loss_weight = config.gradient_loss_weight self.decay_batch_size = config.decay_batch_size self.build_model() self.optimizer = Adadelta(self.net.parameters(), lr=self.lr, eps=self.eps, rho=self.rho, weight_decay=self.decay) self.lr_scheduler_discriminator = torch.optim.lr_scheduler.LambdaLR( self.optimizer, LambdaLR(self.num_epoch, self.epoch, len(self.data_loader), self.decay_batch_size).step) def build_model(self): self.net = MobileHairNet().to(self.device) self.load_model() def load_model(self): print("[*] Load checkpoint in ", str(self.model_path)) if not os.path.exists(self.model_path): os.makedirs(self.model_path) if not os.listdir(self.model_path): print("[!] No checkpoint in ", str(self.model_path)) return model_path = os.path.join(self.model_path, f"MobileHairNet_epoch-{self.epoch-1}.pth") model = glob(model_path) model.sort() if not model: print(f"[!] No Checkpoint in {model_path}") return self.net.load_state_dict( torch.load(model[-1], map_location=self.device)) print(f"[*] Load Model from {model[-1]}: ") def train(self): bce_losses = AverageMeter() image_gradient_losses = AverageMeter() image_gradient_criterion = ImageGradientLoss().to(self.device) bce_criterion = nn.CrossEntropyLoss().to(self.device) for epoch in range(self.epoch, self.num_epoch): bce_losses.reset() image_gradient_losses.reset() for step, (image, gray_image, mask) in enumerate(self.data_loader): image = image.to(self.device) mask = mask.to(self.device) gray_image = gray_image.to(self.device) pred = self.net(image) pred_flat = pred.permute(0, 2, 3, 1).contiguous().view( -1, self.num_classes) mask_flat = mask.squeeze(1).view(-1).long() # preds_flat.shape (N*224*224, 2) # masks_flat.shape (N*224*224, 1) image_gradient_loss = image_gradient_criterion( pred, gray_image) bce_loss = bce_criterion(pred_flat, mask_flat) loss = bce_loss + self.gradient_loss_weight * image_gradient_loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() bce_losses.update(bce_loss.item(), self.batch_size) image_gradient_losses.update( self.gradient_loss_weight * image_gradient_loss, self.batch_size) iou = iou_loss(pred, mask) # save sample images if step % 10 == 0: print( f"Epoch: [{epoch}/{self.num_epoch}] | Step: [{step}/{self.image_len}] | " f"Bce Loss: {bce_losses.avg:.4f} | Image Gradient Loss: {image_gradient_losses.avg:.4f} | " f"IOU: {iou:.4f}") if step % self.sample_step == 0: self.save_sample_imgs(image[0], mask[0], torch.argmax(pred[0], 0), self.sample_dir, epoch, step) print('[*] Saved sample images') torch.save( self.net.state_dict(), f'{self.checkpoint_dir}/MobileHairNet_epoch-{epoch}.pth') def save_sample_imgs(self, real_img, real_mask, prediction, save_dir, epoch, step): data = [real_img, real_mask, prediction] names = ["Image", "Mask", "Prediction"] fig = plt.figure() for i, d in enumerate(data): d = d.squeeze() im = d.data.cpu().numpy() if i > 0: im = np.expand_dims(im, axis=0) im = np.concatenate((im, im, im), axis=0) im = (im.transpose(1, 2, 0) + 1) / 2 f = fig.add_subplot(1, 3, i + 1) f.imshow(im) f.set_title(names[i]) f.set_xticks([]) f.set_yticks([]) p = os.path.join(save_dir, "epoch-%s_step-%s.png" % (epoch, step)) plt.savefig(p)
class HairSegmentation(object): def __init__(self, training_data_path, valid_data_path, test_data_path, resolution, num_classes, decay_epoch, lr, rho, eps, decay, gradient_loss_weight, resume_epochs, log_step, sample_step, num_epochs, batch_size, train_results_dir, valid_results_dir, test_results_dir, model_save_dir, log_dir): self.training_data_path = training_data_path self.valid_data_path = valid_data_path self.test_data_path = test_data_path self.resolution = resolution self.num_classes = num_classes self.decay_epoch = decay_epoch self.lr = lr self.rho = rho self.eps = eps self.decay = decay self.gradient_loss_weight = gradient_loss_weight self.resume_epochs = resume_epochs self.log_step = log_step self.sample_step = sample_step self.num_epochs = num_epochs self.batch_size = batch_size self.train_results_dir = train_results_dir self.valid_results_dir = valid_results_dir self.test_results_dir = test_results_dir self.model_save_dir = model_save_dir self.log_dir = log_dir self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.colors = [(0, 0, 255), (255, 0, 0), (255, 0, 255), (255, 166, 0), (255, 255, 0), (0, 255, 0), (0, 191, 255), (255, 192, 203)] self.create_generator() self.build_model() self.writer = tensorboardX.SummaryWriter(self.log_dir) def create_generator(self): self.transform = transforms.Compose([ transforms.Resize((self.resolution, self.resolution)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_data = Generator(self.training_data_path, 'train', self.resolution) self.train_dataloader = DataLoader(train_data, shuffle=True, batch_size=self.batch_size, num_workers=4, drop_last=True) valid_data = Generator(self.valid_data_path, 'valid', self.resolution) self.valid_dataloader = DataLoader(valid_data, shuffle=True, batch_size=self.batch_size, num_workers=4, drop_last=True) test_data = Generator(self.test_data_path, 'test', self.resolution) self.test_dataloader = DataLoader(test_data, shuffle=True, batch_size=self.batch_size, num_workers=4, drop_last=True) def build_model(self): self.net = HairMatteNet() self.net.to(self.device) self.optimizer = Adadelta(self.net.parameters(), lr=self.lr, eps=self.eps, rho=self.rho, weight_decay=self.decay) def restore_model(self, resume_epochs): print('Loading the trained models from epoch {}...'.format( resume_epochs)) net_path = os.path.join( self.model_save_dir, '{}_epoch-HairMatteNet.ckpt'.format(resume_epochs)) self.net.load_state_dict( torch.load(net_path, map_location=lambda storage, loc: storage)) def train_epoch(self, epoch, start_time): self.net.train() for i, data in enumerate(self.train_dataloader, 0): image = data[0].to(self.device) gray_image = data[1].to(self.device) mask = data[2].to(self.device) pred = self.net(image) pred_flat = pred.permute(0, 2, 3, 1).contiguous().view( -1, self.num_classes) mask_flat = mask.squeeze(1).view(-1).long() image_gradient_loss = self.image_gradient_criterion( pred, gray_image) bce_loss = self.bce_criterion(pred_flat, mask_flat) loss = bce_loss + self.gradient_loss_weight * image_gradient_loss iou = iou_metric(pred, mask) self.optimizer.zero_grad() loss.backward() self.optimizer.step() losses = {} losses['train_bce_loss'] = bce_loss.item() losses[ 'train_image_gradient_loss'] = self.gradient_loss_weight * image_gradient_loss losses['train_loss'] = loss losses['train_iou'] = iou if (i + 1) % self.log_step == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] log = "Elapsed [{}], Iteration [{}/{}], Epoch [{}/{}]".format( et, i + 1, len(self.train_dataloader), epoch, self.num_epochs) for tag, value in losses.items(): log += ", {}: {:.4f}".format(tag, value) self.writer.add_scalar( tag, value, epoch * len(self.train_dataloader) + i + 1) print(log) if (i + 1) % self.sample_step == 0: with torch.no_grad(): out_results = [] for j in range(10): out_results.append(denorm(image[j:j + 1]).data.cpu()) out_results.append( mask.expand(-1, 3, -1, -1)[j:j + 1].data.cpu()) out_results.append( torch.argmax(pred[j:j + 1], 1).unsqueeze(0).expand(-1, 3, -1, -1).data.cpu()) color = random.choice(self.colors) result = dye_hair(denorm(image[j:j + 1]), mask[j:j + 1], color) result = self.transform( Image.fromarray(result)).unsqueeze(0) out_results.append(denorm(result)) result = dye_hair( denorm(image[j:j + 1]), torch.argmax(pred[j:j + 1], 1).unsqueeze(0), color) result = self.transform( Image.fromarray(result)).unsqueeze(0) out_results.append(denorm(result)) results_concat = torch.cat(out_results) results_path = os.path.join( self.train_results_dir, '{}_epoch_train_results.jpg'.format(epoch)) save_image(results_concat, results_path, nrow=5, padding=0) print('Saved real and fake images into {}...'.format( results_path)) if (epoch + 1) % 2 == 0: net_path = os.path.join(self.model_save_dir, '{}_epoch-HairMatteNet.ckpt'.format(epoch)) torch.save(self.net.state_dict(), net_path) print('Saved model checkpoints into {}...'.format( self.model_save_dir)) def valid_epoch(self, epoch): self.net.eval() losses = { 'valid_bce_loss': 0, 'valid_image_gradient_loss': 0, 'valid_loss': 0, 'valid_iou': 0 } for i, data in enumerate(self.valid_dataloader, 0): image = data[0].to(self.device) gray_image = data[1].to(self.device) mask = data[2].to(self.device) with torch.no_grad(): pred = self.net(image) pred_flat = pred.permute(0, 2, 3, 1).contiguous().view( -1, self.num_classes) mask_flat = mask.squeeze(1).view(-1).long() image_gradient_loss = self.image_gradient_criterion( pred, gray_image) bce_loss = self.bce_criterion(pred_flat, mask_flat) loss = bce_loss + self.gradient_loss_weight * image_gradient_loss iou = iou_metric(pred, mask) losses['valid_bce_loss'] += bce_loss.item() losses[ 'valid_image_gradient_loss'] += self.gradient_loss_weight * image_gradient_loss losses['valid_loss'] += loss losses['valid_iou'] += iou if i == 0: with torch.no_grad(): out_results = [] for j in range(10): out_results.append(denorm(image[j:j + 1]).data.cpu()) out_results.append( mask.expand(-1, 3, -1, -1)[j:j + 1].data.cpu()) out_results.append( torch.argmax(pred[j:j + 1], 1).unsqueeze(0).expand(-1, 3, -1, -1).data.cpu()) color = random.choice(self.colors) result = dye_hair(denorm(image[j:j + 1]), mask[j:j + 1], color) result = self.transform( Image.fromarray(result)).unsqueeze(0) out_results.append(denorm(result)) result = dye_hair( denorm(image[j:j + 1]), torch.argmax(pred[j:j + 1], 1).unsqueeze(0), color) result = self.transform( Image.fromarray(result)).unsqueeze(0) out_results.append(denorm(result)) results_concat = torch.cat(out_results) results_path = os.path.join( self.valid_results_dir, '{}_epoch_valid_results.jpg'.format(epoch)) save_image(results_concat, results_path, nrow=5, padding=0) print('Saved real and fake images into {}...'.format( results_path)) losses['valid_bce_loss'] /= i losses['valid_image_gradient_loss'] /= i losses['valid_loss'] /= i losses['valid_iou'] /= i log = "Eval ========================= Epoch [{}/{}]".format( epoch, self.num_epochs) for tag, value in losses.items(): log += ", {}: {:.4f}".format(tag, value) self.writer.add_scalar( tag, value, epoch * len(self.train_dataloader) + len(self.train_dataloader) + 1) print(log) def train(self): if self.resume_epochs != 0: self.restore_model(self.resume_epochs) self.resume_epochs += 1 self.image_gradient_criterion = ImageGradientLoss().to(self.device) self.bce_criterion = nn.CrossEntropyLoss().to(self.device) start_time = time.time() for epoch in range(self.resume_epochs, self.num_epochs, 1): self.train_epoch(epoch, start_time) self.valid_epoch(epoch) self.writer.close() def test(self): self.restore_model(self.resume_epochs) self.net.eval() metrics = {'iou': 0, 'f1_score': 0, 'acc': 0} for i, data in enumerate(self.valid_dataloader, 0): image = data[0].to(self.device) mask = data[2].to(self.device) with torch.no_grad(): pred = self.net(image) metrics['iou'] += iou_metric(pred, mask) metrics['f1_score'] += F1_metric(pred, mask) metrics['acc'] += acc_metric(pred, mask) if i == 0: with torch.no_grad(): out_results = [] for j in range(10): out_results.append(denorm(image[j:j + 1]).data.cpu()) out_results.append( mask.expand(-1, 3, -1, -1)[j:j + 1].data.cpu()) out_results.append( torch.argmax(pred[j:j + 1], 1).unsqueeze(0).expand(-1, 3, -1, -1).data.cpu()) color = random.choice(self.colors) result = dye_hair(denorm(image[j:j + 1]), mask[j:j + 1], color) result = self.transform( Image.fromarray(result)).unsqueeze(0) out_results.append(denorm(result)) result = dye_hair( denorm(image[j:j + 1]), torch.argmax(pred[j:j + 1], 1).unsqueeze(0), color) result = self.transform( Image.fromarray(result)).unsqueeze(0) out_results.append(denorm(result)) results_concat = torch.cat(out_results) results_path = os.path.join( self.test_results_dir, '{}_epoch_test_results.jpg'.format(self.resume_epochs)) save_image(results_concat, results_path, nrow=5, padding=0) print('Saved real and fake images into {}...'.format( results_path)) metrics['iou'] /= i metrics['f1_score'] /= i metrics['acc'] /= i log = "Average metrics, Epoch {}".format(self.resume_epochs) for tag, value in metrics.items(): log += ", {}: {:.4f}".format(tag, value) print(log)