Esempio n. 1
0
class Runner(object):
    def __init__(self, ckpt, params):
        self.device = device
        self.params = params
        self.gender = params['gender']
        self.garment_class = params['garment_class']
        self.bs = params['batch_size']
        self.garment_layer = params['garment_layer']
        self.res_name = params['res']
        self.hres = True
        if self.res_name == 'lres':
            self.hres = False

        # log and backup
        LOG_DIR = '/scratch/BS/pool1/garvita/sizer'
        self.model_name = "EncDec_{}".format(self.res_name)

        log_name = os.path.join(
            self.garment_class, '{}_{}'.format(self.garment_layer,
                                               self.res_name))
        self.log_dir = os.path.join(LOG_DIR, log_name)
        if not os.path.exists(self.log_dir):
            print('making %s' % self.log_dir)
            os.makedirs(self.log_dir)

        with open(os.path.join(self.log_dir, "params.json"), 'w') as f:
            json.dump(params, f)

        self.iter_nums = 0 if 'iter_nums' not in params else params['iter_nums']

        #load smpl data
        self.layer_size, self.smpl_size = get_res_vert(params['garment_class'],
                                                       self.hres,
                                                       params['garment_layer'])

        # get active vert id
        input_dim = self.layer_size * 3
        output_dim = input_dim

        self.vert_indices = get_vid(self.garment_layer, self.garment_class,
                                    self.hres)
        self.vert_indices = torch.tensor(self.vert_indices.astype(
            np.long)).long().cuda()

        # dataset and dataloader
        self.test_dataset = SizerData(garment_class=self.garment_class,
                                      garment_layer=self.garment_layer,
                                      mode='test',
                                      batch_size=self.bs,
                                      res='hres',
                                      gender='male')
        self.test_loader = DataLoader(
            self.test_dataset,
            batch_size=self.bs,
            num_workers=12,
            shuffle=True,
            drop_last=True if len(self.train_dataset) > self.bs else False)

        #create smpl
        self.smpl = TorchSMPL4Garment(gender=self.gender).to(device)
        self.smpl_faces_np = self.smpl.faces
        self.smpl_faces = torch.tensor(self.smpl_faces_np.astype('float32'),
                                       dtype=torch.long).cuda()

        #interpenetraion loss term
        self.body_f_np = self.smpl.faces
        self.garment_f_np = Mesh(filename=os.path.join(
            DATA_DIR, 'real_{}_{}_{}.obj'.format(
                self.garment_class, self.res_name, self.garment_layer))).f

        self.garment_f_torch = torch.tensor(self.garment_f_np.astype(
            np.long)).long().to(device)
        # models and optimizer
        latent_dim = 50
        self.model = getattr(network_layers,
                             self.model_name)(input_size=input_dim,
                                              latent_size=latent_dim,
                                              output_size=output_dim)

        self.model.to(device)

        print("loading {}".format(ckpt))
        state_dict = torch.load(ckpt)
        self.model.load_state_dict(state_dict)
        self.model.eval()

    def forward(self, ):
        pass

    def eval_test(self):

        sum_val_loss = 0
        num_batches = 15
        for _ in range(num_batches):
            try:
                batch = self.val_data_iterator.next()
            except:
                self.val_data_iterator = self.val_loader.__iter__()
                batch = self.val_data_iterator.next()
            gar_vert0 = batch.get('gar_vert0').to(device)
            gar_vert1 = batch.get('gar_vert1').to(device)
            gar_vert2 = batch.get('gar_vert2').to(device)

            betas0 = batch.get('betas0').to(device)

            pose0 = batch.get('pose0').to(device)
            pose1 = batch.get('pose1').to(device)
            pose2 = batch.get('pose2').to(device)

            trans0 = batch.get('trans0').to(device)
            trans1 = batch.get('trans1').to(device)
            trans2 = batch.get('trans2').to(device)

            size0 = batch.get('size0').to(device)
            size1 = batch.get('size1').to(device)
            size2 = batch.get('size2').to(device)

            self.optimizer.zero_grad()
            inp_gar = torch.cat([
                gar_vert0, gar_vert0, gar_vert0, gar_vert1, gar_vert1,
                gar_vert1, gar_vert2, gar_vert2, gar_vert2
            ],
                                dim=0)
            size_inp = torch.cat([
                size0, size0, size0, size1, size1, size1, size2, size2, size2
            ],
                                 dim=0)
            size_des = torch.cat([
                size0, size1, size2, size0, size1, size2, size0, size1, size2
            ],
                                 dim=0)
            pose_all = torch.cat([
                pose0, pose1, pose2, pose0, pose1, pose2, pose0, pose1, pose2
            ],
                                 dim=0)
            trans_all = torch.cat([
                trans0, trans1, trans2, trans0, trans1, trans2, trans0, trans1,
                trans2
            ],
                                  dim=0)
            betas_feat = torch.cat([
                betas0, betas0, betas0, betas0, betas0, betas0, betas0, betas0,
                betas0
            ],
                                   dim=0)
            gt_verts = torch.cat([
                gar_vert0, gar_vert1, gar_vert2, gar_vert0, gar_vert1,
                gar_vert2, gar_vert0, gar_vert1, gar_vert2
            ],
                                 dim=0)

            all_dist = self.model(inp_gar, size_inp, size_des, betas_feat)
            body_verts, pred_verts = self.smpl.forward(beta=betas_feat,
                                                       theta=pose_all,
                                                       trans=trans_all,
                                                       garment_class='t-shirt',
                                                       garment_d=all_dist)

            sum_val_loss += data_loss(self.garment_layer, pred_verts,
                                      gt_verts).item()
        return sum_val_loss, sum_val_loss, gt_verts.detach().cpu().numpy(
        ), pred_verts.detach().cpu().numpy(), body_verts.detach().cpu().numpy(
        ), self.garment_f_np, self.body_f_np

    def cuda(self):
        self.model.cuda()

    def to(self, device):
        self.model.to(device)
