Example #1
0
    def __init__(self, opt):
        '''
        opt, in_channels=19, depth=4,
                 start_filts=64, up_mode='transpose',
                 merge_mode='concat')
        :param opt:
        '''
        super(cyclegan, self).__init__()
        self.Generator = UNet(22, 4)
        self.Discriminator = NLayerDiscriminator()
        self.PairDis = PairDiscriminator()

        self.criterionGAN = GANLoss("lsgan")
        self.PairGAN = GANLoss("lsgan")
        self.loss_1 = nn.L1Loss()
        self.loss_2 = nn.MSELoss()
        self.optimizer_D = torch.optim.Adam(self.Discriminator.parameters(),
                                            lr=opt.lr,
                                            betas=(0.5, 0.999))

        self._optimizer_G = torch.optim.Adam(self.Generator.parameters(),
                                             lr=opt.lr,
                                             betas=(0.5, 0.999))
        self.vgg_loss = VGGLoss()
        self.content_loss = Content_loss()
Example #2
0
def train_tom(opt, train_loader, model, board):
    model.cuda()
    model.train()
    
    # criterion
    criterionL1 = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionMask = nn.L1Loss()
    
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda step: 1.0 -
            max(0, step - opt.keep_step) / float(opt.decay_step + 1))
    
    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()
            
        im = inputs['image'].cuda()
        im_pose = inputs['pose_image']
        im_h = inputs['head']
        shape = inputs['shape']

        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        
        outputs = model(torch.cat([agnostic, c],1))
        p_rendered, m_composite = torch.split(outputs, 3,1)
        p_rendered = F.tanh(p_rendered)
        m_composite = F.sigmoid(m_composite)
        p_tryon = c * m_composite+ p_rendered * (1 - m_composite)

        visuals = [ [im_h, shape, im_pose], 
                   [c, cm*2-1, m_composite*2-1], 
                   [p_rendered, p_tryon, im]]
            
        loss_l1 = criterionL1(p_tryon, im)
        loss_vgg = criterionVGG(p_tryon, im)
        loss_mask = criterionMask(m_composite, cm)
        loss = loss_l1 + loss_vgg + loss_mask
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
        if (step+1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step+1)
            board.add_scalar('metric', loss.item(), step+1)
            board.add_scalar('L1', loss_l1.item(), step+1)
            board.add_scalar('VGG', loss_vgg.item(), step+1)
            board.add_scalar('MaskL1', loss_mask.item(), step+1)
            #board.add_graph(model, torch.cat([agnostic, c],1))
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' 
                    % (step+1, t, loss.item(), loss_l1.item(), 
                    loss_vgg.item(), loss_mask.item()), flush=True)

        if (step+1) % opt.save_count == 0:
            save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))
Example #3
0
    def __init__(self,
                 gen,
                 dis,
                 dataloader_train,
                 dataloader_val,
                 gpu_id,
                 log_freq,
                 save_dir,
                 n_step,
                 optimizer='adam'):
        if torch.cuda.is_available():
            self.device = torch.device('cuda:' + str(gpu_id))
        else:
            self.device = torch.device('cpu')
        self.gen = gen.to(self.device)
        self.dis = dis.to(self.device)

        self.dataloader_train = dataloader_train
        self.dataloader_val = dataloader_val

        if optimizer == 'adam':
            self.optim_g = torch.optim.Adam(gen.parameters(),
                                            lr=1e-4,
                                            betas=(0.5, 0.999))
            self.optim_d = torch.optim.Adam(dis.parameters(),
                                            lr=1e-4,
                                            betas=(0.5, 0.999))
        elif optimizer == 'ranger':
            self.optim_g = Ranger(gen.parameters())
            self.optim_d = Ranger(dis.parameters())

        self.criterionL1 = nn.L1Loss()
        self.criterionVGG = VGGLoss()
        self.criterionAdv = torch.nn.BCELoss()
        self.log_freq = log_freq
        self.save_dir = save_dir
        self.n_step = n_step
        self.step = 0
        print('Generator Parameters:',
              sum([p.nelement() for p in self.gen.parameters()]))
        print('Discriminator Parameters:',
              sum([p.nelement() for p in self.dis.parameters()]))
