Ejemplo n.º 1
0
    def _make_model(self):
        model_path = os.path.join(cfg.model_dir, 'snapshot_%d.pth.tar' % self.test_epoch)
        assert os.path.exists(model_path), 'Cannot find model at ' + model_path
        self.logger.info('Load checkpoint from {}'.format(model_path))
        
        # prepare network
        self.logger.info("Creating graph...")
        model = get_model(self.vertex_num, self.joint_num, 'test')
        model = DataParallel(model).cuda()
        ckpt = torch.load(model_path)
        model.load_state_dict(ckpt['network'], strict=False)
        model.eval()

        self.model = model
Ejemplo n.º 2
0
    def _make_model(self, test_epoch):
        self.test_epoch = test_epoch
        model_path = os.path.join(cfg.model_dir,
                                  'snapshot_%d.pth.tar' % self.test_epoch)
        assert os.path.exists(model_path), 'Cannot find model at ' + model_path
        # self.logger.info('Load checkpoint from {}'.format(model_path))

        # prepare network
        # self.logger.info("Creating graph...")
        model = get_pose_net(self.backbone, False, self.joint_num)
        model = DataParallel(model).cuda()
        ckpt = torch.load(model_path)
        model.load_state_dict(ckpt['network'])
        model.eval()

        self.model = model