Esempio n. 2
0
class Runner(object):
    def __init__(self, ckpt, params):
        self.device = device
        self.params = params
        self.gender = params['gender']
        self.garment_class = params['garment_class']
        self.bs = params['batch_size']
        self.garment_layer = params['garment_layer']
        self.res_name = params['res']
        self.num_neigh = params['num_neigh']
        self.feat = params['feat']
        self.hres = True
        if self.res_name == 'lres':
            self.hres = False
        self.model_name = 'FC_correspondence_{}'.format(self.res_name)
        self.layer_size, self.smpl_size = get_res_vert(params['garment_class'],
                                                       self.hres,
                                                       params['garment_layer'])

        layer_neigh = np.array(
            np.load(
                os.path.join(
                    DATA_DIR,
                    "real_{}_neighborheuristics_{}_{}_{}_gar_order2.npy".
                    format(self.garment_class, self.res_name,
                           self.garment_layer, self.num_neigh))))

        all_neighbors = np.array([[vid] for k in layer_neigh for vid in k])
        self.neigh_id2 = all_neighbors
        if self.garment_layer == 'Body':
            self.idx2 = torch.from_numpy(self.neigh_id2).view(
                len(self.body_vert), self.num_neigh).cuda()
        else:
            self.idx2 = torch.from_numpy(self.neigh_id2).view(
                self.layer_size, self.num_neigh).cuda()

        self.test_dataset = ParserData(garment_class=self.garment_class,
                                       garment_layer=self.garment_layer,
                                       mode='test',
                                       batch_size=self.bs,
                                       res=self.res_name,
                                       gender=self.gender,
                                       feat=self.feat)
        self.test_loader = DataLoader(self.test_dataset,
                                      batch_size=self.bs,
                                      num_workers=12,
                                      shuffle=True,
                                      drop_last=False)
        # #create smpl
        self.smpl = TorchSMPL4Garment(gender=self.gender).to(device)
        self.smpl_faces_np = self.smpl.faces
        self.smpl_faces = torch.tensor(self.smpl_faces_np.astype('float32'),
                                       dtype=torch.long).cuda()

        if self.garment_layer == 'Body':
            self.garment_f_np = self.body_f_np
            self.garment_f_torch = self.smpl_faces
        else:
            self.garment_f_np = Mesh(filename=os.path.join(
                DATA_DIR, 'real_{}_{}_{}.obj'.format(
                    self.garment_class, self.res_name, self.garment_layer))).f
            self.garment_f_torch = torch.tensor(
                self.garment_f_np.astype(np.long)).long().to(device)

        self.out_layer = torch.nn.Softmax(dim=2)
        input_dim = self.smpl_size * 3
        if self.feat == 'vn':
            input_dim = self.smpl_size * 6
        output_dim = self.layer_size * self.num_neigh

        self.model = getattr(network_layers,
                             self.model_name)(input_size=input_dim,
                                              output_size=output_dim)
        self.model.to(self.device)
        print("loading {}".format(ckpt))
        state_dict = torch.load(ckpt)
        self.model.load_state_dict(state_dict)
        self.model.eval()

    def forward(self, inp, betas, pose, trans=None, gt=None):
        bs = inp.shape[0]
        ipdb.set_trace()
        weights_from_net = self.model(inp)
        weights_from_net = weights_from_net.view(bs, self.layer_size,
                                                 self.num_neigh)
        weights_from_net = self.out_layer(weights_from_net)

        # make a copy of neighbour for each vertex
        input_copy = inp[:, self.idx2, :3]
        pred_x = weights_from_net * input_copy[:, :, :, 0]
        pred_y = weights_from_net * input_copy[:, :, :, 1]
        pred_z = weights_from_net * input_copy[:, :, :, 2]

        pred_verts = torch.sum(torch.stack((pred_x, pred_y, pred_z), axis=3),
                               axis=2)

        if trans is None:
            trans = torch.zeros((self.bs, 3))

        smpl_verts = self.smpl.forward(beta=betas, theta=pose, trans=trans)

        dist = None
        if gt is not None:
            dist = verts_dist(gt, pred_verts, dim=1) * 1000.

        return pred_verts.detach().cpu().numpy(), smpl_verts.detach().cpu(
        ).numpy(), self.garment_f_np, self.smpl_faces_np, dist.detach().cpu(
        ).numpy()

    def eval_test(self):

        sum_val_loss = 0
        num_batches = 15
        for _ in range(num_batches):
            try:
                batch = self.test_data_iterator.next()
            except:
                self.test_data_iterator = self.test_loader.__iter__()
                batch = self.test_data_iterator.next()
            inp = batch.get('inp').to(self.device)
            gt_verts = batch.get('gt_verts').to(self.device)
            betas = batch.get('betas').to(self.device)
            pose = batch.get('pose').to(self.device)
            trans = batch.get('trans').to(self.device)
            bs = inp.shape[0]
            # pred_verts = self.models(torch.cat((thetas, betas, gammas), dim=1)).view(gt_verts.shape) + linear_pred
            weights_from_net = self.model(inp).view(bs, self.layer_size,
                                                    self.num_neigh)
            weights_from_net = self.out_layer(weights_from_net)

            input_copy = inp[:, self.idx2, :3]
            pred_x = weights_from_net * input_copy[:, :, :, 0]
            pred_y = weights_from_net * input_copy[:, :, :, 1]
            pred_z = weights_from_net * input_copy[:, :, :, 2]
            pred_verts = torch.sum(torch.stack((pred_x, pred_y, pred_z),
                                               axis=3),
                                   axis=2)

            sum_val_loss += data_loss(self.garment_layer, pred_verts,
                                      gt_verts).item()
        return sum_val_loss, sum_val_loss, gt_verts.detach().cpu().numpy(
        ), pred_verts.detach().cpu().numpy(), inp[:, :, :3].detach().cpu(
        ).numpy(), self.garment_f_np, self.smpl_faces_np

    def cuda(self):
        self.model.cuda()

    def to(self, device):
        self.model.to(device)