Example #4
0
def train(opt, train_loader, G, D, board):
    human_parser = HumanParser(opt)
    human_parser.eval()
    G.train()
    D.train()

    # palette = get_palette()

    # Criterion
    criterionWarp = nn.L1Loss()
    criterionPerceptual = VGGLoss()
    criterionL1 = nn.L1Loss()
    BCE_stable = nn.BCEWithLogitsLoss()
    criterionCloth = nn.L1Loss()

    # Variables
    ya = torch.FloatTensor(opt.batch_size)
    yb = torch.FloatTensor(opt.batch_size)
    u = torch.FloatTensor((opt.batch_size, 1, 1, 1))
    grad_outputs = torch.ones(opt.batch_size)

    # Everything cuda
    if opt.cuda:
        G.cuda()
        D.cuda()
        human_parser.cuda()
        criterionWarp = criterionWarp.cuda()
        criterionPerceptual = criterionPerceptual.cuda()
        criterionL1 = criterionL1.cuda()
        BCE_stable.cuda()
        criterionCloth = criterionCloth.cuda()

        ya = ya.cuda()
        yb = yb.cuda()
        u = u.cuda()
        grad_outputs = grad_outputs.cuda()

        # DataParallel
        G = nn.DataParallel(G)
        D = nn.DataParallel(D)
        human_parser = nn.DataParallel(human_parser)

    # Optimizers
    optimizerD = torch.optim.Adam(D.parameters(),
                                  lr=opt.lr,
                                  betas=(0.5, 0.999))
    optimizerG = torch.optim.Adam(G.parameters(),
                                  lr=opt.lr,
                                  betas=(0.5, 0.999))

    # Fitting model
    step_start_time = time.time()
    for step in range(opt.n_iter):
        ########################
        # (1) Update D network #
        ########################

        for p in D.parameters():
            p.requires_grad = True

        for t in range(opt.Diters):
            D.zero_grad()

            inputs = train_loader.next_batch()
            pa = inputs['image'].cuda()
            ap = inputs['agnostic'].cuda()
            cb = inputs['another_cloth'].cuda()
            del inputs

            current_batch_size = pa.size(0)
            ya_pred = D(pa)
            _, pb_fake = G(cb, ap)

            # Detach y_pred_fake from the neural network G and put it inside D
            yb_pred_fake = D(pb_fake.detach())
            ya.data.resize_(current_batch_size).fill_(1)
            yb.data.resize_(current_batch_size).fill_(0)

            errD = (BCE_stable(ya_pred - torch.mean(yb_pred_fake), ya) +
                    BCE_stable(yb_pred_fake - torch.mean(ya_pred), yb)) / 2.0
            errD.backward()

            # Gradient penalty
            with torch.no_grad():
                u.resize_(current_batch_size, 1, 1, 1).uniform_(0, 1)
                grad_outputs.data.resize_(current_batch_size)
            x_both = pa * u + pb_fake * (1. - u)

            # We only want the gradients with respect to x_both
            x_both = Variable(x_both, requires_grad=True)
            grad = torch.autograd.grad(outputs=D(x_both),
                                       inputs=x_both,
                                       grad_outputs=grad_outputs,
                                       retain_graph=True,
                                       create_graph=True,
                                       only_inputs=True)[0]
            # We need to norm 3 times (over n_colors x image_size x image_size) to get only a vector of size
            # "batch_size"
            grad_penalty = opt.penalty * (
                (grad.norm(2, 1).norm(2, 1).norm(2, 1) - 1)**2).mean()
            grad_penalty.backward()

            optimizerD.step()

        ########################
        # (2) Update G network #
        ########################

        for p in D.parameters():
            p.requires_grad = False

        for t in range(opt.Giters):
            inputs = train_loader.next_batch()
            pa = inputs['image'].cuda()
            ap = inputs['agnostic'].cuda()
            ca = inputs['cloth'].cuda()
            cb = inputs['another_cloth'].cuda()
            parse_cloth = inputs['parse_cloth'].cuda()
            del inputs

            current_batch_size = pa.size(0)

            # paired data
            G.zero_grad()

            warped_cloth_a, pa_fake = G(ca, ap)
            if step >= opt.human_parser_step:  # 生成的图片较真实后再添加human parser
                parse_pa_fake = human_parser(pa_fake)  # (N,H,W)
                parse_ca_fake = (parse_pa_fake == 5) + \
                                (parse_pa_fake == 6) + \
                                (parse_pa_fake == 7)  # [0,1] (N,H,W)
                parse_ca_fake = parse_ca_fake.unsqueeze(1).type_as(
                    pa_fake)  # (N,1,H,W)
                ca_fake = pa_fake * parse_ca_fake + (1 - parse_ca_fake
                                                     )  # [-1,1]
                with torch.no_grad():
                    parse_pa_fake_vis = visualize_seg(parse_pa_fake)
                l_cloth_p = criterionCloth(ca_fake, warped_cloth_a)
            else:
                with torch.no_grad():
                    ca_fake = torch.zeros_like(pa_fake)
                    parse_pa_fake_vis = torch.zeros_like(pa_fake)
                    l_cloth_p = torch.zeros(1).cuda()

            l_warp = 20 * criterionWarp(warped_cloth_a, parse_cloth)
            l_perceptual = criterionPerceptual(pa_fake, pa)
            l_L1 = criterionL1(pa_fake, pa)
            loss_p = l_warp + l_perceptual + l_L1 + l_cloth_p

            loss_p.backward()
            optimizerG.step()

            # unpaired data
            G.zero_grad()

            warped_cloth_b, pb_fake = G(cb, ap)
            if step >= opt.human_parser_step:  # 生成的图片较真实后再添加human parser
                parse_pb_fake = human_parser(pb_fake)
                parse_cb_fake = (parse_pb_fake == 5) + \
                                (parse_pb_fake == 6) + \
                                (parse_pb_fake == 7)  # [0,1] (N,H,W)
                parse_cb_fake = parse_cb_fake.unsqueeze(1).type_as(
                    pb_fake)  # (N,1,H,W)
                cb_fake = pb_fake * parse_cb_fake + (1 - parse_cb_fake
                                                     )  # [-1,1]
                with torch.no_grad():
                    parse_pb_fake_vis = visualize_seg(parse_pb_fake)
                l_cloth_up = criterionCloth(cb_fake, warped_cloth_b)
            else:
                with torch.no_grad():
                    cb_fake = torch.zeros_like(pb_fake)
                    parse_pb_fake_vis = torch.zeros_like(pb_fake)
                    l_cloth_up = torch.zeros(1).cuda()

            with torch.no_grad():
                ya.data.resize_(current_batch_size).fill_(1)
                yb.data.resize_(current_batch_size).fill_(0)
            ya_pred = D(pa)
            yb_pred_fake = D(pb_fake)

            # Non-saturating
            l_adv = 0.1 * (
                BCE_stable(ya_pred - torch.mean(yb_pred_fake), yb) +
                BCE_stable(yb_pred_fake - torch.mean(ya_pred), ya)) / 2
            loss_up = l_adv + l_cloth_up
            loss_up.backward()
            optimizerG.step()

            # visuals = [
            #     [cb, warped_cloth_b, pb_fake],
            #     [ca, warped_cloth_a, pa_fake],
            #     [ap, parse_cloth, pa]
            # ]
            visuals = [[
                cb, warped_cloth_b, pb_fake, cb_fake, parse_pb_fake_vis
            ], [ca, warped_cloth_a, pa_fake, ca_fake, parse_pa_fake_vis],
                       [ap, parse_cloth, pa]]

            if (step + 1) % opt.display_count == 0:
                board_add_images(board, 'combine', visuals, step + 1)
                board.add_scalar('loss_p', loss_p.item(), step + 1)
                board.add_scalar('l_warp', l_warp.item(), step + 1)
                board.add_scalar('l_perceptual', l_perceptual.item(), step + 1)
                board.add_scalar('l_L1', l_L1.item(), step + 1)
                board.add_scalar('l_cloth_p', l_cloth_p.item(), step + 1)
                board.add_scalar('loss_up', loss_up.item(), step + 1)
                board.add_scalar('l_adv', l_adv.item(), step + 1)
                board.add_scalar('l_cloth_up', l_cloth_up.item(), step + 1)
                board.add_scalar('errD', errD.item(), step + 1)

                t = time.time() - step_start_time
                print(
                    'step: %8d, time: %.3f, loss_p: %4f, loss_up: %.4f, l_adv: %.4f, errD: %.4f'
                    % (step + 1, t, loss_p.item(), loss_up.item(),
                       l_adv.item(), errD.item()),
                    flush=True)
                step_start_time = time.time()

            if (step + 1) % opt.save_count == 0:
                save_checkpoint(
                    G,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'step_%06d.pth' % (step + 1)))
