def __init__(self, opt): ''' opt, in_channels=19, depth=4, start_filts=64, up_mode='transpose', merge_mode='concat') :param opt: ''' super(cyclegan, self).__init__() self.Generator = UNet(22, 4) self.Discriminator = NLayerDiscriminator() self.PairDis = PairDiscriminator() self.criterionGAN = GANLoss("lsgan") self.PairGAN = GANLoss("lsgan") self.loss_1 = nn.L1Loss() self.loss_2 = nn.MSELoss() self.optimizer_D = torch.optim.Adam(self.Discriminator.parameters(), lr=opt.lr, betas=(0.5, 0.999)) self._optimizer_G = torch.optim.Adam(self.Generator.parameters(), lr=opt.lr, betas=(0.5, 0.999)) self.vgg_loss = VGGLoss() self.content_loss = Content_loss()
def train_tom(opt, train_loader, model, board): model.cuda() model.train() # criterion criterionL1 = nn.L1Loss() criterionVGG = VGGLoss() criterionMask = nn.L1Loss() # optimizer optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda step: 1.0 - max(0, step - opt.keep_step) / float(opt.decay_step + 1)) for step in range(opt.keep_step + opt.decay_step): iter_start_time = time.time() inputs = train_loader.next_batch() im = inputs['image'].cuda() im_pose = inputs['pose_image'] im_h = inputs['head'] shape = inputs['shape'] agnostic = inputs['agnostic'].cuda() c = inputs['cloth'].cuda() cm = inputs['cloth_mask'].cuda() outputs = model(torch.cat([agnostic, c],1)) p_rendered, m_composite = torch.split(outputs, 3,1) p_rendered = F.tanh(p_rendered) m_composite = F.sigmoid(m_composite) p_tryon = c * m_composite+ p_rendered * (1 - m_composite) visuals = [ [im_h, shape, im_pose], [c, cm*2-1, m_composite*2-1], [p_rendered, p_tryon, im]] loss_l1 = criterionL1(p_tryon, im) loss_vgg = criterionVGG(p_tryon, im) loss_mask = criterionMask(m_composite, cm) loss = loss_l1 + loss_vgg + loss_mask optimizer.zero_grad() loss.backward() optimizer.step() if (step+1) % opt.display_count == 0: board_add_images(board, 'combine', visuals, step+1) board.add_scalar('metric', loss.item(), step+1) board.add_scalar('L1', loss_l1.item(), step+1) board.add_scalar('VGG', loss_vgg.item(), step+1) board.add_scalar('MaskL1', loss_mask.item(), step+1) #board.add_graph(model, torch.cat([agnostic, c],1)) t = time.time() - iter_start_time print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' % (step+1, t, loss.item(), loss_l1.item(), loss_vgg.item(), loss_mask.item()), flush=True) if (step+1) % opt.save_count == 0: save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))
def __init__(self, gen, dis, dataloader_train, dataloader_val, gpu_id, log_freq, save_dir, n_step, optimizer='adam'): if torch.cuda.is_available(): self.device = torch.device('cuda:' + str(gpu_id)) else: self.device = torch.device('cpu') self.gen = gen.to(self.device) self.dis = dis.to(self.device) self.dataloader_train = dataloader_train self.dataloader_val = dataloader_val if optimizer == 'adam': self.optim_g = torch.optim.Adam(gen.parameters(), lr=1e-4, betas=(0.5, 0.999)) self.optim_d = torch.optim.Adam(dis.parameters(), lr=1e-4, betas=(0.5, 0.999)) elif optimizer == 'ranger': self.optim_g = Ranger(gen.parameters()) self.optim_d = Ranger(dis.parameters()) self.criterionL1 = nn.L1Loss() self.criterionVGG = VGGLoss() self.criterionAdv = torch.nn.BCELoss() self.log_freq = log_freq self.save_dir = save_dir self.n_step = n_step self.step = 0 print('Generator Parameters:', sum([p.nelement() for p in self.gen.parameters()])) print('Discriminator Parameters:', sum([p.nelement() for p in self.dis.parameters()]))
def train(opt, train_loader, G, D, board): human_parser = HumanParser(opt) human_parser.eval() G.train() D.train() # palette = get_palette() # Criterion criterionWarp = nn.L1Loss() criterionPerceptual = VGGLoss() criterionL1 = nn.L1Loss() BCE_stable = nn.BCEWithLogitsLoss() criterionCloth = nn.L1Loss() # Variables ya = torch.FloatTensor(opt.batch_size) yb = torch.FloatTensor(opt.batch_size) u = torch.FloatTensor((opt.batch_size, 1, 1, 1)) grad_outputs = torch.ones(opt.batch_size) # Everything cuda if opt.cuda: G.cuda() D.cuda() human_parser.cuda() criterionWarp = criterionWarp.cuda() criterionPerceptual = criterionPerceptual.cuda() criterionL1 = criterionL1.cuda() BCE_stable.cuda() criterionCloth = criterionCloth.cuda() ya = ya.cuda() yb = yb.cuda() u = u.cuda() grad_outputs = grad_outputs.cuda() # DataParallel G = nn.DataParallel(G) D = nn.DataParallel(D) human_parser = nn.DataParallel(human_parser) # Optimizers optimizerD = torch.optim.Adam(D.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizerG = torch.optim.Adam(G.parameters(), lr=opt.lr, betas=(0.5, 0.999)) # Fitting model step_start_time = time.time() for step in range(opt.n_iter): ######################## # (1) Update D network # ######################## for p in D.parameters(): p.requires_grad = True for t in range(opt.Diters): D.zero_grad() inputs = train_loader.next_batch() pa = inputs['image'].cuda() ap = inputs['agnostic'].cuda() cb = inputs['another_cloth'].cuda() del inputs current_batch_size = pa.size(0) ya_pred = D(pa) _, pb_fake = G(cb, ap) # Detach y_pred_fake from the neural network G and put it inside D yb_pred_fake = D(pb_fake.detach()) ya.data.resize_(current_batch_size).fill_(1) yb.data.resize_(current_batch_size).fill_(0) errD = (BCE_stable(ya_pred - torch.mean(yb_pred_fake), ya) + BCE_stable(yb_pred_fake - torch.mean(ya_pred), yb)) / 2.0 errD.backward() # Gradient penalty with torch.no_grad(): u.resize_(current_batch_size, 1, 1, 1).uniform_(0, 1) grad_outputs.data.resize_(current_batch_size) x_both = pa * u + pb_fake * (1. - u) # We only want the gradients with respect to x_both x_both = Variable(x_both, requires_grad=True) grad = torch.autograd.grad(outputs=D(x_both), inputs=x_both, grad_outputs=grad_outputs, retain_graph=True, create_graph=True, only_inputs=True)[0] # We need to norm 3 times (over n_colors x image_size x image_size) to get only a vector of size # "batch_size" grad_penalty = opt.penalty * ( (grad.norm(2, 1).norm(2, 1).norm(2, 1) - 1)**2).mean() grad_penalty.backward() optimizerD.step() ######################## # (2) Update G network # ######################## for p in D.parameters(): p.requires_grad = False for t in range(opt.Giters): inputs = train_loader.next_batch() pa = inputs['image'].cuda() ap = inputs['agnostic'].cuda() ca = inputs['cloth'].cuda() cb = inputs['another_cloth'].cuda() parse_cloth = inputs['parse_cloth'].cuda() del inputs current_batch_size = pa.size(0) # paired data G.zero_grad() warped_cloth_a, pa_fake = G(ca, ap) if step >= opt.human_parser_step: # 生成的图片较真实后再添加human parser parse_pa_fake = human_parser(pa_fake) # (N,H,W) parse_ca_fake = (parse_pa_fake == 5) + \ (parse_pa_fake == 6) + \ (parse_pa_fake == 7) # [0,1] (N,H,W) parse_ca_fake = parse_ca_fake.unsqueeze(1).type_as( pa_fake) # (N,1,H,W) ca_fake = pa_fake * parse_ca_fake + (1 - parse_ca_fake ) # [-1,1] with torch.no_grad(): parse_pa_fake_vis = visualize_seg(parse_pa_fake) l_cloth_p = criterionCloth(ca_fake, warped_cloth_a) else: with torch.no_grad(): ca_fake = torch.zeros_like(pa_fake) parse_pa_fake_vis = torch.zeros_like(pa_fake) l_cloth_p = torch.zeros(1).cuda() l_warp = 20 * criterionWarp(warped_cloth_a, parse_cloth) l_perceptual = criterionPerceptual(pa_fake, pa) l_L1 = criterionL1(pa_fake, pa) loss_p = l_warp + l_perceptual + l_L1 + l_cloth_p loss_p.backward() optimizerG.step() # unpaired data G.zero_grad() warped_cloth_b, pb_fake = G(cb, ap) if step >= opt.human_parser_step: # 生成的图片较真实后再添加human parser parse_pb_fake = human_parser(pb_fake) parse_cb_fake = (parse_pb_fake == 5) + \ (parse_pb_fake == 6) + \ (parse_pb_fake == 7) # [0,1] (N,H,W) parse_cb_fake = parse_cb_fake.unsqueeze(1).type_as( pb_fake) # (N,1,H,W) cb_fake = pb_fake * parse_cb_fake + (1 - parse_cb_fake ) # [-1,1] with torch.no_grad(): parse_pb_fake_vis = visualize_seg(parse_pb_fake) l_cloth_up = criterionCloth(cb_fake, warped_cloth_b) else: with torch.no_grad(): cb_fake = torch.zeros_like(pb_fake) parse_pb_fake_vis = torch.zeros_like(pb_fake) l_cloth_up = torch.zeros(1).cuda() with torch.no_grad(): ya.data.resize_(current_batch_size).fill_(1) yb.data.resize_(current_batch_size).fill_(0) ya_pred = D(pa) yb_pred_fake = D(pb_fake) # Non-saturating l_adv = 0.1 * ( BCE_stable(ya_pred - torch.mean(yb_pred_fake), yb) + BCE_stable(yb_pred_fake - torch.mean(ya_pred), ya)) / 2 loss_up = l_adv + l_cloth_up loss_up.backward() optimizerG.step() # visuals = [ # [cb, warped_cloth_b, pb_fake], # [ca, warped_cloth_a, pa_fake], # [ap, parse_cloth, pa] # ] visuals = [[ cb, warped_cloth_b, pb_fake, cb_fake, parse_pb_fake_vis ], [ca, warped_cloth_a, pa_fake, ca_fake, parse_pa_fake_vis], [ap, parse_cloth, pa]] if (step + 1) % opt.display_count == 0: board_add_images(board, 'combine', visuals, step + 1) board.add_scalar('loss_p', loss_p.item(), step + 1) board.add_scalar('l_warp', l_warp.item(), step + 1) board.add_scalar('l_perceptual', l_perceptual.item(), step + 1) board.add_scalar('l_L1', l_L1.item(), step + 1) board.add_scalar('l_cloth_p', l_cloth_p.item(), step + 1) board.add_scalar('loss_up', loss_up.item(), step + 1) board.add_scalar('l_adv', l_adv.item(), step + 1) board.add_scalar('l_cloth_up', l_cloth_up.item(), step + 1) board.add_scalar('errD', errD.item(), step + 1) t = time.time() - step_start_time print( 'step: %8d, time: %.3f, loss_p: %4f, loss_up: %.4f, l_adv: %.4f, errD: %.4f' % (step + 1, t, loss_p.item(), loss_up.item(), l_adv.item(), errD.item()), flush=True) step_start_time = time.time() if (step + 1) % opt.save_count == 0: save_checkpoint( G, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step + 1)))
def train_tom(opt, train_loader, model, d_g, d_l, board): model.cuda() model.train() d_g.cuda() d_g.train() d_l.cuda() d_l.train() #reverse label dis_label_G = Variable(torch.FloatTensor(opt.batch_size, 1)).fill_(0.).cuda() dis_label_real = Variable(torch.FloatTensor(opt.batch_size, 1)).fill_(0.).cuda() dis_label_fake = Variable(torch.FloatTensor(opt.batch_size, 1)).fill_(1.).cuda() # criterion criterionL1 = nn.L1Loss() criterionVGG = VGGLoss() criterionMask = nn.L1Loss() criterionGAN = nn.BCELoss() #MSE # optimizer optimizerG = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizerDG = torch.optim.Adam(d_g.parameters(), lr=opt.lr, betas=(0.5, 0.999)) optimizerDL = torch.optim.Adam(d_l.parameters(), lr=opt.lr, betas=(0.5, 0.999)) schedulerG = torch.optim.lr_scheduler.LambdaLR( optimizerG, lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float( opt.decay_step + 1)) schedulerDG = torch.optim.lr_scheduler.LambdaLR( optimizerDG, lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float( opt.decay_step + 1)) schedulerDL = torch.optim.lr_scheduler.LambdaLR( optimizerDL, lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float( opt.decay_step + 1)) for step in range(opt.keep_step + opt.decay_step): iter_start_time = time.time() #dis_label_noise dis_label_noise = random.random() / 10 dis_label_real = dis_label_real.data.fill_(0.0 + random.random() * opt.noise) dis_label_fake = dis_label_fake.data.fill_(1.0 - random.random() * opt.noise) #prep inputs = train_loader.next_batch() im = inputs['image'].cuda() #sz=b*3*256*192 im_pose = inputs['pose_image'] im_h = inputs['head'] shape = inputs['shape'] agnostic = inputs['agnostic'].cuda() c = inputs['cloth'].cuda() cm = inputs['cloth_mask'].cuda() batch_size = im.size(0) if batch_size != opt.batch_size: continue #D_real errDg_real = criterionGAN(d_g(torch.cat([agnostic, c, im], 1)), dis_label_real) #generate image outputs = model(torch.cat([agnostic, c], 1)) p_rendered, m_composite = torch.split(outputs, 3, 1) p_rendered = torch.tanh(p_rendered) m_composite = torch.sigmoid(m_composite) p_tryon = c * m_composite + p_rendered * (1 - m_composite) real_crop, fake_crop = random_crop(im, p_tryon, opt.winsize) errDl_real = criterionGAN(d_l(real_crop), dis_label_real) #tom_train errGg_fake = criterionGAN(d_g(torch.cat([agnostic, c, p_tryon], 1)), dis_label_G) errGl_fake = criterionGAN(d_l(fake_crop), dis_label_G) loss_l1 = criterionL1(p_tryon, im) loss_vgg = criterionVGG(p_tryon, im) loss_mask = criterionMask(m_composite, cm) loss_GAN = (errGg_fake + errGl_fake * opt.alpha) / batch_size loss = loss_l1 + loss_vgg + loss_mask + loss_GAN #D_fake errDg_fake = criterionGAN( d_g(torch.cat([agnostic, c, p_tryon], 1).detach()), dis_label_fake) loss_Dg = (errDg_fake + errDg_real) / 2 errDl_fake = criterionGAN(d_l(fake_crop.detach()), dis_label_fake) loss_Dl = (errDl_fake + errDl_real) / 2 optimizerG.zero_grad() loss.backward() optimizerG.step() optimizerDL.zero_grad() loss_Dl.backward() optimizerDL.step() optimizerDG.zero_grad() loss_Dg.backward() optimizerDG.step() #tensorboradX visuals = [[im_h, shape, im_pose], [c, cm * 2 - 1, m_composite * 2 - 1], [p_rendered, p_tryon, im]] if (step + 1) % opt.display_count == 0: t = time.time() - iter_start_time loss_dict = { "TOT": loss.item(), "L1": loss_l1.item(), "VG": loss_vgg.item(), "Mk": loss_mask.item(), "G": loss_GAN.item(), "DG": loss_Dg.item(), "DL": loss_Dl.item() } print('step: %d|time: %.3f' % (step + 1, t), end="") sm_image(combine_images(im, p_tryon, real_crop, fake_crop), "combined%d.jpg" % step, opt.debug) board_add_images(board, 'combine', visuals, step + 1) for k, v in loss_dict.items(): print('|%s: %.3f' % (k, v), end="") board.add_scalar(k, v, step + 1) print() if (step + 1) % opt.save_count == 0: save_checkpoints( model, d_g, d_l, os.path.join(opt.checkpoint_dir, opt.stage + '_' + opt.name, "step%06d" % step, '%s.pth'))
def train_refined_gmm(opt, train_loader, model, board): model.cuda() model.train() loss_weight = opt.loss_weight # if loss_weight > 0.01: # print("Error") # assert False # criterion warped_criterionL1 = nn.L1Loss() result_criterionL1 = nn.L1Loss() point_criterionL1 = nn.L1Loss() criterionMask = nn.L1Loss() criterionVGG = VGGLoss() criterionGram = GramLoss() rendered_criterionL1 = nn.L1Loss() center_mask_critetionL1 = nn.L1Loss() warped_mask_criterionL1 = nn.L1Loss() # optimizer optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float( opt.decay_step + 1)) for step in range(opt.keep_step + opt.decay_step): iter_start_time = time.time() inputs = train_loader.next_batch() im = inputs['image'].cuda() im_pose = inputs['pose_image'].cuda() im_h = inputs['head'].cuda() shape = inputs['shape'].cuda() densepose_shape = inputs['densepose_shape'].cuda() agnostic = inputs['agnostic'].cuda() c = inputs['cloth'].cuda() cm = inputs['cloth_mask'].cuda() im_c = inputs['parse_cloth'].cuda() im_g = inputs['grid_image'].cuda() parse_cloth_mask = inputs['parse_cloth_mask'].cuda() target_shape = inputs['target_shape'] c_point_plane = inputs['cloth_points'].cuda() p_point_plane = inputs['person_points'].cuda() grid, theta, warped_cloth, outputs = model(agnostic, c) #warped_cloth = F.grid_sample(c, grid, padding_mode='border') warped_mask = F.grid_sample(cm, grid, padding_mode='zeros') warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros') compute_c_point_plane = compute_grid_point(p_point_plane, grid) warped_mask_loss = 0 if opt.add_warped_mask_loss: warped_mask_loss += warped_mask_criterionL1( warped_mask, target_shape) c_rendered, m_composite = torch.split(outputs, 3, 1) c_rendered = F.tanh(c_rendered) m_composite = F.sigmoid(m_composite) c_result = warped_cloth * m_composite + c_rendered * (1 - m_composite) visuals = [[im_h, shape, im_pose], [c, warped_cloth, im_c], [warped_grid, (warped_cloth + im) * 0.5, im], [m_composite, (c_result + im) * 0.5, c_result]] loss_warped_cloth = warped_criterionL1(warped_cloth, im_c) loss_point = 0 if opt.add_point_loss: loss_point = point_criterionL1(compute_c_point_plane, c_point_plane) loss_c_result = result_criterionL1(c_result, im_c) loss_mask = criterionMask(m_composite, warped_mask) loss_vgg = 0 if opt.add_vgg_loss: loss_vgg = criterionVGG(c_result, im_c) loss_gram = 0 if opt.add_gram_loss: loss_gram += criterionGram(c_result, im_c) loss_render = 0 if opt.add_render_loss: loss_render += rendered_criterionL1(c_rendered, im_c) loss_mask_constrain = 0 if opt.add_mask_constrain: center_mask = m_composite * parse_cloth_mask ground_mask = torch.ones_like(parse_cloth_mask, dtype=torch.float) ground_mask = ground_mask * warped_mask * parse_cloth_mask loss_mask_constrain = center_mask_critetionL1( center_mask, ground_mask) #print("long_mask_constrain = ", loss_mask_constrain) loss_mask_constrain = loss_mask_constrain * opt.mask_constrain_weight #print("long_mask_constrain = ", loss_mask_constrain) # print("loss cloth = ", loss_warped_cloth) # print("loss point = ", loss_point) # print("loss render = ", loss_render) # print("loss_c_result = ", loss_c_result) loss = loss_warped_cloth + loss_weight * loss_point + loss_c_result + loss_mask + loss_vgg + loss_render + loss_mask_constrain + warped_mask_loss + loss_gram optimizer.zero_grad() loss.backward() optimizer.step() if (step + 1) % opt.display_count == 0: board_add_images(board, 'combine', visuals, step + 1) board.add_scalar('metric', loss.item(), step + 1) t = time.time() - iter_start_time print('step: %8d, time: %.3f, loss: %4f' % (step + 1, t, loss.item()), flush=True) if (step + 1) % opt.save_count == 0: save_checkpoint( model, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step + 1)))
def __init__(self): super(Content_loss, self).__init__() self.l1loss = nn.L1Loss() self.vgg_loss = VGGLoss()
def __init__(self, hyperparameters): super(MUNIT_Trainer, self).__init__() lr = hyperparameters["lr"] self.newsize = hyperparameters["crop_image_height"] self.semantic_w = hyperparameters["semantic_w"] > 0 self.recon_mask = hyperparameters["recon_mask"] == 1 self.dann_scheduler = None self.full_adaptation = hyperparameters["adaptation"][ "full_adaptation"] == 1 dim = hyperparameters["gen"]["dim"] n_downsample = hyperparameters["gen"]["n_downsample"] latent_dim = dim * (2**n_downsample) if "domain_adv_w" in hyperparameters.keys(): self.domain_classif_ab = hyperparameters["domain_adv_w"] > 0 else: self.domain_classif_ab = False if hyperparameters["adaptation"]["dfeat_lambda"] > 0: self.use_classifier_sr = True else: self.use_classifier_sr = False if hyperparameters["adaptation"]["sem_seg_lambda"] > 0: self.train_seg = True else: self.train_seg = False if hyperparameters["adaptation"]["output_classifier_lambda"] > 0: self.use_output_classifier_sr = True else: self.use_output_classifier_sr = False self.gen = SpadeGen(hyperparameters["input_dim_a"], hyperparameters["gen"]) # Note: the "+1" is for the masks if hyperparameters["dis"]["type"] == "patchgan": print("Using patchgan discrminator...") self.dis_a = MultiscaleDiscriminator( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a self.dis_b = MultiscaleDiscriminator( hyperparameters["input_dim_b"], hyperparameters["dis"]) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.dis_a_masked = MultiscaleDiscriminator( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a self.dis_b_masked = MultiscaleDiscriminator( hyperparameters["input_dim_b"], hyperparameters["dis"]) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) else: self.dis_a = MsImageDis( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a self.dis_b = MsImageDis( hyperparameters["input_dim_b"], hyperparameters["dis"]) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) self.dis_a_masked = MsImageDis( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a self.dis_b_masked = MsImageDis( hyperparameters["input_dim_b"], hyperparameters["dis"]) # discriminator for domain b self.instancenorm = nn.InstanceNorm2d(512, affine=False) # fix the noise usd in sampling display_size = int(hyperparameters["display_size"]) # Setup the optimizers beta1 = hyperparameters["beta1"] beta2 = hyperparameters["beta2"] dis_params = (list(self.dis_a.parameters()) + list(self.dis_b.parameters()) + list(self.dis_a_masked.parameters()) + list(self.dis_b_masked.parameters())) gen_params = list(self.gen.parameters()) self.dis_opt = torch.optim.Adam( [p for p in dis_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.gen_opt = torch.optim.Adam( [p for p in gen_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) # Network weight initialization self.apply(weights_init(hyperparameters["init"])) self.dis_a.apply(weights_init("gaussian")) self.dis_b.apply(weights_init("gaussian")) self.dis_a_masked.apply(weights_init("gaussian")) self.dis_b_masked.apply(weights_init("gaussian")) # Load VGG model if needed if hyperparameters["vgg_w"] > 0: self.criterionVGG = VGGLoss() # Load semantic segmentation model if needed if "semantic_w" in hyperparameters.keys( ) and hyperparameters["semantic_w"] > 0: self.segmentation_model = load_segmentation_model( hyperparameters["semantic_ckpt_path"], 19) self.segmentation_model.eval() for param in self.segmentation_model.parameters(): param.requires_grad = False # Load domain classifier if needed if "domain_adv_w" in hyperparameters.keys( ) and hyperparameters["domain_adv_w"] > 0: self.domain_classifier_ab = domainClassifier(input_dim=latent_dim, dim=256) dann_params = list(self.domain_classifier_ab.parameters()) self.dann_opt = torch.optim.Adam( [p for p in dann_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.domain_classifier_ab.apply(weights_init("gaussian")) self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters) # Load classifier on features for syn, real adaptation if self.use_classifier_sr: #! Hardcoded self.domain_classifier_sr_b = domainClassifier( input_dim=latent_dim, dim=256) self.domain_classifier_sr_a = domainClassifier( input_dim=latent_dim, dim=256) dann_params = list( self.domain_classifier_sr_a.parameters()) + list( self.domain_classifier_sr_b.parameters()) self.classif_opt_sr = torch.optim.Adam( [p for p in dann_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.domain_classifier_sr_a.apply(weights_init("gaussian")) self.domain_classifier_sr_b.apply(weights_init("gaussian")) self.classif_sr_scheduler = get_scheduler(self.classif_opt_sr, hyperparameters) if self.use_output_classifier_sr: if self.hyperparameters["dis"]["type"] == "patchgan": self.output_classifier_sr_a = MultiscaleDiscriminator( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a,sr self.output_classifier_sr_b = MultiscaleDiscriminator( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain b,sr else: self.output_classifier_sr_a = MsImageDis( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain a,sr self.output_classifier_sr_b = MsImageDis( hyperparameters["input_dim_a"], hyperparameters["dis"]) # discriminator for domain b,sr dann_params = list( self.output_classifier_sr_a.parameters()) + list( self.output_classifier_sr_b.parameters()) self.output_classif_opt_sr = torch.optim.Adam( [p for p in dann_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.output_classifier_sr_b.apply(weights_init("gaussian")) self.output_classifier_sr_a.apply(weights_init("gaussian")) self.output_scheduler_sr = get_scheduler( self.output_classif_opt_sr, hyperparameters) if self.train_seg: pretrained = load_segmentation_model( hyperparameters["semantic_ckpt_path"], 19) last_layer = nn.Conv2d(512, 10, kernel_size=1) model = torch.nn.Sequential( *list(pretrained.resnet34_8s.children())[7:-1], last_layer.cuda()) self.segmentation_head = model for param in self.segmentation_head.parameters(): param.requires_grad = True dann_params = list(self.segmentation_head.parameters()) self.segmentation_opt = torch.optim.Adam( [p for p in dann_params if p.requires_grad], lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters["weight_decay"], ) self.scheduler_seg = get_scheduler(self.segmentation_opt, hyperparameters)
def train_tom(opt, train_loader, model, board): # load model model.cuda() model.train() # criterion criterionL1 = nn.L1Loss() criterionVGG = VGGLoss() criterionMask = nn.L1Loss() # optimizer optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda step: 1.0 - max(0, step - opt.keep_step) / float(opt.decay_step + 1)) # train log if not opt.checkpoint == '': train_log = open(os.path.join(opt.checkpoint_dir, opt.name, 'train_log.txt'), 'a') else: os.makedirs(os.path.join(opt.checkpoint_dir, opt.name), exist_ok=True) train_log = open(os.path.join(opt.checkpoint_dir, opt.name, 'train_log.txt'), 'w') train_log.write('='*30 + ' Training Option ' + '='*30 + '\n') train_log.write(str(opt) + '\n\n') train_log.write('='*30 + ' Network Architecture ' + '='*30 + '\n') print(str(model) + '\n', file=train_log) train_log.write('='*30 + ' Training Log ' + '='*30 + '\n') # train loop checkpoint_step = 0 if not opt.checkpoint == '': checkpoint_step += int(opt.checkpoint.split('/')[-1][5:11]) for step in range(checkpoint_step, opt.keep_step + opt.decay_step): iter_start_time = time.time() dl_iter = iter(train_loader) inputs = dl_iter.next() im = inputs['image'].cuda() im_pose = inputs['pose_image'] im_h = inputs['head'] shape = inputs['shape'] agnostic = inputs['agnostic'].cuda() c = inputs['cloth'].cuda() cm = inputs['cloth_mask'].cuda() outputs = model(torch.cat([agnostic, c],1)) p_rendered, m_composite = torch.split(outputs, 3,1) p_rendered = torch.tanh(p_rendered) m_composite = torch.sigmoid(m_composite) p_tryon = c * m_composite+ p_rendered * (1 - m_composite) visuals = [ [im_h, shape, im_pose], [c, cm*2-1, m_composite*2-1], [p_rendered, p_tryon, im]] loss_l1 = criterionL1(p_tryon, im) loss_vgg = criterionVGG(p_tryon, im) loss_mask = criterionMask(m_composite, cm) loss = loss_l1 + loss_vgg + loss_mask optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() if (step+1) % opt.display_count == 0: board_add_images(board, 'combine', visuals, step+1) board.add_scalar('metric', loss.item(), step+1) board.add_scalar('L1', loss_l1.item(), step+1) board.add_scalar('VGG', loss_vgg.item(), step+1) board.add_scalar('MaskL1', loss_mask.item(), step+1) t = time.time() - iter_start_time print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' % (step+1, t, loss.item(), loss_l1.item(), loss_vgg.item(), loss_mask.item()), flush=True) train_log.write('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' % (step+1, t, loss.item(), loss_l1.item(), loss_vgg.item(), loss_mask.item()) + '\n') if (step+1) % opt.save_count == 0: save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))
def train_tom_gmm(opt, train_loader, model, model_module, gmm_model, gmm_model_module, board): model.train() gmm_model.train() # criterion criterionL1 = nn.L1Loss() criterionVGG = VGGLoss() criterionMask = nn.L1Loss() # optimizer optimizer = torch.optim.Adam(list(model.parameters()) + list(gmm_model.parameters()), lr=opt.lr, betas=(0.5, 0.999)) for step in range(opt.keep_step + opt.decay_step): iter_start_time = time.time() inputs = train_loader.next_batch() im = inputs['image'].cuda() im_pose = inputs['pose_image'] im_h = inputs['head'] shape = inputs['shape'] im_c = inputs['parse_cloth'].cuda() agnostic = inputs['agnostic'].cuda() c = inputs['cloth'].cuda() cm = inputs['cloth_mask'].cuda() with torch.no_grad(): grid, theta = gmm_model(agnostic, c) c = F.grid_sample(c, grid, padding_mode='border') cm = F.grid_sample(cm, grid, padding_mode='zeros') # grid, theta = model(agnostic, c) # warped_cloth = F.grid_sample(c, grid, padding_mode='border') # warped_mask = F.grid_sample(cm, grid, padding_mode='zeros') # warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros') outputs = model(torch.cat([agnostic, c], 1)) p_rendered, m_composite = torch.split(outputs, 3, 1) p_rendered = F.tanh(p_rendered) m_composite = F.sigmoid(m_composite) p_tryon = c * m_composite + p_rendered * (1 - m_composite) visuals = [[im_h, shape, im_pose], [c, cm * 2 - 1, m_composite * 2 - 1], [p_rendered, p_tryon, im]] loss_l1 = criterionL1(p_tryon, im) loss_vgg = criterionVGG(p_tryon, im) loss_mask = criterionMask(m_composite, cm) loss_warp = criterionL1(c, im_c) loss = loss_l1 + loss_vgg + loss_mask + loss_warp optimizer.zero_grad() loss.backward() optimizer.step() if (step + 1) % opt.display_count == 0 and single_gpu_flag(opt): board_add_images(board, 'combine', visuals, step + 1) board.add_scalar('metric', loss.item(), step + 1) board.add_scalar('L1', loss_l1.item(), step + 1) board.add_scalar('VGG', loss_vgg.item(), step + 1) board.add_scalar('MaskL1', loss_mask.item(), step + 1) board.add_scalar('Warp', loss_warp.item(), step + 1) t = time.time() - iter_start_time print( 'step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f, warp: %.4f' % (step + 1, t, loss.item(), loss_l1.item(), loss_vgg.item(), loss_mask.item(), loss_warp.item()), flush=True) if (step + 1) % opt.save_count == 0 and single_gpu_flag(opt): save_checkpoint( model_module, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step + 1))) save_checkpoint( gmm_model_module, os.path.join(opt.checkpoint_dir, opt.name, 'step_warp_%06d.pth' % (step + 1)))
def train_tom(opt, train_loader, model, board): device = torch.device("cuda:0") model.to(device) #model.cuda() model.train() # criterion criterionL1 = nn.L1Loss() criterionVGG = VGGLoss() criterionMask = nn.L1Loss() # optimizer optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float( opt.decay_step + 1), ) pbar = tqdm(range(opt.keep_step + opt.decay_step)) for step in pbar: inputs = train_loader.next_batch() im = inputs["image"].to(device) #.cuda() im_pose = inputs["pose_image"] im_h = inputs["head"] shape = inputs["shape"] agnostic = inputs["agnostic"].to(device) # .cuda() c = inputs["cloth"].to(device) #.cuda() cm = inputs["cloth_mask"].to(device) #.cuda() concat_tensor = torch.cat([agnostic, c], 1) concat_tensor = concat_tensor.to(device) outputs = model(concat_tensor) p_rendered, m_composite = torch.split(outputs, 3, 1) p_rendered = F.tanh(p_rendered) m_composite = F.sigmoid(m_composite) p_tryon = c * m_composite + p_rendered * (1 - m_composite) visuals = [ [im_h, shape, im_pose], [c, cm * 2 - 1, m_composite * 2 - 1], [p_rendered, p_tryon, im], ] loss_l1 = criterionL1(p_tryon, im) loss_vgg = criterionVGG(p_tryon, im) loss_mask = criterionMask(m_composite, cm) loss = loss_l1 + loss_vgg + loss_mask optimizer.zero_grad() loss.backward() optimizer.step() tqdm.set_description( f"loss: {loss.item():.4f}, l1: {loss_l1.item():.4f}, vgg: {loss_vgg.item():.4f}, mask: {loss_mask.item():.4f}", ) if board and (step + 1) % opt.display_count == 0: board_add_images(board, "combine", visuals, step + 1) board.add_scalar("metric", loss.item(), step + 1) board.add_scalar("L1", loss_l1.item(), step + 1) board.add_scalar("VGG", loss_vgg.item(), step + 1) board.add_scalar("MaskL1", loss_mask.item(), step + 1) print( f"step: {step + 1:8d}, loss: {loss.item():.4f}, l1: {loss_l1.item():.4f}, vgg: {loss_vgg.item():.4f}, mask: {loss_mask.item():.4f}", flush=True, ) if (step + 1) % opt.save_count == 0: save_checkpoint( model, os.path.join(opt.checkpoint_dir, opt.name, "step_%06d.pth" % (step + 1)), )
def train_tom(opt, train_loader, model, board): model.cuda() model.train() dic = { } dic["steps"] = [] dic["loss"] = [] dic["l1"] = [] dic["vgg"] = [] dic["mask"] = [] # criterion criterionL1 = nn.L1Loss() criterionVGG = VGGLoss() criterionMask = nn.L1Loss() # optimizer optimizer = torch.optim.Adam( model.parameters(), lr=opt.lr, betas=(0.5, 0.999)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(opt.decay_step + 1)) for step in range(opt.keep_step + opt.decay_step): iter_start_time = time.time() inputs = train_loader.next_batch() im = inputs['image'].cuda() im_pose = inputs['pose_image'] im_h = inputs['head'] shape = inputs['shape'] agnostic = inputs['agnostic'].cuda() c = inputs['cloth'].cuda() cm = inputs['cloth_mask'].cuda() pcm = inputs['parse_cloth_mask'].cuda() # outputs = model(torch.cat([agnostic, c], 1)) # CP-VTON outputs = model(torch.cat([agnostic, c, cm], 1)) # CP-VTON+ p_rendered, m_composite = torch.split(outputs, 3, 1) p_rendered = F.tanh(p_rendered) m_composite = F.sigmoid(m_composite) p_tryon = c * m_composite + p_rendered * (1 - m_composite) """visuals = [[im_h, shape, im_pose], [c, cm*2-1, m_composite*2-1], [p_rendered, p_tryon, im]]""" # CP-VTON visuals = [[im_h, shape, im_pose], [c, pcm*2-1, m_composite*2-1], [p_rendered, p_tryon, im]] # CP-VTON+ loss_l1 = criterionL1(p_tryon, im) loss_vgg = criterionVGG(p_tryon, im) # loss_mask = criterionMask(m_composite, cm) # CP-VTON loss_mask = criterionMask(m_composite, pcm) # CP-VTON+ loss = loss_l1 + loss_vgg + loss_mask optimizer.zero_grad() loss.backward() optimizer.step() if (step+1) % opt.display_count == 0: board_add_images(board, 'combine', visuals, step+1) board.add_scalar('metric', loss.item(), step+1) board.add_scalar('L1', loss_l1.item(), step+1) board.add_scalar('VGG', loss_vgg.item(), step+1) board.add_scalar('MaskL1', loss_mask.item(), step+1) t = time.time() - iter_start_time print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' % (step+1, t, loss.item(), loss_l1.item(), loss_vgg.item(), loss_mask.item()), flush=True) if (step+1) % opt.save_count == 0: save_checkpoint(model, os.path.join( opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1))) if (step+1) % 5000 == 0: dic["steps"].append(step) dic["loss"].append(loss.item()) dic["l1"].append(loss_l1.item()) dic["vgg"].append(loss_vgg.item()) dic["mask"].append(loss_mask.item()) with open('lossvstep/tom.pickle', 'wb') as handle: pickle.dump(dic, handle, protocol=pickle.HIGHEST_PROTOCOL)