def train_fn(train_loader, model, optimizer, device, epoch): total_loss = AverageMeter() accuracies = AverageMeter() model.train() t = tqdm(train_loader) for step, d in enumerate(t): spect = d["spect"].to(device) targets = d["target"].to(device) outputs = model(spect) loss = utility.loss_fn(outputs, targets) optimizer.zero_grad() loss.backward() # xm.optimizer_step(optimizer, barrier=True) optimizer.step() acc, n_position = utility.get_position_accuracy(outputs, targets) total_loss.update(loss.item(), n_position) accuracies.update(acc, n_position) t.set_description( f"Train E:{epoch + 1} - Loss:{total_loss.avg:0.4f} - Acc:{accuracies.avg:0.4f}" ) return total_loss.avg
def valid_fn(valid_loader, model, device, epoch): total_loss = AverageMeter() accuracies = AverageMeter() model.eval() t = tqdm(valid_loader) for step, d in enumerate(t): with torch.no_grad(): spect = d["spect"].to(device) targets = d["target"].to(device) outputs = model(spect) loss = utility.loss_fn(outputs, targets) acc, n_position = utility.get_position_accuracy( outputs, targets) total_loss.update(loss.item(), n_position) accuracies.update(acc, n_position) t.set_description( f"Eval E:{epoch + 1} - Loss:{total_loss.avg:0.4f} - Acc:{accuracies.avg:0.4f}" ) return total_loss.avg, accuracies.avg
def train(self): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() self.model.train() end = time.time() for batch, (inputs, labels, _) in enumerate(tqdm(self.loader_train)): data_time.update(time.time() - end) inputs = inputs.cuda() labels = labels.cuda() r = np.random.rand(1) if r < self.args.prob_mix and self.args.mix_type != 'none': outputs, loss, labels = utility.mix_regularization(inputs, labels, self.model, self.loss, self.args.mix_type, self.args.mix_beta) else: # no mix no out # compute output outputs = self.model(inputs) loss = self.loss(outputs, labels.long()) prec1, prec5 = utility.accuracy(outputs.data, labels, topk=(1, 5)) losses.update(loss.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() batch_time.update(time.time() - end) end = time.time() if (batch + 1) % self.args.print_every == 0: print('-------------------------------------------------------') print_string = 'Epoch: [{0}][{1}/{2}]'.format(self.current_epoch + 1, batch + 1, len(self.loader_train)) print(print_string) print_string = 'data_time: {data_time:.3f}, batch time: {batch_time:.3f}'.format( data_time=data_time.val, batch_time=batch_time.val) print(print_string) print_string = 'loss: {loss:.5f}'.format(loss=losses.avg) print(print_string) print_string = '[Training] Top-1 accuracy: {top1_acc:.2f}%, Top-5 accuracy: {top5_acc:.2f}%'.format( top1_acc=top1.avg, top5_acc=top5.avg) print(print_string) self.current_epoch += 1 self.load_epoch += 1 if self.current_epoch > self.warmup: self.scheduler.step() self.metrics['train_loss'].append(losses.avg) self.metrics['train_acc'].append(top1.avg)
def train_epoch(self, epoch=None): tt = tqdm(range(self.data_loader.train_iterations), total=self.data_loader.train_iterations, desc="epoch-{}-".format(epoch)) loss_per_epoch = AverageMeter() for cur_it in tt: # One Train step on the current batch loss = self.train_step() # update metrics returned from train_step func loss_per_epoch.update(loss) self.sess.run(self.model.global_epoch_inc) self.model.save(self.sess, self.config.checkpoint_dir) print(""" Epoch-{} loss:{:.4f} """.format(epoch, loss_per_epoch.val)) tt.close()
def test(self, epoch): # initialize tqdm tt = tqdm(range(self.data_loader.test_iterations), total=self.data_loader.test_iterations, desc="Val-{}-".format(epoch)) loss_per_epoch = AverageMeter() # Iterate over batches for cur_it in tt: # One Train step on the current batch feed_dict = {} for step in range(self.model.time_steps): input, label = self.data_loader.next_batch() feed_dict[self.model.train_inputs[step]] = input.reshape(-1, 1) feed_dict[self.model.train_labels[step]] = label.reshape(-1, 1) feed_dict.update({self.is_training: False}) loss = self.sess.run([self.loss_node], feed_dict=feed_dict) loss = loss[0] # update metrics returned from train_step func loss_per_epoch.update(loss) # summarize # summaries_dict = {'test/loss_per_epoch': loss_per_epoch.val, # 'test/acc_per_epoch': acc_per_epoch.val} # self.summarizer.summarize(self.model.global_step_tensor.eval(self.sess), summaries_dict) print(""" Val-{} loss:{:.4f} """.format(epoch, loss_per_epoch.val)) tt.close()
def test(args, logger, device_ids): logger.info("Loading network") model = AdaMatting(in_channel=4) ckpt = torch.load("./ckpts/ckpt_best_alpha.tar") model.load_state_dict(ckpt["state_dict"]) if args.cuda: device = torch.device("cuda:{}".format(device_ids[0])) if len(device_ids) > 1: logger.info("Loading with multiple GPUs") model = torch.nn.DataParallel(model, device_ids=device_ids) # model = model.cuda(device=device_ids[0]) else: device = torch.device("cpu") model = model.to(device) torch.set_grad_enabled(False) model.eval() test_names = gen_test_names() with open(os.path.join(args.raw_data_path, "Combined_Dataset/Test_set/test_fg_names.txt")) as f: fg_files = f.read().splitlines() with open(os.path.join(args.raw_data_path, "Combined_Dataset/Test_set/test_bg_names.txt")) as f: bg_files = f.read().splitlines() out_path = os.path.join(args.raw_data_path, "pred/") if not os.path.exists(out_path): os.makedirs(out_path) logger.info("Start testing") avg_sad = AverageMeter() avg_mse = AverageMeter() for index, name in enumerate(test_names): # file names fcount = int(name.split('.')[0].split('_')[0]) bcount = int(name.split('.')[0].split('_')[1]) img_name = fg_files[fcount] bg_name = bg_files[bcount] merged_name = bg_name.split(".")[0] + "!" + img_name.split(".")[0] + "!" + str(fcount) + "!" + str(index) + ".png" trimap_name = img_name.split(".")[0] + "_" + str(index % 20) + ".png" # read files merged = os.path.join(args.raw_data_path, "test/merged/", merged_name) alpha = os.path.join(args.raw_data_path, "test/mask/", img_name) trimap = os.path.join(args.raw_data_path, "Combined_Dataset/Test_set/Adobe-licensed images/trimaps/", trimap_name) merged = cv.imread(merged) # merged = cv.resize(merged, None, fx=0.75, fy=0.75) merged = cv.cvtColor(merged, cv.COLOR_BGR2RGB) trimap = cv.imread(trimap) # trimap = cv.resize(trimap, None, fx=0.75, fy=0.75) alpha = cv.imread(alpha, 0) # alpha = cv.resize(alpha, None, fx=0.75, fy=0.75) # process merged image merged = transforms.ToPILImage()(merged) out_merged = merged.copy() merged = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])(merged) h, w = merged.shape[1:3] h_crop, w_crop = h, w for i in range(h): if (h - i) % 16 == 0: h_crop = h - i break h_margin = int((h - h_crop) / 2) for i in range(w): if (w - i) % 16 == 0: w_crop = w - i break w_margin = int((w - w_crop) / 2) # write cropped gt alpha alpha = alpha[h_margin : h_margin + h_crop, w_margin : w_margin + w_crop] cv.imwrite(out_path + "{:04d}_gt_alpha.png".format(index), alpha) # generate and write cropped gt trimap gt_trimap = np.zeros(alpha.shape) gt_trimap.fill(128) gt_trimap[alpha <= 0] = 0 gt_trimap[alpha >= 255] = 255 cv.imwrite(out_path + "{:04d}_gt_trimap.png".format(index), gt_trimap) # concat the 4-d input and crop to feed the network properly x = torch.zeros((1, 4, h, w), dtype=torch.float) x[0, 0:3, :, :] = merged x[0, 3, :, :] = torch.from_numpy(trimap[:, :, 0] / 255.) x = x[:, :, h_margin : h_margin + h_crop, w_margin : w_margin + w_crop] # write cropped input images out_merged = transforms.ToTensor()(out_merged) out_merged = out_merged[:, h_margin : h_margin + h_crop, w_margin : w_margin + w_crop] out_merged = transforms.ToPILImage()(out_merged) out_merged.save(out_path + "{:04d}_input_merged.png".format(index)) out_trimap = transforms.ToPILImage()(x[0, 3, :, :]) out_trimap.save(out_path + "{:04d}_input_trimap.png".format(index)) # test x = x.type(torch.FloatTensor).to(device) _, pred_trimap, pred_alpha, _, _ = model(x) cropped_trimap = x[0, 3, :, :].unsqueeze(dim=0).unsqueeze(dim=0) pred_alpha[cropped_trimap <= 0] = 0.0 pred_alpha[cropped_trimap >= 1] = 1.0 # output predicted images pred_trimap = (pred_trimap.type(torch.FloatTensor) / 2).unsqueeze(dim=1) pred_trimap = transforms.ToPILImage()(pred_trimap[0, :, :, :]) pred_trimap.save(out_path + "{:04d}_pred_trimap.png".format(index)) out_pred_alpha = transforms.ToPILImage()(pred_alpha[0, :, :, :].cpu()) out_pred_alpha.save(out_path + "{:04d}_pred_alpha.png".format(index)) sad = compute_sad(pred_alpha, alpha) mse = compute_mse(pred_alpha, alpha, trimap) avg_sad.update(sad.item()) avg_mse.update(mse.item()) logger.info("{:04d}/{} | SAD: {:.1f} | MSE: {:.3f} | Avg SAD: {:.1f} | Avg MSE: {:.3f}".format(index, len(test_names), sad.item(), mse.item(), avg_sad.avg, avg_mse.avg)) logger.info("Average SAD: {:.1f} | Average MSE: {:.3f}".format(avg_sad.avg, avg_mse.avg))
def train(args, logger, device_ids): writer = SummaryWriter() logger.info("Loading network") model = AdaMatting(in_channel=4) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0001) if args.resume != "": ckpt = torch.load(args.resume) model.load_state_dict(ckpt["state_dict"]) optimizer.load_state_dict(ckpt["optimizer"]) if args.cuda: for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() device = torch.device("cuda:{}".format(device_ids[0])) if len(device_ids) > 1: logger.info("Loading with multiple GPUs") model = torch.nn.DataParallel(model, device_ids=device_ids) # model = convert_model(model) else: device = torch.device("cpu") model = model.to(device) logger.info("Initializing data loaders") train_dataset = AdaMattingDataset(args.raw_data_path, "train") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True) valid_dataset = AdaMattingDataset(args.raw_data_path, "valid") valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True, drop_last=True) if args.resume != "": logger.info("Start training from saved ckpt") start_epoch = ckpt["epoch"] + 1 cur_iter = ckpt["cur_iter"] peak_lr = ckpt["peak_lr"] best_loss = ckpt["best_loss"] best_alpha_loss = ckpt["best_alpha_loss"] else: logger.info("Start training from scratch") start_epoch = 0 cur_iter = 0 peak_lr = args.lr best_loss = float('inf') best_alpha_loss = float('inf') max_iter = 43100 * (1 - args.valid_portion / 100) / args.batch_size * args.epochs tensorboard_iter = cur_iter * (args.batch_size / 16) avg_lo = AverageMeter() avg_lt = AverageMeter() avg_la = AverageMeter() for epoch in range(start_epoch, args.epochs): # Training torch.set_grad_enabled(True) model.train() for index, (_, inputs, gts) in enumerate(train_loader): # cur_lr, peak_lr = lr_scheduler(optimizer=optimizer, cur_iter=cur_iter, peak_lr=peak_lr, end_lr=0.000001, # decay_iters=args.decay_iters, decay_power=0.8, power=0.5) cur_lr = lr_scheduler(optimizer=optimizer, init_lr=args.lr, cur_iter=cur_iter, max_iter=max_iter, max_decay_times=30, decay_rate=0.9) # img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320] inputs = inputs.to(device) gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320] gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320] optimizer.zero_grad() trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs) L_overall, L_t, L_a = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr) sigma_t, sigma_a = torch.exp(log_sigma_t_sqr.mean() / 2), torch.exp(log_sigma_a_sqr.mean() / 2) optimizer.zero_grad() L_overall.backward() clip_gradient(optimizer, 5) optimizer.step() avg_lo.update(L_overall.item()) avg_lt.update(L_t.item()) avg_la.update(L_a.item()) if cur_iter % 10 == 0: logger.info("Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}" .format(epoch, index, len(train_loader), avg_lo.avg, avg_lt.avg, avg_la.avg)) writer.add_scalar("loss/L_overall", avg_lo.avg, tensorboard_iter) writer.add_scalar("loss/L_t", avg_lt.avg, tensorboard_iter) writer.add_scalar("loss/L_a", avg_la.avg, tensorboard_iter) writer.add_scalar("other/sigma_t", sigma_t.item(), tensorboard_iter) writer.add_scalar("other/sigma_a", sigma_a.item(), tensorboard_iter) writer.add_scalar("other/lr", cur_lr, tensorboard_iter) avg_lo.reset() avg_lt.reset() avg_la.reset() cur_iter += 1 tensorboard_iter = cur_iter * (args.batch_size / 16) # Validation logger.info("Validating after the {}th epoch".format(epoch)) avg_loss = AverageMeter() avg_l_t = AverageMeter() avg_l_a = AverageMeter() torch.cuda.empty_cache() torch.set_grad_enabled(False) model.eval() with tqdm(total=len(valid_loader)) as pbar: for index, (display_rgb, inputs, gts) in enumerate(valid_loader): inputs = inputs.to(device) # [bs, 4, 320, 320] gt_alpha = (gts[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320] gt_trimap = gts[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320] trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model(inputs) L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr) avg_loss.update(L_overall_valid.item()) avg_l_t.update(L_t_valid.item()) avg_l_a.update(L_a_valid.item()) if index == 0: input_rbg = torchvision.utils.make_grid(display_rgb, normalize=False, scale_each=True) writer.add_image('input/rbg_image', input_rbg, tensorboard_iter) input_trimap = inputs[:, 3, :, :].unsqueeze(dim=1) input_trimap = torchvision.utils.make_grid(input_trimap, normalize=False, scale_each=True) writer.add_image('input/trimap', input_trimap, tensorboard_iter) output_alpha = alpha_estimation.clone() output_alpha[t_argmax.unsqueeze(dim=1) == 0] = 0.0 output_alpha[t_argmax.unsqueeze(dim=1) == 2] = 1.0 output_alpha = torchvision.utils.make_grid(output_alpha, normalize=False, scale_each=True) writer.add_image('output/alpha', output_alpha, tensorboard_iter) trimap_adaption_res = (t_argmax.type(torch.FloatTensor) / 2).unsqueeze(dim=1) trimap_adaption_res = torchvision.utils.make_grid(trimap_adaption_res, normalize=False, scale_each=True) writer.add_image('pred/trimap_adaptation', trimap_adaption_res, tensorboard_iter) alpha_estimation_res = torchvision.utils.make_grid(alpha_estimation, normalize=False, scale_each=True) writer.add_image('pred/alpha_estimation', alpha_estimation_res, tensorboard_iter) gt_alpha = gt_alpha gt_alpha = torchvision.utils.make_grid(gt_alpha, normalize=False, scale_each=True) writer.add_image('gt/alpha', gt_alpha, tensorboard_iter) gt_trimap = (gt_trimap.type(torch.FloatTensor) / 2).unsqueeze(dim=1) gt_trimap = torchvision.utils.make_grid(gt_trimap, normalize=False, scale_each=True) writer.add_image('gt/trimap', gt_trimap, tensorboard_iter) pbar.update() logger.info("Average loss overall: {:.4e}".format(avg_loss.avg)) logger.info("Average loss of trimap adaptation: {:.4e}".format(avg_l_t.avg)) logger.info("Average loss of alpha estimation: {:.4e}".format(avg_l_a.avg)) writer.add_scalar("valid_loss/L_overall", avg_loss.avg, tensorboard_iter) writer.add_scalar("valid_loss/L_t", avg_l_t.avg, tensorboard_iter) writer.add_scalar("valid_loss/L_a", avg_l_a.avg, tensorboard_iter) is_best = avg_loss.avg < best_loss best_loss = min(avg_loss.avg, best_loss) is_alpha_best = avg_l_a.avg < best_alpha_loss best_alpha_loss = min(avg_l_a.avg, best_alpha_loss) if is_best or is_alpha_best or args.save_ckpt: if not os.path.exists("ckpts"): os.makedirs("ckpts") save_checkpoint(ckpt_path=args.ckpt_path, is_best=is_best, is_alpha_best=is_alpha_best, logger=logger, model=model, optimizer=optimizer, epoch=epoch, cur_iter=cur_iter, peak_lr=peak_lr, best_loss=best_loss, best_alpha_loss=best_alpha_loss) writer.close()
def train(model, optimizer, device, args, logger, multi_gpu): torch.manual_seed(7) writer = SummaryWriter() logger.info("Initializing data loaders") train_dataset = AdaMattingDataset(args.raw_data_path, 'train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True) valid_dataset = AdaMattingDataset(args.raw_data_path, 'valid') valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True) if args.resume: logger.info("Start training from saved ckpt") ckpt = torch.load(args.ckpt_path) model = ckpt["model"].module model = model.to(device) optimizer = ckpt["optimizer"] start_epoch = ckpt["epoch"] + 1 max_iter = ckpt["max_iter"] cur_iter = ckpt["cur_iter"] init_lr = ckpt["init_lr"] best_loss = ckpt["best_loss"] else: logger.info("Start training from scratch") start_epoch = 0 max_iter = 43100 * (1 - args.valid_portion) / args.batch_size * args.epochs cur_iter = 0 init_lr = args.lr best_loss = float('inf') for epoch in range(start_epoch, args.epochs): # Training torch.set_grad_enabled(True) model.train() for index, (img, gt) in enumerate(train_loader): cur_lr = poly_lr_scheduler(optimizer=optimizer, init_lr=init_lr, iter=cur_iter, max_iter=max_iter) img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320] gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320] gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320] optimizer.zero_grad() trimap_adaption, t_argmax, alpha_estimation = model(img) L_overall, L_t, L_a = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=model.log_sigma_t_sqr, log_sigma_a_sqr=model.log_sigma_a_sqr) # if multi_gpu: # L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean() optimizer.zero_grad() L_overall.backward() optimizer.step() if cur_iter % 10 == 0: logger.info("Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}" .format(epoch, index, len(train_loader), L_overall.item(), L_t.item(), L_a.item())) writer.add_scalar("loss/L_overall", L_overall.item(), cur_iter) writer.add_scalar("loss/L_t", L_t.item(), cur_iter) writer.add_scalar("loss/L_a", L_a.item(), cur_iter) sigma_t = torch.exp(model.log_sigma_t_sqr / 2) sigma_a = torch.exp(model.log_sigma_a_sqr / 2) writer.add_scalar("sigma/sigma_t", sigma_t, cur_iter) writer.add_scalar("sigma/sigma_a", sigma_a, cur_iter) writer.add_scalar("lr", cur_lr, cur_iter) cur_iter += 1 # Validation logger.info("Validating after the {}th epoch".format(epoch)) avg_loss = AverageMeter() avg_l_t = AverageMeter() avg_l_a = AverageMeter() torch.cuda.empty_cache() torch.set_grad_enabled(False) model.eval() with tqdm(total=len(valid_loader)) as pbar: for index, (img, gt) in enumerate(valid_loader): img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320] gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to(device) # [bs, 1, 320, 320] gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to(device) # [bs, 320, 320] trimap_adaption, t_argmax, alpha_estimation = model(img) L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss(pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=model.log_sigma_t_sqr, log_sigma_a_sqr=model.log_sigma_a_sqr) # if multi_gpu: # L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean() avg_loss.update(L_overall_valid.item()) avg_l_t.update(L_t_valid.item()) avg_l_a.update(L_a_valid.item()) if index == 0: trimap_adaption_res = torchvision.utils.make_grid(t_argmax.type(torch.FloatTensor) / 2, normalize=True, scale_each=True) writer.add_image('valid_image/trimap_adaptation', trimap_adaption_res, cur_iter) alpha_estimation_res = torchvision.utils.make_grid(alpha_estimation, normalize=True, scale_each=True) writer.add_image('valid_image/alpha_estimation', alpha_estimation_res, cur_iter) pbar.update() logger.info("Average loss overall: {:.4e}".format(avg_loss.avg)) logger.info("Average loss of trimap adaptation: {:.4e}".format(avg_l_t.avg)) logger.info("Average loss of alpha estimation: {:.4e}".format(avg_l_a.avg)) writer.add_scalar("valid_loss/L_overall", avg_loss.avg, cur_iter) writer.add_scalar("valid_loss/L_t", avg_l_t.avg, cur_iter) writer.add_scalar("valid_loss/L_a", avg_l_a.avg, cur_iter) is_best = avg_loss.avg < best_loss best_loss = min(avg_loss.avg, best_loss) if is_best or (args.save_ckpt and epoch % 10 == 0): if not os.path.exists("ckpts"): os.makedirs("ckpts") logger.info("Checkpoint saved") if (is_best): logger.info("Best checkpoint saved") save_checkpoint(epoch, model, optimizer, cur_iter, max_iter, init_lr, avg_loss.avg, is_best, args.ckpt_path) writer.export_scalars_to_json("./all_scalars.json") writer.close()
def train(args, logger, device_ids): torch.manual_seed(7) writer = SummaryWriter() logger.info("Loading network") model = AdaMatting(in_channel=4) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0) if args.resume != "": ckpt = torch.load(args.resume) # for key, _ in ckpt.items(): # print(key) model.load_state_dict(ckpt["state_dict"]) optimizer.load_state_dict(ckpt["optimizer"]) if args.cuda: for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() device = torch.device("cuda:{}".format(device_ids[0])) if len(device_ids) > 1: logger.info("Loading with multiple GPUs") model = torch.nn.DataParallel(model, device_ids=device_ids) # model = model.cuda(device=device_ids[0]) else: device = torch.device("cpu") model = model.to(device) logger.info("Initializing data loaders") train_dataset = AdaMattingDataset(args.raw_data_path, "train") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True) valid_dataset = AdaMattingDataset(args.raw_data_path, "valid") valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True) if args.resume != "": logger.info("Start training from saved ckpt") start_epoch = ckpt["epoch"] + 1 cur_iter = ckpt["cur_iter"] + 1 peak_lr = ckpt["peak_lr"] best_loss = ckpt["best_loss"] else: logger.info("Start training from scratch") start_epoch = 0 cur_iter = 0 peak_lr = args.lr best_loss = float('inf') avg_lo = AverageMeter() avg_lt = AverageMeter() avg_la = AverageMeter() for epoch in range(start_epoch, args.epochs): # Training torch.set_grad_enabled(True) model.train() for index, (img, gt) in enumerate(train_loader): cur_lr, peak_lr = lr_scheduler(optimizer=optimizer, cur_iter=cur_iter, peak_lr=peak_lr, end_lr=0.00001, decay_iters=args.decay_iters, decay_power=0.9, power=0.9) img = img.type(torch.FloatTensor).to(device) # [bs, 4, 320, 320] gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type(torch.FloatTensor).to( device) # [bs, 1, 320, 320] gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to( device) # [bs, 320, 320] optimizer.zero_grad() trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model( img) L_overall, L_t, L_a = task_uncertainty_loss( pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr) L_overall, L_t, L_a = L_overall.mean(), L_t.mean(), L_a.mean() sigma_t, sigma_a = log_sigma_t_sqr.mean(), log_sigma_a_sqr.mean() optimizer.zero_grad() L_overall.backward() optimizer.step() avg_lo.update(L_overall.item()) avg_lt.update(L_t.item()) avg_la.update(L_a.item()) if cur_iter % 10 == 0: logger.info( "Epoch: {:03d} | Iter: {:05d}/{} | Loss: {:.4e} | L_t: {:.4e} | L_a: {:.4e}" .format(epoch, index, len(train_loader), avg_lo.avg, avg_lt.avg, avg_la.avg)) writer.add_scalar("loss/L_overall", avg_lo.avg, cur_iter) writer.add_scalar("loss/L_t", avg_lt.avg, cur_iter) writer.add_scalar("loss/L_a", avg_la.avg, cur_iter) sigma_t = torch.exp(sigma_t / 2) sigma_a = torch.exp(sigma_a / 2) writer.add_scalar("other/sigma_t", sigma_t.item(), cur_iter) writer.add_scalar("other/sigma_a", sigma_a.item(), cur_iter) writer.add_scalar("other/lr", cur_lr, cur_iter) avg_lo.reset() avg_lt.reset() avg_la.reset() cur_iter += 1 # Validation logger.info("Validating after the {}th epoch".format(epoch)) avg_loss = AverageMeter() avg_l_t = AverageMeter() avg_l_a = AverageMeter() torch.cuda.empty_cache() torch.set_grad_enabled(False) model.eval() with tqdm(total=len(valid_loader)) as pbar: for index, (img, gt) in enumerate(valid_loader): img = img.type(torch.FloatTensor).to( device) # [bs, 4, 320, 320] gt_alpha = (gt[:, 0, :, :].unsqueeze(1)).type( torch.FloatTensor).to(device) # [bs, 1, 320, 320] gt_trimap = gt[:, 1, :, :].type(torch.LongTensor).to( device) # [bs, 320, 320] trimap_adaption, t_argmax, alpha_estimation, log_sigma_t_sqr, log_sigma_a_sqr = model( img) L_overall_valid, L_t_valid, L_a_valid = task_uncertainty_loss( pred_trimap=trimap_adaption, pred_trimap_argmax=t_argmax, pred_alpha=alpha_estimation, gt_trimap=gt_trimap, gt_alpha=gt_alpha, log_sigma_t_sqr=log_sigma_t_sqr, log_sigma_a_sqr=log_sigma_a_sqr) L_overall_valid, L_t_valid, L_a_valid = L_overall_valid.mean( ), L_t_valid.mean(), L_a_valid.mean() avg_loss.update(L_overall_valid.item()) avg_l_t.update(L_t_valid.item()) avg_l_a.update(L_a_valid.item()) if index == 0: trimap_adaption_res = (t_argmax.type(torch.FloatTensor) / 2).unsqueeze(dim=1) trimap_adaption_res = torchvision.utils.make_grid( trimap_adaption_res, normalize=False, scale_each=True) writer.add_image('valid_image/trimap_adaptation', trimap_adaption_res, cur_iter) alpha_estimation_res = torchvision.utils.make_grid( alpha_estimation, normalize=True, scale_each=True) writer.add_image('valid_image/alpha_estimation', alpha_estimation_res, cur_iter) pbar.update() logger.info("Average loss overall: {:.4e}".format(avg_loss.avg)) logger.info("Average loss of trimap adaptation: {:.4e}".format( avg_l_t.avg)) logger.info("Average loss of alpha estimation: {:.4e}".format( avg_l_a.avg)) writer.add_scalar("valid_loss/L_overall", avg_loss.avg, cur_iter) writer.add_scalar("valid_loss/L_t", avg_l_t.avg, cur_iter) writer.add_scalar("valid_loss/L_a", avg_l_a.avg, cur_iter) is_best = avg_loss.avg < best_loss best_loss = min(avg_loss.avg, best_loss) if is_best or args.save_ckpt: if not os.path.exists("ckpts"): os.makedirs("ckpts") save_checkpoint(ckpt_path=args.ckpt_path, is_best=is_best, logger=logger, model=model, optimizer=optimizer, epoch=epoch, cur_iter=cur_iter, peak_lr=peak_lr, best_loss=best_loss) writer.close()
def test(self): if self.current_epoch % self.args.test_every == 0: batch_time = AverageMeter() data_time = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() self.model.eval() end = time.time() with torch.no_grad(): for batch, (inputs, labels, filename) in enumerate(tqdm(self.loader_test)): data_time.update(time.time() - end) _, _, len_of_frame, height, width = inputs.size() spatial_stride = (width - self.args.crop_size) // 2 stride = len_of_frame / 10 if len_of_frame <= self.args.clip_len: avail_number = 0 new_len = len_of_frame else: last_subclip_end_idx = round(min(len_of_frame, len_of_frame - (stride / 2) + (self.args.clip_len / 2))) last_subclip_begin_idx = last_subclip_end_idx - self.args.clip_len avail_number = min(last_subclip_begin_idx, 9) new_stride = last_subclip_begin_idx / float(avail_number) new_len = self.args.clip_len # Per View Test begin_idx = 0 for t in range(avail_number + 1): end_idx = begin_idx + new_len sub_inputs_t = inputs[:, :, begin_idx:end_idx, :, :] if self.args.test_view == 30: begin_spatial_idx = 0 for st in range(3): end_spatial_idx = begin_spatial_idx + self.args.crop_size sub_inputs_st = sub_inputs_t[:, :, :, :, begin_spatial_idx:end_spatial_idx] begin_spatial_idx = begin_spatial_idx + spatial_stride sub_inputs_st = sub_inputs_st.cuda() if t == 0 and st == 0: outputs = torch.nn.Softmax(dim=1)(self.model(sub_inputs_st)) else: outputs = outputs + torch.nn.Softmax(dim=1)(self.model(sub_inputs_st)) else: sub_inputs_t = sub_inputs_t.cuda() if t == 0: outputs = torch.nn.Softmax(dim=1)(self.model(sub_inputs_t)) else: outputs = outputs + torch.nn.Softmax(dim=1)(self.model(sub_inputs_t)) # idx update begin_idx = round(begin_idx + new_stride) if self.args.test_view == 10: outputs = outputs / (avail_number + 1) else: outputs = outputs / (3 * (avail_number + 1)) labels = labels.cuda() if self.final_test: # Write Prediction into Text File Here final_array = utility.inference(outputs.data) # write [filename final_array] and [newline] self.logfile.write(filename[0][-22:]) for tops in range(5): data_msg = ' {0}'.format(final_array[tops]) self.logfile.write(data_msg) self.logfile.write('\n') else: # measure accuracy and record loss prec1, prec5 = utility.accuracy(outputs.data, labels, topk=(1, 5)) top1.update(prec1.item(), inputs.size(0)) top5.update(prec5.item(), inputs.size(0)) batch_time.update(time.time() - end) end = time.time() if self.args.is_validate: print('----Validation Results Summary----') print_string = 'Epoch: [{0}]'.format(self.current_epoch) print(print_string) print_string = '----------------------------- Top-1 accuracy: {top1_acc:.2f}%'.format(top1_acc=top1.avg) print(print_string) print_string = '----------------------------- Top-5 accuracy: {top5_acc:.2f}%'.format(top5_acc=top5.avg) print(print_string) # save model per epoch if not self.args.test_only: if self.current_epoch % self.args.save_every == 0: torch.save({'epoch': self.current_epoch, 'state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict()}, self.ckpt_dir + '/model_epoch' + str(self.current_epoch).zfill(3) + '.pth') self.metrics['val_acc'].append(top1.avg) self.metrics['val_acc_top5'].append(top5.avg) else: self.metrics['val_acc'].append(0.) self.metrics['val_acc_top5'].append(0.) # Write logs if not self.args.test_only: with open(self.args.out_dir + '/log_epoch.csv', 'a') as epoch_log: if not self.args.load: epoch_log.write('{}, {:.5f}, {:.5f}, {:.5f}, {:.5f}\n'.format( self.current_epoch, self.metrics['train_loss'][self.current_epoch-1], self.metrics['train_acc'][self.current_epoch-1], self.metrics['val_acc'][self.current_epoch-1], self.metrics['val_acc_top5'][self.current_epoch-1])) plot_learning_curves(self.metrics, self.current_epoch, self.args) else: epoch_log.write('{}, {:.5f}, {:.5f}, {:.5f}, {:.5f}\n'.format( self.current_epoch, self.metrics['train_loss'][self.load_epoch-1], self.metrics['train_acc'][self.load_epoch-1], self.metrics['val_acc'][self.load_epoch-1], self.metrics['val_acc_top5'][self.load_epoch-1])) plot_learning_curves(self.metrics, self.load_epoch, self.args)