Esempio n. 3
0
class Trainer(object):
    def __init__(self, params):
        self.device = device
        self.params = params
        self.gender = params['gender']
        self.garment_class = params['garment_class']
        self.bs = params['batch_size']
        self.garment_layer = params['garment_layer']
        self.res_name = params['res']
        self.num_neigh = params['num_neigh']
        self.feat = params['feat']
        self.hres = True
        if self.res_name == 'lres':
            self.hres = False
        # log
        LOG_DIR = params['log_dir']

        self.model_name = 'FC_correspondence_{}'.format(self.res_name)
        self.note = "FC_corr_{}_{}_{}".format(self.garment_class,
                                              self.garment_layer,
                                              self.res_name)
        log_name = os.path.join(
            self.garment_class,
            '{}_{}_{}_{}'.format(self.garment_layer, self.feat, self.num_neigh,
                                 self.res_name))

        self.log_dir = os.path.join(LOG_DIR, log_name)
        if not os.path.exists(self.log_dir):
            print('making %s' % self.log_dir)
            os.makedirs(self.log_dir)

        with open(os.path.join(self.log_dir, "params.json"), 'w') as f:
            json.dump(params, f)

        self.iter_nums = 0 if 'iter_nums' not in params else params['iter_nums']

        #load smpl and garment data

        self.layer_size, self.smpl_size = get_res_vert(params['garment_class'],
                                                       self.hres,
                                                       params['garment_layer'])
        if self.garment_layer == 'Body':
            self.layer_size = 4448
        # get active vert id
        input_dim = self.smpl_size * 3
        if self.feat == 'vn':
            input_dim = self.smpl_size * 6
        output_dim = self.layer_size * self.num_neigh

        layer_neigh = np.array(
            np.load(
                os.path.join(
                    DATA_DIR,
                    "real_{}_neighborheuristics_{}_{}_{}_gar_order2.npy".
                    format(self.garment_class, self.res_name,
                           self.garment_layer, self.num_neigh))))
        self.layer_neigh = torch.from_numpy(layer_neigh).cuda()

        #separate for body layer
        body_vert = range(self.smpl_size)
        vert_id_upper = get_vid('UpperClothes', self.garment_class, self.hres)
        vert_id_lower = get_vid('Pants', self.garment_class, self.hres)
        body_vert2 = [i for i in body_vert if i not in vert_id_upper]
        body_vert2 = [i for i in body_vert2 if i not in vert_id_lower]
        self.body_vert = body_vert2

        all_neighbors = np.array([[vid] for k in layer_neigh for vid in k])
        self.neigh_id2 = all_neighbors
        if self.garment_layer == 'Body':
            self.idx2 = torch.from_numpy(self.neigh_id2).view(
                len(self.body_vert), self.num_neigh).cuda()
        else:
            self.idx2 = torch.from_numpy(self.neigh_id2).view(
                self.layer_size, self.num_neigh).cuda()

        #get vert indixed of layer
        self.vert_indices = get_vid(self.garment_layer, self.garment_class,
                                    self.hres)
        self.vert_indices = torch.tensor(self.vert_indices.astype(
            np.long)).long().cuda()

        # dataset and dataloader
        self.train_dataset = ParserData(garment_class=self.garment_class,
                                        garment_layer=self.garment_layer,
                                        mode='train',
                                        batch_size=self.bs,
                                        res=self.res_name,
                                        gender=self.gender,
                                        feat=self.feat)
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.bs,
            num_workers=12,
            shuffle=True,
            drop_last=True if len(self.train_dataset) > self.bs else False)

        self.val_dataset = ParserData(garment_class=self.garment_class,
                                      garment_layer=self.garment_layer,
                                      mode='val',
                                      batch_size=self.bs,
                                      res=self.res_name,
                                      gender=self.gender,
                                      feat=self.feat)
        self.val_loader = DataLoader(self.val_dataset,
                                     batch_size=self.bs,
                                     num_workers=12,
                                     shuffle=True,
                                     drop_last=False)

        #create smpl
        self.smpl = TorchSMPL4Garment(gender=self.gender).to(device)
        self.smpl_faces_np = self.smpl.faces
        self.smpl_faces = torch.tensor(self.smpl_faces_np.astype('float32'),
                                       dtype=torch.long).cuda()

        if self.garment_layer == 'Body':
            self.garment_f_np = self.body_f_np
            self.garment_f_torch = self.smpl_faces
        else:
            self.garment_f_np = Mesh(filename=os.path.join(
                DATA_DIR, 'real_{}_{}_{}.obj'.format(
                    self.garment_class, self.res_name, self.garment_layer))).f
            self.garment_f_torch = torch.tensor(
                self.garment_f_np.astype(np.long)).long().to(device)

        self.num_faces = len(self.garment_f_np)

        self.model = getattr(network_layers,
                             self.model_name)(input_size=input_dim,
                                              output_size=output_dim)
        self.model.to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=params['lr'],
                                          weight_decay=1e-6)
        self.out_layer = torch.nn.Softmax(dim=2)
        if params['checkpoint']:
            ckpt_path = params['checkpoint']
            print('loading ckpt from {}'.format(ckpt_path))
            state_dict = torch.load(os.path.join(ckpt_path, 'lin.pth.tar'))
            self.model.load_state_dict(state_dict)
            state_dict = torch.load(
                os.path.join(ckpt_path, 'optimizer.pth.tar'))
            self.optimizer.load_state_dict(state_dict)

        geo_weights = np.load(os.path.join(DATA_DIR,
                                           'real_g5_geo_weights.npy'))
        self.geo_weights = torch.tensor(geo_weights[body_vert2].astype(
            np.float32)).cuda()
        self.best_error = np.inf
        self.best_epoch = -1
        self.logger = tensorboardX.SummaryWriter(os.path.join(self.log_dir))
        self.val_min = None
        self.d_tol = 0.002

        self.sideddistance = SidedDistance()
        self.relu = nn.ReLU()
        #weight initialiser
        vert_id = self.vert_indices.cpu().numpy()
        init_weights = torch.from_numpy(
            np.array([
                layer_neigh[i] == vert_id[i] for i in range(self.layer_size)
            ]).astype('int64'))
        self.init_weight = torch.stack([init_weights
                                        for _ in range(self.bs)]).cuda()

    def train(
        self,
        batch,
        pretrain=False,
    ):
        inp = batch.get('inp').to(self.device)
        gt_verts = batch.get('gt_verts').to(self.device)
        betas = batch.get('betas').to(self.device)
        pose = batch.get('pose').to(self.device)
        trans = batch.get('trans').to(self.device)

        self.optimizer.zero_grad()
        weights_from_net = self.model(inp).view(self.bs, self.layer_size,
                                                self.num_neigh)
        weights_from_net = self.out_layer(weights_from_net)

        if pretrain:
            loss = (weights_from_net - self.init_weight).abs().sum(-1).mean()
        else:
            input_copy = inp[:, self.idx2, :3]
            pred_x = weights_from_net * input_copy[:, :, :, 0]
            pred_y = weights_from_net * input_copy[:, :, :, 1]
            pred_z = weights_from_net * input_copy[:, :, :, 2]

            pred_verts = torch.sum(torch.stack((pred_x, pred_y, pred_z),
                                               axis=3),
                                   axis=2)

            # local neighbourhood regulaiser
            current_argmax = torch.argmax(weights_from_net, axis=2)
            idx = torch.stack([
                torch.index_select(self.layer_neigh, 1, current_argmax[i])[0]
                for i in range(self.bs)
            ])
            current_argmax_verts = torch.stack([
                torch.index_select(inp[i, :, :3], 0, idx[i])
                for i in range(self.bs)
            ])
            current_argmax_verts = torch.stack(
                [current_argmax_verts for i in range(self.num_neigh)], dim=2)
            dist_from_max = current_argmax_verts - input_copy  # todo: should it be input copy??

            dist_from_max = torch.sqrt(
                torch.sum(dist_from_max * dist_from_max, dim=3))
            local_regu = torch.sum(dist_from_max * weights_from_net) / (
                self.bs * self.num_neigh * self.layer_size)

            body_tmp = self.smpl.forward(beta=betas, theta=pose, trans=trans)
            body_mesh = [
                tm.from_tensors(vertices=v, faces=self.smpl_faces)
                for v in body_tmp
            ]

            if self.garment_layer == 'Body':
                # update body verts with prediction
                body_tmp[:, self.vert_indices, :] = pred_verts
                # get skin cutout
                loss_data = data_loss(self.garment_layer, pred_verts,
                                      inp[:, self.vert_indices, :],
                                      self.geo_weights)
            else:
                loss_data = data_loss(self.garment_layer, pred_verts, gt_verts)

            # create mesh for predicted and smpl mesh
            pred_mesh = [
                tm.from_tensors(vertices=v, faces=self.garment_f_torch)
                for v in pred_verts
            ]
            gt_mesh = [
                tm.from_tensors(vertices=v, faces=self.garment_f_torch)
                for v in gt_verts
            ]

            loss_lap = lap_loss(pred_mesh, gt_mesh)

            # calculate normal for gt, pred and body
            loss_norm, body_normals, pred_normals = normal_loss(
                self.bs, pred_mesh, gt_mesh, body_mesh, self.num_faces)

            # interpenetration loss
            loss_interp = interp_loss(self.sideddistance,
                                      self.relu,
                                      pred_verts,
                                      gt_verts,
                                      body_tmp,
                                      body_normals,
                                      self.layer_size,
                                      d_tol=self.d_tol)

            loss = loss_data + 100. * loss_lap + local_regu + loss_interp  # loss_norm

        loss.backward()
        self.optimizer.step()
        return loss

    def train_epoch(self, epoch, pretrain=False, train=True):
        if train:
            self.model.train()
            loss_total = 0.0
            for batch in self.train_loader:
                train_loss = self.train(batch, pretrain)
                loss_total += train_loss.item()
                self.logger.add_scalar("train/loss", train_loss.item(),
                                       self.iter_nums)
                print("Iter {}, loss: {:.8f}".format(self.iter_nums,
                                                     train_loss.item()))
                self.iter_nums += 1
            self.logger.add_scalar("train_epoch/loss", loss_total / len(batch),
                                   epoch)
        else:  #validation
            self._save_ckpt(epoch)
            val_loss, val_dist = self.validation(epoch)
            print("epoch {}, loss: {:.8f} dist {:8f}".format(
                epoch, val_loss, val_dist))
            if self.val_min is None:
                self.val_min = val_loss

            if val_loss < self.val_min:
                self.val_min = val_loss
                with open(os.path.join(self.log_dir, 'best_epoch'), 'w') as f:
                    f.write("{:04d}".format(epoch))

            self.logger.add_scalar("val/loss", val_loss, epoch)
            self.logger.add_scalar("val/dist", val_dist, epoch)

    def validation(self, epoch):
        self.model.eval()

        sum_val_loss = 0
        num_batches = 15
        for _ in range(num_batches):
            try:
                batch = self.val_data_iterator.next()
            except:
                self.val_data_iterator = self.val_loader.__iter__()
                batch = self.val_data_iterator.next()
            inp = batch.get('inp').to(self.device)
            gt_verts = batch.get('gt_verts').to(self.device)
            betas = batch.get('betas').to(self.device)
            pose = batch.get('pose').to(self.device)
            trans = batch.get('trans').to(self.device)
            self.optimizer.zero_grad()
            bs = inp.shape[0]
            # pred_verts = self.models(torch.cat((thetas, betas, gammas), dim=1)).view(gt_verts.shape) + linear_pred
            weights_from_net = self.model(inp).view(bs, self.layer_size,
                                                    self.num_neigh)
            weights_from_net = self.out_layer(weights_from_net)

            input_copy = inp[:, self.idx2, :3]
            pred_x = weights_from_net * input_copy[:, :, :, 0]
            pred_y = weights_from_net * input_copy[:, :, :, 1]
            pred_z = weights_from_net * input_copy[:, :, :, 2]
            pred_verts = torch.sum(torch.stack((pred_x, pred_y, pred_z),
                                               axis=3),
                                   axis=2)

            sum_val_loss += data_loss(self.garment_layer, pred_verts,
                                      gt_verts).item()
        return sum_val_loss, sum_val_loss

    def _save_ckpt(self, epoch):
        save_dir = os.path.join(self.log_dir, "{:04d}".format(epoch))
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(self.model.state_dict(),
                   os.path.join(save_dir, 'lin.pth.tar'))
        torch.save(self.optimizer.state_dict(),
                   os.path.join(save_dir, "optimizer.pth.tar"))