Exemplo n.º 1
0
    def one_step(self, inputs):
        gt_verts, smooth_verts, thetas, _, _, _ = inputs

        thetas = ops.mask_thetas(thetas, self.garment_class)
        gt_verts = gt_verts.to(device)
        smooth_verts = smooth_verts.to(device)
        thetas = thetas.to(device)

        # predicts residual over smooth groundtruth.
        pred_verts = self.model(thetas).view(gt_verts.shape) + smooth_verts

        # L1 loss
        data_loss = (pred_verts - gt_verts).abs().sum(-1).mean()
        return pred_verts, data_loss
Exemplo n.º 2
0
    def one_step(self, inputs):
        gt_verts, thetas, betas, gammas, _ = inputs

        thetas = ops.mask_thetas(thetas, self.garment_class)
        gt_verts = gt_verts.to(device)
        thetas = thetas.to(device)
        betas = betas.to(device)
        gammas = gammas.to(device)

        ss2g_verts = self.ss2g_runner.forward(betas=betas, gammas=gammas).view(
            gt_verts.shape)
        pred_verts = ss2g_verts + self.model(
            torch.cat((thetas, betas, gammas), dim=1)).view(gt_verts.shape)

        # L1 loss
        data_loss = (pred_verts - gt_verts).abs().sum(-1).mean()
        return pred_verts, data_loss
Exemplo n.º 3
0
    def one_step(self, inputs):
        """One forward pass.
        Takes `inputs` tuple. Returns output(s) and loss.
        """
        gt_verts, thetas, betas, gammas, _ = inputs

        thetas = ops.mask_thetas(thetas, self.garment_class)
        gt_verts = gt_verts.to(device)
        thetas = thetas.to(device)
        betas = betas.to(device)
        gammas = gammas.to(device)
        pred_verts = self.model(torch.cat((thetas, betas, gammas),
                                          dim=1)).view(gt_verts.shape)

        # L1 loss
        data_loss = (pred_verts - gt_verts).abs().sum(-1).mean()
        return pred_verts, data_loss
Exemplo n.º 4
0
def evaluate():
    """Evaluate TailorNet (or any model for that matter) on test set."""
    from dataset.static_pose_shape_final import MultiStyleShape
    from torch.utils.data import DataLoader
    from utils.eval import AverageMeter
    from models import ops

    gender = 'female'
    garment_class = 'old-t-shirt'

    dataset = MultiStyleShape(garment_class=garment_class,
                              gender=gender,
                              split='test')
    dataloader = DataLoader(dataset,
                            batch_size=32,
                            num_workers=0,
                            shuffle=False,
                            drop_last=False)
    print(len(dataset))

    val_dist = AverageMeter()
    runner = get_best_runner(garment_class, gender)
    # from trainer.base_trainer import get_best_runner as baseline_runner
    # runner = baseline_runner("/BS/cpatel/work/data/learn_anim/tn_baseline/{}_{}".format(garment_class, gender))

    device = torch.device('cuda:0')
    with torch.no_grad():
        for i, inputs in enumerate(dataloader):
            gt_verts, thetas, betas, gammas, _ = inputs

            thetas = ops.mask_thetas(thetas, garment_class)
            gt_verts = gt_verts.to(device)
            thetas = thetas.to(device)
            betas = betas.to(device)
            gammas = gammas.to(device)
            pred_verts = runner.forward(thetas=thetas,
                                        betas=betas,
                                        gammas=gammas).view(gt_verts.shape)

            dist = ops.verts_dist(gt_verts, pred_verts) * 1000.
            val_dist.update(dist.item(), gt_verts.shape[0])
            print(i, len(dataloader))
    print(val_dist.avg)
Exemplo n.º 5
0
 def forward(self, thetas, betas=None, gammas=None):
     thetas = ops.mask_thetas(thetas=thetas,
                              garment_class=self.garment_class)
     pred_verts = self.model(thetas)
     return pred_verts
Exemplo n.º 6
0
 def forward(self, thetas, betas, gammas):
     thetas = ops.mask_thetas(thetas=thetas,
                              garment_class=self.garment_class)
     pred_verts = self.model(torch.cat((thetas, betas, gammas), dim=1))
     return pred_verts