Ejemplo n.º 3
0
class SSRunner(object):
    def __init__(self, config):
        self.config = config

        # Data
        self.dataset_ss_train, _, self.dataset_ss_val = DatasetUtil.get_dataset_by_type(
            DatasetUtil.dataset_type_ss,
            self.config.ss_size,
            is_balance=self.config.is_balance_data,
            data_root=self.config.data_root_path,
            train_label_path=self.config.label_path,
            max_size=self.config.max_size)
        self.data_loader_ss_train = DataLoader(self.dataset_ss_train,
                                               self.config.ss_batch_size,
                                               True,
                                               num_workers=16,
                                               drop_last=True)
        self.data_loader_ss_val = DataLoader(self.dataset_ss_val,
                                             self.config.ss_batch_size,
                                             False,
                                             num_workers=16,
                                             drop_last=True)

        # Model
        self.net = self.config.Net(num_classes=self.config.ss_num_classes,
                                   output_stride=self.config.output_stride,
                                   arch=self.config.arch)

        if self.config.only_train_ss:
            self.net = BalancedDataParallel(0, self.net, dim=0).cuda()
        else:
            self.net = DataParallel(self.net).cuda()
            pass
        cudnn.benchmark = True

        # Optimize
        self.optimizer = optim.SGD(params=[
            {
                'params': self.net.module.model.backbone.parameters(),
                'lr': self.config.ss_lr
            },
            {
                'params': self.net.module.model.classifier.parameters(),
                'lr': self.config.ss_lr * 10
            },
        ],
                                   lr=self.config.ss_lr,
                                   momentum=0.9,
                                   weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=self.config.ss_milestones, gamma=0.1)

        # Loss
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=255,
                                           reduction='mean').cuda()
        pass

    def train_ss(self, start_epoch=0, model_file_name=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.ss_save_result_txt)
            self.load_model(model_file_name)
            pass

        # self.eval_ss(epoch=0)
        best_iou = 0.0

        for epoch in range(start_epoch, self.config.ss_epoch_num):
            Tools.print()
            Tools.print('Epoch:{:2d}, lr={:.6f} lr2={:.6f}'.format(
                epoch, self.optimizer.param_groups[0]['lr'],
                self.optimizer.param_groups[1]['lr']),
                        txt_path=self.config.ss_save_result_txt)

            ###########################################################################
            # 1 训练模型
            all_loss = 0.0
            self.net.train()
            if self.config.is_balance_data:
                self.dataset_ss_train.reset()
                pass
            for i, (inputs,
                    labels) in tqdm(enumerate(self.data_loader_ss_train),
                                    total=len(self.data_loader_ss_train)):
                inputs, labels = inputs.float().cuda(), labels.long().cuda()
                self.optimizer.zero_grad()

                result = self.net(inputs)
                loss = self.ce_loss(result, labels)

                loss.backward()
                self.optimizer.step()

                all_loss += loss.item()

                if (i + 1) % (len(self.data_loader_ss_train) // 10) == 0:
                    score = self.eval_ss(epoch=epoch)
                    mean_iou = score["Mean IoU"]
                    if mean_iou > best_iou:
                        best_iou = mean_iou
                        save_file_name = Tools.new_dir(
                            os.path.join(
                                self.config.ss_model_dir,
                                "ss_{}_{}_{}.pth".format(epoch, i, best_iou)))
                        torch.save(self.net.state_dict(), save_file_name)
                        Tools.print("Save Model to {}".format(save_file_name),
                                    txt_path=self.config.ss_save_result_txt)
                        Tools.print()
                    pass
                pass
            self.scheduler.step()
            ###########################################################################

            Tools.print("[E:{:3d}/{:3d}] ss loss:{:.4f}".format(
                epoch, self.config.ss_epoch_num,
                all_loss / len(self.data_loader_ss_train)),
                        txt_path=self.config.ss_save_result_txt)

            ###########################################################################
            # 2 保存模型
            if epoch % self.config.ss_save_epoch_freq == 0:
                Tools.print()
                save_file_name = Tools.new_dir(
                    os.path.join(self.config.ss_model_dir,
                                 "ss_{}.pth".format(epoch)))
                torch.save(self.net.state_dict(), save_file_name)
                Tools.print("Save Model to {}".format(save_file_name),
                            txt_path=self.config.ss_save_result_txt)
                Tools.print()
                pass
            ###########################################################################

            ###########################################################################
            # 3 评估模型
            if epoch % self.config.ss_eval_epoch_freq == 0:
                score = self.eval_ss(epoch=epoch)
                pass
            ###########################################################################

            pass

        # Final Save
        Tools.print()
        save_file_name = Tools.new_dir(
            os.path.join(self.config.ss_model_dir,
                         "ss_final_{}.pth".format(self.config.ss_epoch_num)))
        torch.save(self.net.state_dict(), save_file_name)
        Tools.print("Save Model to {}".format(save_file_name),
                    txt_path=self.config.ss_save_result_txt)
        Tools.print()

        self.eval_ss(epoch=self.config.ss_epoch_num)
        pass

    def eval_ss(self, epoch=0, model_file_name=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.ss_save_result_txt)
            self.load_model(model_file_name)
            pass

        self.net.eval()
        metrics = StreamSegMetrics(self.config.ss_num_classes)
        with torch.no_grad():
            for i, (inputs,
                    labels) in tqdm(enumerate(self.data_loader_ss_val),
                                    total=len(self.data_loader_ss_val)):
                inputs = inputs.float().cuda()
                labels = labels.long().cuda()
                outputs = self.net(inputs)
                preds = outputs.detach().max(dim=1)[1].cpu().numpy()
                targets = labels.cpu().numpy()

                metrics.update(targets, preds)
                pass
            pass

        score = metrics.get_results()
        Tools.print("{} {}".format(epoch, metrics.to_str(score)),
                    txt_path=self.config.ss_save_result_txt)
        return score

    def inference_ss(self,
                     model_file_name=None,
                     data_loader=None,
                     save_path=None):
        if model_file_name is not None:
            Tools.print("Load model form {}".format(model_file_name),
                        txt_path=self.config.ss_save_result_txt)
            self.load_model(model_file_name)
            pass

        final_save_path = Tools.new_dir("{}_final".format(save_path))

        self.net.eval()
        metrics = StreamSegMetrics(self.config.ss_num_classes)
        with torch.no_grad():
            for i, (inputs, labels,
                    image_info_list) in tqdm(enumerate(data_loader),
                                             total=len(data_loader)):
                assert len(image_info_list) == 1

                # 标签
                max_size = 1000
                size = Image.open(image_info_list[0]).size
                basename = os.path.basename(image_info_list[0])
                final_name = os.path.join(final_save_path,
                                          basename.replace(".JPEG", ".png"))
                if os.path.exists(final_name):
                    continue

                if size[0] < max_size and size[1] < max_size:
                    targets = F.interpolate(torch.unsqueeze(
                        labels[0].float().cuda(), dim=0),
                                            size=(size[1], size[0]),
                                            mode="nearest").detach().cpu()
                else:
                    targets = F.interpolate(torch.unsqueeze(labels[0].float(),
                                                            dim=0),
                                            size=(size[1], size[0]),
                                            mode="nearest")
                targets = targets[0].long().numpy()

                # 预测
                outputs = 0
                for input_index, input_one in enumerate(inputs):
                    output_one = self.net(input_one.float().cuda())
                    if size[0] < max_size and size[1] < max_size:
                        outputs += F.interpolate(
                            output_one,
                            size=(size[1], size[0]),
                            mode="bilinear",
                            align_corners=False).detach().cpu()
                    else:
                        outputs += F.interpolate(output_one.detach().cpu(),
                                                 size=(size[1], size[0]),
                                                 mode="bilinear",
                                                 align_corners=False)
                        pass
                    pass
                outputs = outputs / len(inputs)
                preds = outputs.max(dim=1)[1].numpy()

                # 计算
                metrics.update(targets, preds)

                if save_path:
                    Image.open(image_info_list[0]).save(
                        os.path.join(save_path, basename))
                    DataUtil.gray_to_color(
                        np.asarray(targets[0], dtype=np.uint8)).save(
                            os.path.join(save_path,
                                         basename.replace(".JPEG", "_l.png")))
                    DataUtil.gray_to_color(np.asarray(
                        preds[0], dtype=np.uint8)).save(
                            os.path.join(save_path,
                                         basename.replace(".JPEG", ".png")))
                    Image.fromarray(np.asarray(
                        preds[0], dtype=np.uint8)).save(final_name)
                    pass
                pass
            pass

        score = metrics.get_results()
        Tools.print("{}".format(metrics.to_str(score)),
                    txt_path=self.config.ss_save_result_txt)
        return score

    def load_model(self, model_file_name):
        Tools.print("Load model form {}".format(model_file_name),
                    txt_path=self.config.ss_save_result_txt)
        checkpoint = torch.load(model_file_name)

        if len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) == 1:
            # checkpoint = {key.replace("module.", ""): checkpoint[key] for key in checkpoint}
            pass

        self.net.load_state_dict(checkpoint, strict=True)
        Tools.print("Restore from {}".format(model_file_name),
                    txt_path=self.config.ss_save_result_txt)
        pass

    def stat(self):
        stat(self.net, (3, self.config.ss_size, self.config.ss_size))
        pass

    pass
