def vismodel(): import sys import torch import tensorwatch as tw from networks import ResnetConditionHR netM = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=7, n_blocks2=3) tw.draw_model(netM, [1, 3, 512, 512])
def setup(opts): #initialize network netM = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=7, n_blocks2=3) netM = nn.DataParallel(netM) checkpoint_path = opts['checkpoint'] netM.load_state_dict(torch.load(checkpoint_path)) netM.cuda() netM.eval() cudnn.benchmark = True return netM
# Original Data traindata = VideoData( csv_file='Video_data_train.csv', data_config=data_config_train, transform=None ) # Write a dataloader function that can read the database provided by .csv file train_loader = torch.utils.data.DataLoader(traindata, batch_size=args.batch_size, shuffle=True, num_workers=args.batch_size, collate_fn=collate_filter_none) print('\n[Phase 2] : Initialization') netB = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=args.n_blocks1, n_blocks2=args.n_blocks2) netB = nn.DataParallel(netB) netB.load_state_dict(torch.load(args.init_model)) netB.cuda() netB.eval() for param in netB.parameters(): # freeze netD param.requires_grad = False netG = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=args.n_blocks1, n_blocks2=args.n_blocks2) netG.apply(conv_init) netG = nn.DataParallel(netG) netG.cuda()
#input data path data_path = args.input_dir #target background path back_img10 = cv2.imread(args.target_back) back_img10 = cv2.cvtColor(back_img10, cv2.COLOR_BGR2RGB) #Green-screen background back_img20 = np.zeros(back_img10.shape) back_img20[..., 0] = 120 back_img20[..., 1] = 255 back_img20[..., 2] = 155 #initialize network fo = glob.glob(model_main_dir + 'netG_epoch_*') model_name1 = fo[0] netM = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=7, n_blocks2=3) netM = nn.DataParallel(netM) netM.load_state_dict(torch.load(model_name1)) netM.cuda() netM.eval() cudnn.benchmark = True reso = (512, 512) #input reoslution to the network #Create a list of test images test_imgs = [ f for f in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, f)) and f.endswith('_img.png') ] test_imgs.sort()
def __init__(self, device=None, jit=True): self.device = device self.jit = jit self.opt = Namespace( **{ 'n_blocks1': 7, 'n_blocks2': 3, 'batch_size': 1, 'resolution': 512, 'name': 'Real_fixed' }) scriptdir = os.path.dirname(os.path.realpath(__file__)) csv_file = "Video_data_train_processed.csv" with open("Video_data_train.csv", "r") as r: with open(csv_file, "w") as w: w.write(r.read().format(scriptdir=scriptdir)) data_config_train = { 'reso': (self.opt.resolution, self.opt.resolution) } traindata = VideoData(csv_file=csv_file, data_config=data_config_train, transform=None) self.train_loader = torch.utils.data.DataLoader( traindata, batch_size=self.opt.batch_size, shuffle=True, num_workers=self.opt.batch_size, collate_fn=_collate_filter_none) netB = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=self.opt.n_blocks1, n_blocks2=self.opt.n_blocks2) if self.device == 'cuda': netB.cuda() netB.eval() for param in netB.parameters(): # freeze netB param.requires_grad = False self.netB = netB netG = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=self.opt.n_blocks1, n_blocks2=self.opt.n_blocks2) netG.apply(conv_init) self.netG = netG if self.device == 'cuda': self.netG.cuda() # TODO(asuhan): is this needed? torch.backends.cudnn.benchmark = True netD = MultiscaleDiscriminator(input_nc=3, num_D=1, norm_layer=nn.InstanceNorm2d, ndf=64) netD.apply(conv_init) netD = nn.DataParallel(netD) self.netD = netD if self.device == 'cuda': self.netD.cuda() self.l1_loss = alpha_loss() self.c_loss = compose_loss() self.g_loss = alpha_gradient_loss() self.GAN_loss = GANloss() self.optimizerG = optim.Adam(netG.parameters(), lr=1e-4) self.optimizerD = optim.Adam(netD.parameters(), lr=1e-5) self.log_writer = SummaryWriter(scriptdir) self.model_dir = scriptdir self._maybe_trace()
else: traindata = RealDataWoMotion( csv_file='tool/rgbd.csv', data_config=data_config_train, transform=None ) # Write a dataloader function that can read the database provided by .csv file train_loader = torch.utils.data.DataLoader(traindata, batch_size=args.batch_size, shuffle=True, num_workers=args.batch_size, collate_fn=collate_filter_none) print('\n[Phase 2] : Initialization') netB = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=args.n_blocks1_t, n_blocks2=args.n_blocks2_t) netB = nn.DataParallel(netB) netB.load_state_dict(torch.load(args.init_model)) netB.cuda() netB.eval() for param in netB.parameters(): # freeze netB param.requires_grad = False netG = ResnetConditionHR_mo(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=args.n_blocks1_s, n_blocks2=args.n_blocks2_s) netG.apply(conv_init) netG = nn.DataParallel(netG) netG.cuda()
def inference( output_dir, input_dir, sharpen=False, mask_ops="erode,3,10;dilate,5,5;blur,31,0", video=True, target_back=None, back=None, trained_model="real-fixed-cam", mask_suffix="_masksDL", outputs=["out"], output_suffix="", ): # input model model_main_dir = "Models/" + trained_model + "/" # input data path data_path = input_dir alpha_output = "out" in outputs matte_output = "matte" in outputs fg_output = "fg" in outputs compose_output = "compose" in outputs # initialize network fo = glob.glob(model_main_dir + "netG_epoch_*") model_name1 = fo[0] netM = ResnetConditionHR( input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=7, n_blocks2=3 ) netM = nn.DataParallel(netM) netM.load_state_dict(torch.load(model_name1)) netM.cuda() netM.eval() cudnn.benchmark = True reso = (512, 512) # input reoslution to the network # load captured background for video mode, fixed camera if back is not None: bg_im0 = cv2.imread(back) bg_im0 = cv2.cvtColor(bg_im0, cv2.COLOR_BGR2RGB) if sharpen: bg_im0 = sharpen_image(bg_im0) # Create a list of test images test_imgs = [ f for f in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, f)) and f.endswith("_img.png") ] test_imgs.sort() # output directory result_path = output_dir if not os.path.exists(result_path): os.makedirs(result_path) # mask preprocess data ops = [] if mask_ops: ops_list = mask_ops.split(";") for i, op_st in enumerate(ops_list): op_list = op_st.split(",") op = op_list[0] ks = int(op_list[1]) it = int(op_list[2]) if op != "blur": kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ks, ks)) else: kernel = (ks, ks) ops.append((op, kernel, it)) for i in tqdm(range(0, len(test_imgs))): filename = test_imgs[i] # original image bgr_img = cv2.imread(os.path.join(data_path, filename)) bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB) if back is None: # captured background image bg_im0 = cv2.imread( os.path.join(data_path, filename.replace("_img", "_back")) ) bg_im0 = cv2.cvtColor(bg_im0, cv2.COLOR_BGR2RGB) # segmentation mask rcnn = cv2.imread( os.path.join(data_path, filename.replace("_img", mask_suffix)), 0 ) if video: # if video mode, load target background frames # target background path if compose_output: back_img10 = cv2.imread( os.path.join(target_back, filename.replace("_img.png", ".png")) ) back_img10 = cv2.cvtColor(back_img10, cv2.COLOR_BGR2RGB) # Green-screen background back_img20 = np.zeros(bgr_img.shape) back_img20[..., 0] = 120 back_img20[..., 1] = 255 back_img20[..., 2] = 155 # create multiple frames with adjoining frames gap = 20 multi_fr_w = np.zeros((bgr_img.shape[0], bgr_img.shape[1], 4)) idx = [i - 2 * gap, i - gap, i + gap, i + 2 * gap] for t in range(0, 4): if idx[t] < 0: idx[t] = len(test_imgs) + idx[t] elif idx[t] >= len(test_imgs): idx[t] = idx[t] - len(test_imgs) file_tmp = test_imgs[idx[t]] bgr_img_mul = cv2.imread(os.path.join(data_path, file_tmp)) multi_fr_w[..., t] = cv2.cvtColor(bgr_img_mul, cv2.COLOR_BGR2GRAY) else: if i is 0: if compose_output: # target background path back_img10 = cv2.imread(target_back) back_img10 = cv2.cvtColor(back_img10, cv2.COLOR_BGR2RGB) # Green-screen background back_img20 = np.zeros(bgr_img.shape) back_img20[..., 0] = 120 back_img20[..., 1] = 255 back_img20[..., 2] = 155 ## create the multi-frame multi_fr_w = np.zeros((bgr_img.shape[0], bgr_img.shape[1], 4)) multi_fr_w[..., 0] = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2GRAY) multi_fr_w[..., 1] = multi_fr_w[..., 0] multi_fr_w[..., 2] = multi_fr_w[..., 0] multi_fr_w[..., 3] = multi_fr_w[..., 0] # crop tightly bgr_img0 = bgr_img try: bbox = get_bbox(rcnn, R=bgr_img0.shape[0], C=bgr_img0.shape[1]) except ValueError: R0 = bgr_img0.shape[0] C0 = bgr_img0.shape[1] if compose_output: back_img10 = cv2.resize(back_img10, (C0, R0)) back_img20 = cv2.resize(back_img20, (C0, R0)).astype(np.uint8) # There is no mask input, create empty images if alpha_output: cv2.imwrite( result_path + "/" + filename.replace("_img", "_out" + output_suffix), rcnn, ) if fg_output: cv2.imwrite( result_path + "/" + filename.replace("_img", "_fg" + output_suffix), cv2.cvtColor(cv2.resize(rcnn, (C0, R0)), cv2.COLOR_GRAY2RGB), ) if compose_output: cv2.imwrite( result_path + "/" + filename.replace("_img", "_compose" + output_suffix), cv2.cvtColor(back_img10, cv2.COLOR_BGR2RGB), ) if matte_output: cv2.imwrite( result_path + "/" + filename.replace("_img", "_matte" + output_suffix).format(i), cv2.cvtColor(back_img20, cv2.COLOR_BGR2RGB), ) # print("Empty: " + str(i + 1) + "/" + str(len(test_imgs))) continue crop_list = [bgr_img, bg_im0, rcnn, multi_fr_w] crop_list = crop_images(crop_list, reso, bbox) bgr_img = crop_list[0] bg_im = crop_list[1] rcnn = crop_list[2] multi_fr = crop_list[3] # sharpen original images if sharpen: bgr_img = sharpen_image(bgr_img) if back is None: bg_im = sharpen_image(bg_im) # process segmentation mask rcnn = rcnn.astype(np.float32) / 255 rcnn[rcnn > 0.2] = 1 K = 25 zero_id = np.nonzero(np.sum(rcnn, axis=1) == 0) del_id = zero_id[0][zero_id[0] > 250] if len(del_id) > 0: del_id = [del_id[0] - 2, del_id[0] - 1, *del_id] rcnn = np.delete(rcnn, del_id, 0) rcnn = cv2.copyMakeBorder(rcnn, 0, K + len(del_id), 0, 0, cv2.BORDER_REPLICATE) for op in ops: if op[0] == "dilate": rcnn = cv2.dilate(rcnn, op[1], iterations=op[2]) elif op[0] == "erode": rcnn = cv2.erode(rcnn, op[1], iterations=op[2]) elif op[0] == "blur": rcnn = cv2.GaussianBlur(rcnn.astype(np.float32), op[1], op[2]) rcnn = (255 * rcnn).astype(np.uint8) rcnn = np.delete(rcnn, range(reso[0], reso[0] + K), 0) # convert to torch img = torch.from_numpy(bgr_img.transpose((2, 0, 1))).unsqueeze(0) img = 2 * img.float().div(255) - 1 bg = torch.from_numpy(bg_im.transpose((2, 0, 1))).unsqueeze(0) bg = 2 * bg.float().div(255) - 1 rcnn_al = torch.from_numpy(rcnn).unsqueeze(0).unsqueeze(0) rcnn_al = 2 * rcnn_al.float().div(255) - 1 multi_fr = torch.from_numpy(multi_fr.transpose((2, 0, 1))).unsqueeze(0) multi_fr = 2 * multi_fr.float().div(255) - 1 with torch.no_grad(): img, bg, rcnn_al, multi_fr = ( Variable(img.cuda()), Variable(bg.cuda()), Variable(rcnn_al.cuda()), Variable(multi_fr.cuda()), ) input_im = torch.cat([img, bg, rcnn_al, multi_fr], dim=1) alpha_pred, fg_pred_tmp = netM(img, bg, rcnn_al, multi_fr) al_mask = (alpha_pred > 0.95).type(torch.cuda.FloatTensor) # for regions with alpha>0.95, simply use the image as fg fg_pred = img * al_mask + fg_pred_tmp * (1 - al_mask) alpha_out = to_image(alpha_pred[0, ...]) # refine alpha with connected component labels = label((alpha_out > 0.05).astype(int)) try: assert labels.max() != 0 except: continue largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1 alpha_out = alpha_out * largestCC alpha_out = (255 * alpha_out[..., 0]).astype(np.uint8) fg_out = to_image(fg_pred[0, ...]) fg_out = fg_out * np.expand_dims( (alpha_out.astype(float) / 255 > 0.01).astype(float), axis=2 ) fg_out = (255 * fg_out).astype(np.uint8) # Uncrop R0 = bgr_img0.shape[0] C0 = bgr_img0.shape[1] alpha_out0 = uncrop(alpha_out, bbox, R0, C0) fg_out0 = uncrop(fg_out, bbox, R0, C0) # compose if alpha_output: cv2.imwrite( result_path + "/" + filename.replace("_img", "_out" + output_suffix), alpha_out0, ) if fg_output: cv2.imwrite( result_path + "/" + filename.replace("_img", "_fg" + output_suffix), cv2.cvtColor(fg_out0, cv2.COLOR_BGR2RGB), ) if compose_output: back_img10 = cv2.resize(back_img10, (C0, R0)) comp_im_tr1 = composite4(fg_out0, back_img10, alpha_out0) cv2.imwrite( result_path + "/" + filename.replace("_img", "_compose" + output_suffix), cv2.cvtColor(comp_im_tr1, cv2.COLOR_BGR2RGB), ) if matte_output: back_img20 = cv2.resize(back_img20, (C0, R0)) comp_im_tr2 = composite4(fg_out0, back_img20, alpha_out0) cv2.imwrite( result_path + "/" + filename.replace("_img", "_matte" + output_suffix).format(i), cv2.cvtColor(comp_im_tr2, cv2.COLOR_BGR2RGB), )
csv_file='Data_adobe/Adobe_train_data.csv', data_config=data_config_train, transform=None ) # Write a dataloader function that can read the database provided by .csv file train_loader = torch.utils.data.DataLoader(traindata, batch_size=args.batch_size, shuffle=True, num_workers=args.batch_size, collate_fn=collate_filter_none) print('\n[Phase 2] : Initialization') net = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=7, n_blocks2=3, norm_layer=nn.BatchNorm2d) net.apply(conv_init) net = nn.DataParallel(net) # net.load_state_dict(torch.load(model_dir + 'net_epoch_X')) #uncomment this if you are initializing your model net.cuda() torch.backends.cudnn.benchmark = True # Loss l1_loss = alpha_loss() c_loss = compose_loss() g_loss = alpha_gradient_loss() optimizer = optim.Adam(net.parameters(), lr=1e-4) # optimizer.load_state_dict(torch.load(model_dir + 'optim_epoch_X')) #uncomment this if you are initializing your model
else: torch_model.load_state_dict(torch.load(model_path)) torch.onnx.export(torch_model, test_inputs, output_name + ".onnx") self.onnx_file = output_name + ".onnx" @staticmethod def copy_dict(state_dict): if list(state_dict.keys())[0].startswith("module"): start_idx = 1 else: start_idx = 0 new_state_dict = OrderedDict() for k, v in state_dict.items(): name = ".".join(k.split(".")[start_idx:]) new_state_dict[name] = v return new_state_dict def if __name__ == "__main__": # Load the trained model from file model_path = 'sample.pth' netM = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=7, n_blocks2=3) dummy_input1 = Variable(torch.randn(1, 1, 512, 512)) dummy_input2 = Variable( torch.randn(1, 3, 512, 512)) # one black and white 28 x 28 picture will be the input to the model dummy_input3 = Variable(torch.randn(1, 4, 512, 512)) dummy_input = (dummy_input2, dummy_input2, dummy_input1, dummy_input3) toc = TorchOnnxConverter(netM,dummy_input,model_path,"sample") print("Conversion Complete, New file: {toc.onnx_file}")
def main(): # CUDA # os.environ["CUDA_VISIBLE_DEVICES"]="4" # print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"]) print(f'Is CUDA available: {torch.cuda.is_available()}') """Parses arguments.""" parser = argparse.ArgumentParser( description='Training Background Matting on Adobe Dataset') parser.add_argument('-n', '--name', type=str, help='Name of tensorboard and model saving folders') parser.add_argument('-bs', '--batch_size', type=int, help='Batch Size') parser.add_argument('-res', '--reso', type=int, help='Input image resolution') parser.add_argument('-init_model', '--init_model', type=str, help='Initial model file') parser.add_argument('-w', '--workers', type=int, default=None, help='Number of worker to load data') parser.add_argument('-ep', '--epochs', type=int, default=15, help='Maximum Epoch') parser.add_argument( '-n_blocks1', '--n_blocks1', type=int, default=7, help='Number of residual blocks after Context Switching') parser.add_argument('-n_blocks2', '--n_blocks2', type=int, default=3, help='Number of residual blocks for Fg and alpha each') args = parser.parse_args() if args.workers is None: args.workers = args.batch_size ##Directories tb_dir = f'tb_summary/{args.name}' model_dir = f'models/{args.name}' if not os.path.exists(model_dir): os.makedirs(model_dir) if not os.path.exists(tb_dir): os.makedirs(tb_dir) ## Input list data_config_train = { 'reso': (args.reso, args.reso) } # if trimap is true, rcnn is used # DATA LOADING print('\n[Phase 1] : Data Preparation') # Original Data traindata = VideoData( csv_file='Video_data_train.csv', data_config=data_config_train, transform=None ) # Write a dataloader function that can read the database provided by .csv file train_loader = DataLoader(traindata, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=collate_filter_none) print('\n[Phase 2] : Initialization') netB = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=args.n_blocks1, n_blocks2=args.n_blocks2) netB = nn.DataParallel(netB) netB.load_state_dict(torch.load(args.init_model)) netB.cuda() netB.eval() for param in netB.parameters(): # freeze netB param.requires_grad = False netG = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=args.n_blocks1, n_blocks2=args.n_blocks2) netG.apply(conv_init) netG = nn.DataParallel(netG) netG.cuda() torch.backends.cudnn.benchmark = True netD = MultiscaleDiscriminator(input_nc=3, num_D=1, norm_layer=nn.InstanceNorm2d, ndf=64) netD.apply(conv_init) netD = nn.DataParallel(netD) netD.cuda() # Loss l1_loss = alpha_loss() c_loss = compose_loss() g_loss = alpha_gradient_loss() GAN_loss = GANloss() optimizerG = Adam(netG.parameters(), lr=1e-4) optimizerD = Adam(netD.parameters(), lr=1e-5) log_writer = SummaryWriter(tb_dir) print('Starting Training') step = 50 KK = len(train_loader) wt = 1 for epoch in range(0, args.epochs): netG.train() netD.train() lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 t0 = get_time() for i, data in enumerate(train_loader): # Initiating bg = data['bg'].cuda() image = data['image'].cuda() seg = data['seg'].cuda() multi_fr = data['multi_fr'].cuda() seg_gt = data['seg-gt'].cuda() back_rnd = data['back-rnd'].cuda() mask0 = torch.ones(seg.shape).cuda() tr0 = get_time() # pseudo-supervision alpha_pred_sup, fg_pred_sup = netB(image, bg, seg, multi_fr) mask = (alpha_pred_sup > -0.98).type(torch.FloatTensor).cuda() mask1 = (seg_gt > 0.95).type(torch.FloatTensor).cuda() ## Train Generator alpha_pred, fg_pred = netG(image, bg, seg, multi_fr) ##pseudo-supervised losses al_loss = l1_loss( alpha_pred_sup, alpha_pred, mask0) + 0.5 * g_loss(alpha_pred_sup, alpha_pred, mask0) fg_loss = l1_loss(fg_pred_sup, fg_pred, mask) # compose into same background comp_loss = c_loss(image, alpha_pred, fg_pred, bg, mask1) # randomly permute the background perm = torch.LongTensor(np.random.permutation(bg.shape[0])) bg_sh = bg[perm, :, :, :] al_mask = (alpha_pred > 0.95).type(torch.FloatTensor).cuda() # Choose the target background for composition # back_rnd: contains separate set of background videos captured # bg_sh: contains randomly permuted captured background from the same minibatch if np.random.random_sample() > 0.5: bg_sh = back_rnd image_sh = compose_image_withshift( alpha_pred, image * al_mask + fg_pred * (1 - al_mask), bg_sh, seg) fake_response = netD(image_sh) loss_ganG = GAN_loss(fake_response, label_type=True) lossG = loss_ganG + wt * (0.05 * comp_loss + 0.05 * al_loss + 0.05 * fg_loss) optimizerG.zero_grad() lossG.backward() optimizerG.step() # Train Discriminator fake_response = netD(image_sh) real_response = netD(image) loss_ganD_fake = GAN_loss(fake_response, label_type=False) loss_ganD_real = GAN_loss(real_response, label_type=True) lossD = (loss_ganD_real + loss_ganD_fake) * 0.5 # Update discriminator for every 5 generator update if i % 5 == 0: optimizerD.zero_grad() lossD.backward() optimizerD.step() lG += lossG.data lD += lossD.data GenL += loss_ganG.data DisL_r += loss_ganD_real.data DisL_f += loss_ganD_fake.data alL += al_loss.data fgL += fg_loss.data compL += comp_loss.data log_writer.add_scalar('Generator Loss', lossG.data, epoch * KK + i + 1) log_writer.add_scalar('Discriminator Loss', lossD.data, epoch * KK + i + 1) log_writer.add_scalar('Generator Loss: Fake', loss_ganG.data, epoch * KK + i + 1) log_writer.add_scalar('Discriminator Loss: Real', loss_ganD_real.data, epoch * KK + i + 1) log_writer.add_scalar('Discriminator Loss: Fake', loss_ganD_fake.data, epoch * KK + i + 1) log_writer.add_scalar('Generator Loss: Alpha', al_loss.data, epoch * KK + i + 1) log_writer.add_scalar('Generator Loss: Fg', fg_loss.data, epoch * KK + i + 1) log_writer.add_scalar('Generator Loss: Comp', comp_loss.data, epoch * KK + i + 1) t1 = get_time() elapse += t1 - t0 elapse_run += t1 - tr0 t0 = t1 if i % step == (step - 1): print(f'[{epoch + 1}, {i + 1:5d}] ' f'Gen-loss: {lG / step:.4f} ' f'Disc-loss: {lD / step:.4f} ' f'Alpha-loss: {alL / step:.4f} ' f'Fg-loss: {fgL / step:.4f} ' f'Comp-loss: {compL / step:.4f} ' f'Time-all: {elapse / step:.4f} ' f'Time-fwbw: {elapse_run / step:.4f}') lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 write_tb_log(image, 'image', log_writer, i) write_tb_log(seg, 'seg', log_writer, i) write_tb_log(alpha_pred_sup, 'alpha-sup', log_writer, i) write_tb_log(alpha_pred, 'alpha_pred', log_writer, i) write_tb_log(fg_pred_sup * mask, 'fg-pred-sup', log_writer, i) write_tb_log(fg_pred * mask, 'fg_pred', log_writer, i) # composition alpha_pred = (alpha_pred + 1) / 2 comp = fg_pred * alpha_pred + (1 - alpha_pred) * bg write_tb_log(comp, 'composite-same', log_writer, i) write_tb_log(image_sh, 'composite-diff', log_writer, i) del comp del bg, image, seg, multi_fr, seg_gt, back_rnd del mask0, alpha_pred_sup, fg_pred_sup, mask, mask1 del alpha_pred, fg_pred, al_loss, fg_loss, comp_loss del bg_sh, image_sh, fake_response, real_response del lossG, lossD, loss_ganD_real, loss_ganD_fake, loss_ganG if epoch % 2 == 0: ep = epoch + 1 torch.save(netG.state_dict(), f'{model_dir}/netG_epoch_{ep}.pth') torch.save(optimizerG.state_dict(), f'{model_dir}/optimG_epoch_{ep}.pth') torch.save(netD.state_dict(), f'{model_dir}/netD_epoch_{ep}.pth') torch.save(optimizerD.state_dict(), f'{model_dir}/optimD_epoch_{ep}.pth') # Change weight every 2 epoch to put more stress on discriminator weight and less on pseudo-supervision wt = wt / 2
def main(): # CUDA # os.environ["CUDA_VISIBLE_DEVICES"]="4" # print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"]) print(f'Is CUDA available: {torch.cuda.is_available()}') """Parses arguments.""" parser = argparse.ArgumentParser( description='Training Background Matting on Adobe Dataset') parser.add_argument('-n', '--name', type=str, help='Name of tensorboard and model saving folders') parser.add_argument('-bs', '--batch_size', type=int, help='Batch Size') parser.add_argument('-res', '--reso', type=int, help='Input image resolution') parser.add_argument( '-cont', '--continue', action='store_true', help= 'Indicates to run the continue training using the latest saved model') parser.add_argument('-w', '--workers', type=int, default=None, help='Number of worker to load data') parser.add_argument('-ep', '--epochs', type=int, default=60, help='Maximum Epoch') parser.add_argument( '-n_blocks1', '--n_blocks1', type=int, default=7, help='Number of residual blocks after Context Switching') parser.add_argument('-n_blocks2', '--n_blocks2', type=int, default=3, help='Number of residual blocks for Fg and alpha each') args = parser.parse_args() if args.workers is None: args.workers = args.batch_size continue_training = getattr(args, 'continue') # Directories tb_dir = f'tb_summary/{args.name}' model_dir = f'models/{args.name}' if not os.path.exists(model_dir): os.makedirs(model_dir) if not os.path.exists(tb_dir): os.makedirs(tb_dir) # Input list data_config_train = { 'reso': [args.reso, args.reso], 'trimapK': [5, 5], 'noise': True } # choice for data loading parameters # DATA LOADING print('\n[Phase 1] : Data Preparation') # Original Data traindata = AdobeDataAffineHR( csv_file='Data_adobe/Adobe_train_data.csv', data_config=data_config_train, transform=None ) # Write a dataloader function that can read the database provided by .csv file train_loader = DataLoader(traindata, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, collate_fn=collate_filter_none) print('\n[Phase 2] : Initialization') # Find latest saved model model, optim = '', '' start_epoch = 0 if continue_training: for name in os.listdir(model_dir): if name.endswith('.pth') and name.startswith('net_epoch_'): ep = int(name[len('net_epoch_'):-4]) if ep > start_epoch: start_epoch = ep model = name if model: model = f'{model_dir}/{model}' optim = f'{model_dir}/optim_epoch_{start_epoch}.pth' else: continue_training = False net = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=7, n_blocks2=3, norm_layer=nn.BatchNorm2d) net.apply(conv_init) net = nn.DataParallel(net) if continue_training: net.load_state_dict(torch.load(model)) net.cuda() torch.backends.cudnn.benchmark = True # Loss l1_loss = alpha_loss() c_loss = compose_loss() g_loss = alpha_gradient_loss() optimizer = Adam(net.parameters(), lr=1e-4) if continue_training: optimizer.load_state_dict(torch.load(optim)) log_writer = SummaryWriter(tb_dir) print('Starting Training') step = 50 # steps to visualize training images in tensorboard KK = len(train_loader) for epoch in range(start_epoch, args.epochs): net.train() netL, alL, fgL, fg_cL, al_fg_cL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0 t0 = get_time() testL = 0 ct_tst = 0 for i, data in enumerate(train_loader): # Initiating fg = data['fg'].cuda() bg = data['bg'].cuda() alpha = data['alpha'].cuda() image = data['image'].cuda() bg_tr = data['bg_tr'].cuda() seg = data['seg'].cuda() multi_fr = data['multi_fr'].cuda() mask = (alpha > -0.99).type(torch.FloatTensor).cuda() mask0 = torch.ones(alpha.shape).cuda() tr0 = get_time() alpha_pred, fg_pred = net(image, bg_tr, seg, multi_fr) ## Put needed loss here al_loss = l1_loss(alpha, alpha_pred, mask0) fg_loss = l1_loss(fg, fg_pred, mask) al_mask = (alpha_pred > 0.95).type(torch.FloatTensor).cuda() fg_pred_c = image * al_mask + fg_pred * (1 - al_mask) fg_c_loss = c_loss(image, alpha_pred, fg_pred_c, bg, mask0) al_fg_c_loss = g_loss(alpha, alpha_pred, mask0) loss = al_loss + 2 * fg_loss + fg_c_loss + al_fg_c_loss optimizer.zero_grad() loss.backward() optimizer.step() netL += loss.data alL += al_loss.data fgL += fg_loss.data fg_cL += fg_c_loss.data al_fg_cL += al_fg_c_loss.data log_writer.add_scalar('training_loss', loss.data, epoch * KK + i + 1) log_writer.add_scalar('alpha_loss', al_loss.data, epoch * KK + i + 1) log_writer.add_scalar('fg_loss', fg_loss.data, epoch * KK + i + 1) log_writer.add_scalar('comp_loss', fg_c_loss.data, epoch * KK + i + 1) log_writer.add_scalar('alpha_gradient_loss', al_fg_c_loss.data, epoch * KK + i + 1) t1 = get_time() elapse += t1 - t0 elapse_run += t1 - tr0 t0 = t1 testL += loss.data ct_tst += 1 if i % step == (step - 1): print(f'[{epoch + 1}, {i + 1:5d}] ' f'Total-loss: {netL / step:.4f} ' f'Alpha-loss: {alL / step:.4f} ' f'Fg-loss: {fgL / step:.4f} ' f'Comp-loss: {fg_cL / step:.4f} ' f'Alpha-gradient-loss: {al_fg_cL / step:.4f} ' f'Time-all: {elapse / step:.4f} ' f'Time-fwbw: {elapse_run / step:.4f}') netL, alL, fgL, fg_cL, al_fg_cL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0 write_tb_log(image, 'image', log_writer, i) write_tb_log(seg, 'seg', log_writer, i) write_tb_log(alpha, 'alpha', log_writer, i) write_tb_log(alpha_pred, 'alpha_pred', log_writer, i) write_tb_log(fg * mask, 'fg', log_writer, i) write_tb_log(fg_pred * mask, 'fg_pred', log_writer, i) write_tb_log(multi_fr[0:4, 0, ...].unsqueeze(1), 'multi_fr', log_writer, i) # composition alpha_pred = (alpha_pred + 1) / 2 comp = fg_pred * alpha_pred + (1 - alpha_pred) * bg write_tb_log(comp, 'composite', log_writer, i) del comp del fg, bg, alpha, image, alpha_pred, fg_pred, bg_tr, seg, multi_fr # Saving torch.save(net.state_dict(), f'{model_dir}/net_epoch_{epoch + 1}.pth') torch.save(optimizer.state_dict(), f'{model_dir}/optim_epoch_{epoch + 1}.pth')
def __init__(self, device=None, jit=True): self.device = device self.jit = jit self.opt = Namespace( **{ 'n_blocks1': 7, 'n_blocks2': 3, 'batch_size': 1, 'resolution': 512, 'name': 'Real_fixed' }) data_config_train = { 'reso': (self.opt.resolution, self.opt.resolution) } traindata = VideoData(csv_file='Video_data_train.csv', data_config=data_config_train, transform=None) self.train_loader = torch.utils.data.DataLoader( traindata, batch_size=self.opt.batch_size, shuffle=True, num_workers=self.opt.batch_size, collate_fn=_collate_filter_none) netB = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=self.opt.n_blocks1, n_blocks2=self.opt.n_blocks2) if self.device == 'cuda': netB.cuda() netB.eval() for param in netB.parameters(): # freeze netB param.requires_grad = False self.netB = netB netG = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=self.opt.n_blocks1, n_blocks2=self.opt.n_blocks2) netG.apply(conv_init) self.netG = netG if self.device == 'cuda': self.netG.cuda() # TODO(asuhan): is this needed? torch.backends.cudnn.benchmark = True netD = MultiscaleDiscriminator(input_nc=3, num_D=1, norm_layer=nn.InstanceNorm2d, ndf=64) netD.apply(conv_init) netD = nn.DataParallel(netD) self.netD = netD if self.device == 'cuda': self.netD.cuda() self.l1_loss = alpha_loss() self.c_loss = compose_loss() self.g_loss = alpha_gradient_loss() self.GAN_loss = GANloss() self.optimizerG = optim.Adam(netG.parameters(), lr=1e-4) self.optimizerD = optim.Adam(netD.parameters(), lr=1e-5) tb_dir = '/home/circleci/project/benchmark/models/Background-Matting/TB_Summary/' + self.opt.name if not os.path.exists(tb_dir): os.makedirs(tb_dir) self.log_writer = SummaryWriter(tb_dir) self.model_dir = '/home/circleci/project/benchmark/models/Background-Matting/Models/' + self.opt.name if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) self._maybe_trace()
else: args.video = False print('Using image mode') # target background path back_img10 = cv2.imread(args.target_back) back_img10 = cv2.cvtColor(back_img10, cv2.COLOR_BGR2RGB) # Green-screen background back_img20 = np.zeros(back_img10.shape) back_img20[..., 0] = 120 back_img20[..., 1] = 255 back_img20[..., 2] = 155 # initialize network fo = glob.glob(model_main_dir + 'netG_epoch_*') model_name1 = fo[0] netM = ResnetConditionHR(input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=7, n_blocks2=3) netM = nn.DataParallel(netM) netM.load_state_dict(torch.load(model_name1, map_location=torch.device('cpu'))) if cuda.is_available(): netM.cuda() cudnn.benchmark = True else: netM.cpu() netM.eval() reso = (512, 512) # input reoslution to the network # load captured background for video mode, fixed camera if args.back is not None: bg_im0 = cv2.imread(args.back)