Example #5
0
def train_tom(opt, train_loader, model, d_g, d_l, board):
    model.cuda()
    model.train()
    d_g.cuda()
    d_g.train()
    d_l.cuda()
    d_l.train()

    #reverse label
    dis_label_G = Variable(torch.FloatTensor(opt.batch_size,
                                             1)).fill_(0.).cuda()
    dis_label_real = Variable(torch.FloatTensor(opt.batch_size,
                                                1)).fill_(0.).cuda()
    dis_label_fake = Variable(torch.FloatTensor(opt.batch_size,
                                                1)).fill_(1.).cuda()

    # criterion
    criterionL1 = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionMask = nn.L1Loss()
    criterionGAN = nn.BCELoss()  #MSE

    # optimizer
    optimizerG = torch.optim.Adam(model.parameters(),
                                  lr=opt.lr,
                                  betas=(0.5, 0.999))
    optimizerDG = torch.optim.Adam(d_g.parameters(),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    optimizerDL = torch.optim.Adam(d_l.parameters(),
                                   lr=opt.lr,
                                   betas=(0.5, 0.999))
    schedulerG = torch.optim.lr_scheduler.LambdaLR(
        optimizerG,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))
    schedulerDG = torch.optim.lr_scheduler.LambdaLR(
        optimizerDG,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))
    schedulerDL = torch.optim.lr_scheduler.LambdaLR(
        optimizerDL,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()

        #dis_label_noise
        dis_label_noise = random.random() / 10
        dis_label_real = dis_label_real.data.fill_(0.0 +
                                                   random.random() * opt.noise)
        dis_label_fake = dis_label_fake.data.fill_(1.0 -
                                                   random.random() * opt.noise)

        #prep
        inputs = train_loader.next_batch()

        im = inputs['image'].cuda()  #sz=b*3*256*192
        im_pose = inputs['pose_image']
        im_h = inputs['head']
        shape = inputs['shape']

        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        batch_size = im.size(0)
        if batch_size != opt.batch_size: continue

        #D_real
        errDg_real = criterionGAN(d_g(torch.cat([agnostic, c, im], 1)),
                                  dis_label_real)

        #generate image
        outputs = model(torch.cat([agnostic, c], 1))
        p_rendered, m_composite = torch.split(outputs, 3, 1)
        p_rendered = torch.tanh(p_rendered)
        m_composite = torch.sigmoid(m_composite)
        p_tryon = c * m_composite + p_rendered * (1 - m_composite)

        real_crop, fake_crop = random_crop(im, p_tryon, opt.winsize)
        errDl_real = criterionGAN(d_l(real_crop), dis_label_real)

        #tom_train
        errGg_fake = criterionGAN(d_g(torch.cat([agnostic, c, p_tryon], 1)),
                                  dis_label_G)
        errGl_fake = criterionGAN(d_l(fake_crop), dis_label_G)

        loss_l1 = criterionL1(p_tryon, im)
        loss_vgg = criterionVGG(p_tryon, im)
        loss_mask = criterionMask(m_composite, cm)
        loss_GAN = (errGg_fake + errGl_fake * opt.alpha) / batch_size
        loss = loss_l1 + loss_vgg + loss_mask + loss_GAN

        #D_fake
        errDg_fake = criterionGAN(
            d_g(torch.cat([agnostic, c, p_tryon], 1).detach()), dis_label_fake)
        loss_Dg = (errDg_fake + errDg_real) / 2

        errDl_fake = criterionGAN(d_l(fake_crop.detach()), dis_label_fake)
        loss_Dl = (errDl_fake + errDl_real) / 2

        optimizerG.zero_grad()
        loss.backward()
        optimizerG.step()

        optimizerDL.zero_grad()
        loss_Dl.backward()
        optimizerDL.step()

        optimizerDG.zero_grad()
        loss_Dg.backward()
        optimizerDG.step()
        #tensorboradX
        visuals = [[im_h, shape, im_pose],
                   [c, cm * 2 - 1, m_composite * 2 - 1],
                   [p_rendered, p_tryon, im]]

        if (step + 1) % opt.display_count == 0:
            t = time.time() - iter_start_time

            loss_dict = {
                "TOT": loss.item(),
                "L1": loss_l1.item(),
                "VG": loss_vgg.item(),
                "Mk": loss_mask.item(),
                "G": loss_GAN.item(),
                "DG": loss_Dg.item(),
                "DL": loss_Dl.item()
            }
            print('step: %d|time: %.3f' % (step + 1, t), end="")

            sm_image(combine_images(im, p_tryon, real_crop, fake_crop),
                     "combined%d.jpg" % step, opt.debug)
            board_add_images(board, 'combine', visuals, step + 1)
            for k, v in loss_dict.items():
                print('|%s: %.3f' % (k, v), end="")
                board.add_scalar(k, v, step + 1)
            print()

        if (step + 1) % opt.save_count == 0:
            save_checkpoints(
                model, d_g, d_l,
                os.path.join(opt.checkpoint_dir, opt.stage + '_' + opt.name,
                             "step%06d" % step, '%s.pth'))
def train_refined_gmm(opt, train_loader, model, board):
    model.cuda()
    model.train()

    loss_weight = opt.loss_weight
    # if loss_weight > 0.01:
    #     print("Error")
    #     assert False

    # criterion
    warped_criterionL1 = nn.L1Loss()
    result_criterionL1 = nn.L1Loss()
    point_criterionL1 = nn.L1Loss()
    criterionMask = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionGram = GramLoss()
    rendered_criterionL1 = nn.L1Loss()

    center_mask_critetionL1 = nn.L1Loss()

    warped_mask_criterionL1 = nn.L1Loss()

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1))

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        im = inputs['image'].cuda()
        im_pose = inputs['pose_image'].cuda()
        im_h = inputs['head'].cuda()
        shape = inputs['shape'].cuda()
        densepose_shape = inputs['densepose_shape'].cuda()
        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        im_c = inputs['parse_cloth'].cuda()
        im_g = inputs['grid_image'].cuda()
        parse_cloth_mask = inputs['parse_cloth_mask'].cuda()
        target_shape = inputs['target_shape']

        c_point_plane = inputs['cloth_points'].cuda()
        p_point_plane = inputs['person_points'].cuda()

        grid, theta, warped_cloth, outputs = model(agnostic, c)
        #warped_cloth = F.grid_sample(c, grid, padding_mode='border')
        warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
        warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')
        compute_c_point_plane = compute_grid_point(p_point_plane, grid)

        warped_mask_loss = 0
        if opt.add_warped_mask_loss:
            warped_mask_loss += warped_mask_criterionL1(
                warped_mask, target_shape)

        c_rendered, m_composite = torch.split(outputs, 3, 1)
        c_rendered = F.tanh(c_rendered)
        m_composite = F.sigmoid(m_composite)
        c_result = warped_cloth * m_composite + c_rendered * (1 - m_composite)

        visuals = [[im_h, shape, im_pose], [c, warped_cloth, im_c],
                   [warped_grid, (warped_cloth + im) * 0.5, im],
                   [m_composite, (c_result + im) * 0.5, c_result]]

        loss_warped_cloth = warped_criterionL1(warped_cloth, im_c)
        loss_point = 0
        if opt.add_point_loss:
            loss_point = point_criterionL1(compute_c_point_plane,
                                           c_point_plane)
        loss_c_result = result_criterionL1(c_result, im_c)
        loss_mask = criterionMask(m_composite, warped_mask)
        loss_vgg = 0
        if opt.add_vgg_loss:
            loss_vgg = criterionVGG(c_result, im_c)
        loss_gram = 0
        if opt.add_gram_loss:
            loss_gram += criterionGram(c_result, im_c)

        loss_render = 0
        if opt.add_render_loss:
            loss_render += rendered_criterionL1(c_rendered, im_c)

        loss_mask_constrain = 0
        if opt.add_mask_constrain:
            center_mask = m_composite * parse_cloth_mask
            ground_mask = torch.ones_like(parse_cloth_mask, dtype=torch.float)
            ground_mask = ground_mask * warped_mask * parse_cloth_mask
            loss_mask_constrain = center_mask_critetionL1(
                center_mask, ground_mask)
            #print("long_mask_constrain = ", loss_mask_constrain)
            loss_mask_constrain = loss_mask_constrain * opt.mask_constrain_weight
            #print("long_mask_constrain = ", loss_mask_constrain)
        # print("loss cloth = ", loss_warped_cloth)
        # print("loss point = ", loss_point)
        # print("loss render = ", loss_render)
        # print("loss_c_result = ", loss_c_result)

        loss = loss_warped_cloth + loss_weight * loss_point + loss_c_result + loss_mask + loss_vgg + loss_render + loss_mask_constrain + warped_mask_loss + loss_gram

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step + 1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step + 1)
            board.add_scalar('metric', loss.item(), step + 1)
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %4f' %
                  (step + 1, t, loss.item()),
                  flush=True)

        if (step + 1) % opt.save_count == 0:
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
Example #7
0
 def __init__(self):
     super(Content_loss, self).__init__()
     self.l1loss = nn.L1Loss()
     self.vgg_loss = VGGLoss()