Ejemplo n.º 4
0
               'Head', 'R_Hand', 'L_Hand', 'R_Toe', 'L_Toe')
flip_pairs = ((2, 5), (3, 6), (4, 7), (8, 11), (9, 12), (10, 13), (17, 18),
              (19, 20))
skeleton = ((0, 16), (16, 1), (1, 15), (15, 14), (14, 8), (14, 11), (8, 9),
            (9, 10), (10, 19), (11, 12), (12, 13), (13, 20), (1, 2), (2, 3),
            (3, 4), (4, 17), (1, 5), (5, 6), (6, 7), (7, 18))

# snapshot load
model_path = './snapshot_%d.pth.tar' % int(args.test_epoch)
assert osp.exists(model_path), 'Cannot find model at ' + model_path
print('Load checkpoint from {}'.format(model_path))
model = get_pose_net(cfg, False, joint_num)
model = DataParallel(model).cuda()
ckpt = torch.load(model_path)
model.load_state_dict(ckpt['network'])
model.eval()

# prepare input image
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=cfg.pixel_mean, std=cfg.pixel_std)
])
img_path = 'input.jpg'
original_img = cv2.imread(img_path)
original_img_height, original_img_width = original_img.shape[:2]

# prepare bbox
bbox_list = [
    [139.41, 102.25, 222.39, 241.57],\
[287.17, 61.52, 74.88, 165.61],\
[540.04, 48.81, 99.96, 223.36],\
Ejemplo n.º 5
0
class RootNet(object):
    def __init__(self, weightsPath, principal_points=None, focal=(1500, 1500)):
        """

        :param weightsPath:
        :param principal_points:
        :param focal:
        """

        self.focal = focal
        self.principal_points = principal_points

        self.net = get_pose_net(cfg, False)
        self.net = DataParallel(self.net).cuda()
        weigths = torch.load(weightsPath)
        self.net.load_state_dict(weigths['network'])
        self.net.eval()

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=cfg.pixel_mean, std=cfg.pixel_std)
        ])

    def estimate(self, bboxes, image, tracking=False):
        """

        :param bboxes:
        :param image:
        :return:
        """
        if self.principal_points is None:
            self.principal_points = [image.shape[1] / 2, image.shape[0] / 2]

        output = []
        for bbox in bboxes:
            bbox_xywh = convertToXYWH(bbox[0], bbox[1], bbox[2], bbox[3])
            bbox_root = process_bbox(bbox_xywh, image.shape[1], image.shape[0])
            img, img2bb_trans = generate_patch_image(image, bbox_root, False,
                                                     0.0)
            img = self.transform(img).cuda()[None, :, :, :]
            k_value = np.array([
                math.sqrt(cfg.bbox_real[0] * cfg.bbox_real[1] * self.focal[0] *
                          self.focal[1] / (bbox_root[2] * bbox_root[3]))
            ]).astype(np.float32)
            k_value = torch.FloatTensor([k_value]).cuda()[None, :]

            # forward
            with torch.no_grad():
                root_3d = self.net(
                    img, k_value)  # x,y: pixel, z: root-relative depth (mm)
            root_3d = root_3d[0].cpu().numpy()

            # inverse affine transform (restore the crop and resize)
            root_3d[0] = root_3d[0] / cfg.output_shape[1] * cfg.input_shape[1]
            root_3d[1] = root_3d[1] / cfg.output_shape[0] * cfg.input_shape[0]
            root_3d_xy1 = np.concatenate(
                (root_3d[:2], np.ones_like(root_3d[:1])))
            img2bb_trans_001 = np.concatenate(
                (img2bb_trans, np.array([0, 0, 1]).reshape(1, 3)))
            root_3d[:2] = np.dot(np.linalg.inv(img2bb_trans_001),
                                 root_3d_xy1)[:2]
            # get 3D coordinates for bbox
            root_3d = pixel2cam(root_3d[None, :], self.focal,
                                self.principal_points)
            if tracking:
                pid = bbox[-1]
                output.append([bbox, root_3d, pid])
            else:
                output.append([bbox, root_3d])

        return output
