def run(opt, model, data_loader, mode): if torch.cuda.is_available(): device = torch.device('cuda:'+str(opt.gpu_id)) else: device = torch.device('cpu') model = model.to(device) tryon_dir = os.path.join(opt.data_root, mode, 'tryon-person') mkdir(tryon_dir) visual_dir = os.path.join(opt.out_dir, opt.name, mode) mkdir(visual_dir) data_iter = tqdm(data_loader, total=len(data_loader), bar_format='{l_bar}{r_bar}') for _data in data_iter: data = {} for key, value in _data.items(): if not 'name' in key: data[key] = value.to(device) # Load data on GPU else: data[key] = value cloth = data['cloth'] cloth_mask = data['cloth_mask'] person = data['person'] cloth_name = data['cloth_name'] outputs = model(torch.cat([data['feature'], cloth],1)) # (batch, channel, height, width) rendered_person, composition_mask = torch.split(outputs, 3,1) rendered_person = torch.tanh(rendered_person) composition_mask = torch.sigmoid(composition_mask) tryon_person = cloth*composition_mask + rendered_person*(1-composition_mask) visuals = [[data['head'], data['shape'], data['pose']], [cloth, cloth_mask*2-1, composition_mask*2-1], [rendered_person, tryon_person, person]] save_images(tryon_person, cloth_name, tryon_dir) save_visual(visuals, cloth_name, visual_dir)
def run(opt, model, data_loader, mode): if torch.cuda.is_available(): device = torch.device('cuda:' + str(opt.gpu_id)) else: device = torch.device('cpu') model = model.to(device) warp_cloth_dir = os.path.join(opt.data_root, mode, 'warp-cloth') warp_cloth_mask_dir = os.path.join(opt.data_root, mode, 'warp-cloth-mask') visual_dir = os.path.join(opt.out_dir, opt.name, mode) dirs = [warp_cloth_dir, warp_cloth_mask_dir, visual_dir] for d in dirs: mkdir(d) data_iter = tqdm(data_loader, total=len(data_loader), bar_format='{l_bar}{r_bar}') for _data in data_iter: data = {} for key, value in _data.items(): if not 'name' in key: data[key] = value.to(device) # Load data on GPU else: data[key] = value cloth = data['cloth'] person = data['person'] body_mask = data['body_mask'] cloth_name = data['cloth_name'] grid, _ = model(data['feature'], cloth) warped_cloth = F.grid_sample(cloth, grid, padding_mode='border') warped_cloth_mask = F.grid_sample(data['cloth_mask'], grid, padding_mode='zeros') save_images(warped_cloth, cloth_name, warp_cloth_dir) save_images(warped_cloth_mask * 2 - 1, cloth_name, warp_cloth_mask_dir) if mode == 'train': # No visuals continue warped_grid = F.grid_sample(data['grid'], grid, padding_mode='zeros') warped_person = body_mask * person + (1 - body_mask) * warped_cloth gt = body_mask * person + (1 - body_mask) * data['cloth_parse'] visuals = [[data['head'], data['shape'], data['pose']], [cloth, warped_cloth, warped_grid], [warped_person, gt, person]] save_visual(visuals, cloth_name, visual_dir)