Example #8
0
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters["lr"]
        self.newsize = hyperparameters["crop_image_height"]
        self.semantic_w = hyperparameters["semantic_w"] > 0
        self.recon_mask = hyperparameters["recon_mask"] == 1
        self.dann_scheduler = None
        self.full_adaptation = hyperparameters["adaptation"][
            "full_adaptation"] == 1
        dim = hyperparameters["gen"]["dim"]
        n_downsample = hyperparameters["gen"]["n_downsample"]
        latent_dim = dim * (2**n_downsample)

        if "domain_adv_w" in hyperparameters.keys():
            self.domain_classif_ab = hyperparameters["domain_adv_w"] > 0
        else:
            self.domain_classif_ab = False

        if hyperparameters["adaptation"]["dfeat_lambda"] > 0:
            self.use_classifier_sr = True
        else:
            self.use_classifier_sr = False

        if hyperparameters["adaptation"]["sem_seg_lambda"] > 0:
            self.train_seg = True
        else:
            self.train_seg = False

        if hyperparameters["adaptation"]["output_classifier_lambda"] > 0:
            self.use_output_classifier_sr = True
        else:
            self.use_output_classifier_sr = False

        self.gen = SpadeGen(hyperparameters["input_dim_a"],
                            hyperparameters["gen"])

        # Note: the "+1" is for the masks
        if hyperparameters["dis"]["type"] == "patchgan":
            print("Using patchgan discrminator...")
            self.dis_a = MultiscaleDiscriminator(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b = MultiscaleDiscriminator(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)

            self.dis_a_masked = MultiscaleDiscriminator(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b_masked = MultiscaleDiscriminator(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        else:
            self.dis_a = MsImageDis(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b = MsImageDis(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)
            self.dis_a_masked = MsImageDis(
                hyperparameters["input_dim_a"],
                hyperparameters["dis"])  # discriminator for domain a
            self.dis_b_masked = MsImageDis(
                hyperparameters["input_dim_b"],
                hyperparameters["dis"])  # discriminator for domain b
            self.instancenorm = nn.InstanceNorm2d(512, affine=False)

        # fix the noise usd in sampling
        display_size = int(hyperparameters["display_size"])
        # Setup the optimizers
        beta1 = hyperparameters["beta1"]
        beta2 = hyperparameters["beta2"]
        dis_params = (list(self.dis_a.parameters()) +
                      list(self.dis_b.parameters()) +
                      list(self.dis_a_masked.parameters()) +
                      list(self.dis_b_masked.parameters()))

        gen_params = list(self.gen.parameters())

        self.dis_opt = torch.optim.Adam(
            [p for p in dis_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters["weight_decay"],
        )
        self.gen_opt = torch.optim.Adam(
            [p for p in gen_params if p.requires_grad],
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=hyperparameters["weight_decay"],
        )
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters["init"]))
        self.dis_a.apply(weights_init("gaussian"))
        self.dis_b.apply(weights_init("gaussian"))
        self.dis_a_masked.apply(weights_init("gaussian"))
        self.dis_b_masked.apply(weights_init("gaussian"))

        # Load VGG model if needed
        if hyperparameters["vgg_w"] > 0:
            self.criterionVGG = VGGLoss()

        # Load semantic segmentation model if needed
        if "semantic_w" in hyperparameters.keys(
        ) and hyperparameters["semantic_w"] > 0:
            self.segmentation_model = load_segmentation_model(
                hyperparameters["semantic_ckpt_path"], 19)
            self.segmentation_model.eval()
            for param in self.segmentation_model.parameters():
                param.requires_grad = False

        # Load domain classifier if needed
        if "domain_adv_w" in hyperparameters.keys(
        ) and hyperparameters["domain_adv_w"] > 0:
            self.domain_classifier_ab = domainClassifier(input_dim=latent_dim,
                                                         dim=256)
            dann_params = list(self.domain_classifier_ab.parameters())
            self.dann_opt = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.domain_classifier_ab.apply(weights_init("gaussian"))
            self.dann_scheduler = get_scheduler(self.dann_opt, hyperparameters)

        # Load classifier on features for syn, real adaptation
        if self.use_classifier_sr:
            #! Hardcoded
            self.domain_classifier_sr_b = domainClassifier(
                input_dim=latent_dim, dim=256)
            self.domain_classifier_sr_a = domainClassifier(
                input_dim=latent_dim, dim=256)

            dann_params = list(
                self.domain_classifier_sr_a.parameters()) + list(
                    self.domain_classifier_sr_b.parameters())
            self.classif_opt_sr = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.domain_classifier_sr_a.apply(weights_init("gaussian"))
            self.domain_classifier_sr_b.apply(weights_init("gaussian"))
            self.classif_sr_scheduler = get_scheduler(self.classif_opt_sr,
                                                      hyperparameters)

        if self.use_output_classifier_sr:
            if self.hyperparameters["dis"]["type"] == "patchgan":
                self.output_classifier_sr_a = MultiscaleDiscriminator(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain a,sr
                self.output_classifier_sr_b = MultiscaleDiscriminator(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain b,sr

            else:
                self.output_classifier_sr_a = MsImageDis(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain a,sr
                self.output_classifier_sr_b = MsImageDis(
                    hyperparameters["input_dim_a"],
                    hyperparameters["dis"])  # discriminator for domain b,sr

            dann_params = list(
                self.output_classifier_sr_a.parameters()) + list(
                    self.output_classifier_sr_b.parameters())
            self.output_classif_opt_sr = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.output_classifier_sr_b.apply(weights_init("gaussian"))
            self.output_classifier_sr_a.apply(weights_init("gaussian"))
            self.output_scheduler_sr = get_scheduler(
                self.output_classif_opt_sr, hyperparameters)

        if self.train_seg:
            pretrained = load_segmentation_model(
                hyperparameters["semantic_ckpt_path"], 19)
            last_layer = nn.Conv2d(512, 10, kernel_size=1)
            model = torch.nn.Sequential(
                *list(pretrained.resnet34_8s.children())[7:-1],
                last_layer.cuda())
            self.segmentation_head = model

            for param in self.segmentation_head.parameters():
                param.requires_grad = True

            dann_params = list(self.segmentation_head.parameters())
            self.segmentation_opt = torch.optim.Adam(
                [p for p in dann_params if p.requires_grad],
                lr=lr,
                betas=(beta1, beta2),
                weight_decay=hyperparameters["weight_decay"],
            )
            self.scheduler_seg = get_scheduler(self.segmentation_opt,
                                               hyperparameters)
Example #9
0
def train_tom(opt, train_loader, model, board):
    # load model
    model.cuda()
    model.train()
    
    # criterion
    criterionL1 = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionMask = nn.L1Loss()
    
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda step: 1.0 -
            max(0, step - opt.keep_step) / float(opt.decay_step + 1))

    # train log
    if not opt.checkpoint == '':
        train_log = open(os.path.join(opt.checkpoint_dir, opt.name, 'train_log.txt'), 'a')
    else:
        os.makedirs(os.path.join(opt.checkpoint_dir, opt.name), exist_ok=True)
        train_log = open(os.path.join(opt.checkpoint_dir, opt.name, 'train_log.txt'), 'w')
        train_log.write('='*30 + ' Training Option ' + '='*30 + '\n')
        train_log.write(str(opt) + '\n\n')
        train_log.write('='*30 + ' Network Architecture ' + '='*30 + '\n')
        print(str(model) + '\n', file=train_log)
        train_log.write('='*30 + ' Training Log ' + '='*30 + '\n')

    # train loop
    checkpoint_step = 0
    if not opt.checkpoint == '':
        checkpoint_step += int(opt.checkpoint.split('/')[-1][5:11])
    for step in range(checkpoint_step, opt.keep_step + opt.decay_step):
        iter_start_time = time.time()

        dl_iter = iter(train_loader)
        inputs = dl_iter.next()
            
        im = inputs['image'].cuda()
        im_pose = inputs['pose_image']
        im_h = inputs['head']
        shape = inputs['shape']

        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        
        outputs = model(torch.cat([agnostic, c],1))
        p_rendered, m_composite = torch.split(outputs, 3,1)
        p_rendered = torch.tanh(p_rendered)
        m_composite = torch.sigmoid(m_composite)
        p_tryon = c * m_composite+ p_rendered * (1 - m_composite)

        visuals = [ [im_h, shape, im_pose], 
                   [c, cm*2-1, m_composite*2-1], 
                   [p_rendered, p_tryon, im]]
            
        loss_l1 = criterionL1(p_tryon, im)
        loss_vgg = criterionVGG(p_tryon, im)
        loss_mask = criterionMask(m_composite, cm)
        loss = loss_l1 + loss_vgg + loss_mask
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
            
        if (step+1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step+1)
            board.add_scalar('metric', loss.item(), step+1)
            board.add_scalar('L1', loss_l1.item(), step+1)
            board.add_scalar('VGG', loss_vgg.item(), step+1)
            board.add_scalar('MaskL1', loss_mask.item(), step+1)
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' 
                    % (step+1, t, loss.item(), loss_l1.item(), 
                    loss_vgg.item(), loss_mask.item()), flush=True)
            train_log.write('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f' 
                    % (step+1, t, loss.item(), loss_l1.item(), loss_vgg.item(), loss_mask.item()) + '\n')

        if (step+1) % opt.save_count == 0:
            save_checkpoint(model, os.path.join(opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))
Example #10
0
def train_tom_gmm(opt, train_loader, model, model_module, gmm_model,
                  gmm_model_module, board):
    model.train()
    gmm_model.train()

    # criterion
    criterionL1 = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionMask = nn.L1Loss()

    # optimizer
    optimizer = torch.optim.Adam(list(model.parameters()) +
                                 list(gmm_model.parameters()),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        im = inputs['image'].cuda()
        im_pose = inputs['pose_image']
        im_h = inputs['head']
        shape = inputs['shape']
        im_c = inputs['parse_cloth'].cuda()

        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()

        with torch.no_grad():
            grid, theta = gmm_model(agnostic, c)
            c = F.grid_sample(c, grid, padding_mode='border')
            cm = F.grid_sample(cm, grid, padding_mode='zeros')

        # grid, theta = model(agnostic, c)
        # warped_cloth = F.grid_sample(c, grid, padding_mode='border')
        # warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
        # warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')

        outputs = model(torch.cat([agnostic, c], 1))
        p_rendered, m_composite = torch.split(outputs, 3, 1)
        p_rendered = F.tanh(p_rendered)
        m_composite = F.sigmoid(m_composite)
        p_tryon = c * m_composite + p_rendered * (1 - m_composite)

        visuals = [[im_h, shape, im_pose],
                   [c, cm * 2 - 1, m_composite * 2 - 1],
                   [p_rendered, p_tryon, im]]

        loss_l1 = criterionL1(p_tryon, im)
        loss_vgg = criterionVGG(p_tryon, im)
        loss_mask = criterionMask(m_composite, cm)
        loss_warp = criterionL1(c, im_c)

        loss = loss_l1 + loss_vgg + loss_mask + loss_warp
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step + 1) % opt.display_count == 0 and single_gpu_flag(opt):
            board_add_images(board, 'combine', visuals, step + 1)
            board.add_scalar('metric', loss.item(), step + 1)
            board.add_scalar('L1', loss_l1.item(), step + 1)
            board.add_scalar('VGG', loss_vgg.item(), step + 1)
            board.add_scalar('MaskL1', loss_mask.item(), step + 1)
            board.add_scalar('Warp', loss_warp.item(), step + 1)

            t = time.time() - iter_start_time
            print(
                'step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f, warp: %.4f'
                % (step + 1, t, loss.item(), loss_l1.item(), loss_vgg.item(),
                   loss_mask.item(), loss_warp.item()),
                flush=True)

        if (step + 1) % opt.save_count == 0 and single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_%06d.pth' % (step + 1)))
            save_checkpoint(
                gmm_model_module,
                os.path.join(opt.checkpoint_dir, opt.name,
                             'step_warp_%06d.pth' % (step + 1)))
def train_tom(opt, train_loader, model, board):
    device = torch.device("cuda:0")
    model.to(device)
    #model.cuda()
    model.train()

    # criterion
    criterionL1 = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionMask = nn.L1Loss()

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=opt.lr,
                                 betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: 1.0 - max(0, step - opt.keep_step) / float(
            opt.decay_step + 1),
    )

    pbar = tqdm(range(opt.keep_step + opt.decay_step))
    for step in pbar:
        inputs = train_loader.next_batch()

        im = inputs["image"].to(device)  #.cuda()
        im_pose = inputs["pose_image"]
        im_h = inputs["head"]
        shape = inputs["shape"]

        agnostic = inputs["agnostic"].to(device)  # .cuda()
        c = inputs["cloth"].to(device)  #.cuda()
        cm = inputs["cloth_mask"].to(device)  #.cuda()

        concat_tensor = torch.cat([agnostic, c], 1)
        concat_tensor = concat_tensor.to(device)

        outputs = model(concat_tensor)
        p_rendered, m_composite = torch.split(outputs, 3, 1)
        p_rendered = F.tanh(p_rendered)
        m_composite = F.sigmoid(m_composite)
        p_tryon = c * m_composite + p_rendered * (1 - m_composite)

        visuals = [
            [im_h, shape, im_pose],
            [c, cm * 2 - 1, m_composite * 2 - 1],
            [p_rendered, p_tryon, im],
        ]

        loss_l1 = criterionL1(p_tryon, im)
        loss_vgg = criterionVGG(p_tryon, im)
        loss_mask = criterionMask(m_composite, cm)
        loss = loss_l1 + loss_vgg + loss_mask
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tqdm.set_description(
            f"loss: {loss.item():.4f}, l1: {loss_l1.item():.4f}, vgg: {loss_vgg.item():.4f}, mask: {loss_mask.item():.4f}",
        )
        if board and (step + 1) % opt.display_count == 0:
            board_add_images(board, "combine", visuals, step + 1)
            board.add_scalar("metric", loss.item(), step + 1)
            board.add_scalar("L1", loss_l1.item(), step + 1)
            board.add_scalar("VGG", loss_vgg.item(), step + 1)
            board.add_scalar("MaskL1", loss_mask.item(), step + 1)
            print(
                f"step: {step + 1:8d}, loss: {loss.item():.4f}, l1: {loss_l1.item():.4f}, vgg: {loss_vgg.item():.4f}, mask: {loss_mask.item():.4f}",
                flush=True,
            )

        if (step + 1) % opt.save_count == 0:
            save_checkpoint(
                model,
                os.path.join(opt.checkpoint_dir, opt.name,
                             "step_%06d.pth" % (step + 1)),
            )
Example #12
0
def train_tom(opt, train_loader, model, board):
    model.cuda()
    model.train()

    dic = { }
    dic["steps"] = []
    dic["loss"] = []
    dic["l1"] = []
    dic["vgg"] = []
    dic["mask"] = []


    # criterion
    criterionL1 = nn.L1Loss()
    criterionVGG = VGGLoss()
    criterionMask = nn.L1Loss()

    # optimizer
    optimizer = torch.optim.Adam(
        model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 -
                                                  max(0, step - opt.keep_step) / float(opt.decay_step + 1))

    for step in range(opt.keep_step + opt.decay_step):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        im = inputs['image'].cuda()
        im_pose = inputs['pose_image']
        im_h = inputs['head']
        shape = inputs['shape']

        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        pcm = inputs['parse_cloth_mask'].cuda()

        # outputs = model(torch.cat([agnostic, c], 1))  # CP-VTON
        outputs = model(torch.cat([agnostic, c, cm], 1))  # CP-VTON+
        p_rendered, m_composite = torch.split(outputs, 3, 1)
        p_rendered = F.tanh(p_rendered)
        m_composite = F.sigmoid(m_composite)
        p_tryon = c * m_composite + p_rendered * (1 - m_composite)

        """visuals = [[im_h, shape, im_pose],
                   [c, cm*2-1, m_composite*2-1],
                   [p_rendered, p_tryon, im]]"""  # CP-VTON

        visuals = [[im_h, shape, im_pose],
                   [c, pcm*2-1, m_composite*2-1],
                   [p_rendered, p_tryon, im]]  # CP-VTON+

        loss_l1 = criterionL1(p_tryon, im)
        loss_vgg = criterionVGG(p_tryon, im)
        # loss_mask = criterionMask(m_composite, cm)  # CP-VTON
        loss_mask = criterionMask(m_composite, pcm)  # CP-VTON+
        loss = loss_l1 + loss_vgg + loss_mask
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step+1) % opt.display_count == 0:
            board_add_images(board, 'combine', visuals, step+1)
            board.add_scalar('metric', loss.item(), step+1)
            board.add_scalar('L1', loss_l1.item(), step+1)
            board.add_scalar('VGG', loss_vgg.item(), step+1)
            board.add_scalar('MaskL1', loss_mask.item(), step+1)
            t = time.time() - iter_start_time
            print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f'
                  % (step+1, t, loss.item(), loss_l1.item(),
                     loss_vgg.item(), loss_mask.item()), flush=True)

        if (step+1) % opt.save_count == 0:
            save_checkpoint(model, os.path.join(
                opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))

        if (step+1) % 5000 == 0:
            dic["steps"].append(step)
            dic["loss"].append(loss.item())
            dic["l1"].append(loss_l1.item())
            dic["vgg"].append(loss_vgg.item())
            dic["mask"].append(loss_mask.item())

    
    with open('lossvstep/tom.pickle', 'wb') as handle:
        pickle.dump(dic, handle, protocol=pickle.HIGHEST_PROTOCOL)