Ejemplo n.º 6
0
def load_model(path='rootnet/rootnet_snapshot_18.pth.tar'):
    model = DataParallel(get_pose_net()).cuda()
    model.load_state_dict(torch.load(path)['network'])
    model.eval()
    return model
Ejemplo n.º 7
0
def stage2_train(args):
    logger = init_logger(args)
    if args.summary:
        summary_writer = SummaryWriter(args.s2_summary_path)
    dataset = Birds(args.data_dir, split='train', im_size=256)
    dataloader = DataLoader(dataset, batch_size=args.s2_batch_size, shuffle=True, num_workers=8, drop_last=True)
    generator1 = Stage1Generator(args.txt_embedding_dim, args.c_dim, args.z_dim, args.gf_dim)
    print('generator1={}'.format(generator1))
    state_dict = torch.load(args.s1_checkpoint_path)
    for n, p in generator1.state_dict().items():
        if 'module.' in state_dict:
            p.copy_(state_dict['module.' + n])
    generator1 = generator1.cuda()
    generator2 = Stage2Generator(args.txt_embedding_dim, args.c_dim, args.gf_dim).cuda()
    print(f'generator2={generator2}')
    discriminator = Stage2Discriminator(args.df_dim, args.c_dim).cuda()
    print('discriminator={}'.format(discriminator))
    device_ids = list(range(torch.cuda.device_count()))
    generator1 = DataParallel(generator1, device_ids)
    generator2 = DataParallel(generator2, device_ids)
    discriminator = DataParallel(discriminator, device_ids)
    g2_parameters = list(filter(lambda f: f.requires_grad, generator2.parameters()))
    d_parameters = list(filter(lambda f: f.requires_grad, discriminator.parameters()))
    g2_optimizer = torch.optim.Adam(g2_parameters, args.lr, betas=(0.5, 0.999))
    d_optimizer = torch.optim.Adam(d_parameters, args.lr, betas=(0.5, 0.999))
    r_labels = torch.ones((args.s2_batch_size,), device='cuda:0')
    f_labels = torch.zeros((args.s2_batch_size,), device='cuda:0')
    criterion = nn.BCELoss()
    cur_lr = args.lr
    generator1.eval()
    for epoch in range(args.total_epoch):
        for idx, (r_imgs, txt_embeddings) in enumerate(dataloader):
            r_imgs = r_imgs.cuda()
            txt_embeddings = txt_embeddings.cuda()
            # discriminator
            noise = torch.zeros((args.s2_batch_size, args.z_dim), device='cuda:0').normal_()
            with torch.no_grad():
                s1_img, _, _ = generator1(txt_embeddings, noise)
            s1_img = s1_img.detach()
            s2_img, mu, logvar = generator2(txt_embeddings, s1_img)
            d_loss, r_loss, w_loss, f_loss = discriminator_loss(discriminator, r_imgs, s2_img.detach(), mu.detach(), r_labels, f_labels, criterion)
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()
            # generator
            s2_img, mu, logvar = generator2(txt_embeddings, s1_img)
            logits = discriminator(mu.detach(), s2_img)
            g_loss = criterion(logits, r_labels)
            kl_loss_ = kl_loss(mu, logvar)
            g_loss += kl_loss_
            g2_optimizer.zero_grad()
            g_loss.backward()
            g2_optimizer.step()
            if args.summary and idx % args.summary_iters == 0 and idx > 0:
                summary_writer.add_scalar('d_loss', g_loss.item())
                summary_writer.add_scalar('r_loss', r_loss.item())
                summary_writer.add_scalar('w_loss', w_loss.item())
                summary_writer.add_scalar('f_loss', f_loss.item())
                summary_writer.add_scalar('g_loss', g_loss.item())
                summary_writer.add_scalar('kl_loss', kl_loss.item())
            elif idx % args.display_iters == 0 and idx > 0:
                logger.info(f'epoch:{epoch}, lr={cur_lr}, d_loss={d_loss}, r_loss={r_loss}, w_loss={w_loss}, f_loss={f_loss}, g_loss={g_loss}, kl_loss={kl_loss_}')
        if epoch % args.lr_decay_every_epoch == 0 and epoch > 0:
            logger.info(f'lr decay: {cur_lr}')
            cur_lr *= args.lr_decay_ratio
            g2_optimizer = torch.optim.Adam(g2_parameters, cur_lr, betas=(0.5, 0.999))
            d_optimizer = torch.optim.Adam(d_parameters, cur_lr, betas=(0.5, 0.999))
        if epoch % args.display_epoch == 0 and epoch > 0:
            logger.info(f'epoch:{epoch}, lr={cur_lr}, d_loss={d_loss}, r_loss={r_loss}, w_loss={w_loss}, f_loss={f_loss}, g_loss={g_loss}, kl_loss={kl_loss_}')
        if epoch % args.checkpoint_epoch == 0 and epoch > 0:
            if not os.path.isdir(args.s2_checkpoint_dir):
                os.makedirs(args.s2_checkpoint_dir)
            logger.info(f'saving checkpoints_{epoch}')
            torch.save(generator2.state_dict(), os.path.join(args.s2_checkpoint_dir, f'generator_epoch_{epoch}.pth'))
            torch.save(discriminator.state_dict(), os.path.join(args.s2_checkpoint_dir, f'discriminator_epoch_{epoch}.pth'))
    torch.save(generator2.state_dict(), os.path.join(args.s2_checkpoint_dir, 'generator.pth'))
    torch.save(generator2.state_dict(), os.path.join(args.s2_checkpoint_dir, 'discriminator.pth'))
    if args.summary:
        summary_writer.close()
Ejemplo n.º 8
0
def execute(args):
    try:
        logger.info('人物深度処理開始: {0}', args.img_dir, decoration=MLogger.DECORATION_BOX)

        if not os.path.exists(args.img_dir):
            logger.error("指定された処理用ディレクトリが存在しません。: {0}", args.img_dir, decoration=MLogger.DECORATION_BOX)
            return False

        parser = get_parser()
        argv = parser.parse_args(args=[])

        if not os.path.exists(argv.model_path):
            logger.error("指定された学習モデルが存在しません。: {0}", argv.model_path, decoration=MLogger.DECORATION_BOX)
            return False

        cudnn.benchmark = True

        # snapshot load
        model = get_pose_net(argv, False)
        model = DataParallel(model).to('cuda')
        ckpt = torch.load(argv.model_path)
        model.load_state_dict(ckpt['network'])
        model.eval()
        focal = [1500, 1500] # x-axis, y-axis

        # prepare input image
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=argv.pixel_mean, std=argv.pixel_std)])

        # 全人物分の順番別フォルダ
        ordered_person_dir_pathes = sorted(glob.glob(os.path.join(args.img_dir, "ordered", "*")), key=sort_by_numeric)

        frame_pattern = re.compile(r'^(frame_(\d+)\.png)')

        for oidx, ordered_person_dir_path in enumerate(ordered_person_dir_pathes):    
            logger.info("【No.{0}】人物深度推定開始", f"{oidx:03}", decoration=MLogger.DECORATION_LINE)

            frame_json_pathes = sorted(glob.glob(os.path.join(ordered_person_dir_path, "frame_*.json")), key=sort_by_numeric)

            for frame_json_path in tqdm(frame_json_pathes, desc=f"No.{oidx:03} ... "):                
                m = frame_pattern.match(os.path.basename(frame_json_path))
                if m:
                    frame_image_name = str(m.groups()[0])
                    fno_name = str(m.groups()[1])
                    
                    # 該当フレームの画像パス
                    frame_image_path = os.path.join(args.img_dir, "frames", fno_name, frame_image_name)

                    if os.path.exists(frame_image_path):

                        frame_joints = {}
                        with open(frame_json_path, 'r') as f:
                            frame_joints = json.load(f)
                        
                        width = int(frame_joints['image']['width'])
                        height = int(frame_joints['image']['height'])

                        original_img = cv2.imread(frame_image_path)

                        bx = float(frame_joints["bbox"]["x"])
                        by = float(frame_joints["bbox"]["y"])
                        bw = float(frame_joints["bbox"]["width"])
                        bh = float(frame_joints["bbox"]["height"])

                        # ROOT_NETで深度推定
                        bbox = process_bbox([bx, by, bw, bh], width, height, argv)
                        img, img2bb_trans = generate_patch_image(original_img, bbox, False, 0.0, argv)
                        img = transform(img).to('cuda')[None,:,:,:]
                        k_value = np.array([math.sqrt(argv.bbox_real[0] * argv.bbox_real[1] * focal[0] * focal[1] / (bbox[2] * bbox[3]))]).astype(np.float32)
                        k_value = torch.FloatTensor([k_value]).to('cuda')[None,:]

                        with torch.no_grad():
                            root_3d = model(img, k_value) # x,y: pixel, z: root-relative depth (mm)

                        img = img[0].to('cpu').numpy()
                        root_3d = root_3d[0].to('cpu').numpy()
                        root_3d[0] = root_3d[0] / argv.output_shape[0] * bbox[2] + bbox[0]
                        root_3d[1] = root_3d[1] / argv.output_shape[1] * bbox[3] + bbox[1]

                        frame_joints["root"] = {"x": float(root_3d[0]), "y": float(root_3d[1]), "z": float(root_3d[2]), \
                                                "input": {"x": argv.input_shape[0], "y": argv.input_shape[1]}, "output": {"x": argv.output_shape[0], "y": argv.output_shape[1]}, \
                                                "focal": {"x": focal[0], "y": focal[1]}}

                        with open(frame_json_path, 'w') as f:
                            json.dump(frame_joints, f, indent=4)

        logger.info('人物深度処理終了: {0}', args.img_dir, decoration=MLogger.DECORATION_BOX)

        return True
    except Exception as e:
        logger.critical("人物深度で予期せぬエラーが発生しました。", e, decoration=MLogger.DECORATION_BOX)
        return False