Exemple #1
0
    def __init__(
        self,
        c,
        nof_joints,
        checkpoint_path,
        resolution=(384, 288),
        interpolation=cv2.INTER_CUBIC,
        multiperson=True,
        yolo_model_def="./models/detectors/yolo/config/yolov3.cfg",
        yolo_class_path="./models/detectors/yolo/data/coco.names",
        yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights",
        device=torch.device("cpu")):

        self.c = c
        self.nof_joints = nof_joints
        self.checkpoint_path = checkpoint_path
        self.resolution = resolution  #en la forma (alto, ancho) como en la implementación original
        self.interpolation = interpolation
        self.multiperson = multiperson
        self.max_batch_size = max_batch_size
        self.yolo_model_def = yolo_model_def
        self.yolo_class_path = yolo_class_path
        self.yolo_weights_path = yolo_weights_path
        self.device = device

        self.model = HRNet(c=c, nof_joints=nof_joints).to(device)
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        if 'model' in checkpoint:
            self.model.load_state_dict(checkpoint['model'])
        else:
            self.model.load_state_dict(checkpoint)
        self.model.eval()

        if not self.multiperson:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

        else:
            self.detector = YOLOv3(model_def=yolo_model_def,
                                   class_path=yolo_class_path,
                                   weights_path=yolo_weights_path,
                                   classes=('person', ),
                                   max_batch_size=self.max_batch_size,
                                   device=device)
            self.transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((self.resolution[0],
                                   self.resolution[1])),  # (height, width)
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

        pass
Exemple #2
0
def convertToONNX():
    checkpoint_path = "./weights/pose_hrnet_w48_384x288.pth"

    model = HRNet(c=48, nof_joints=17)
    checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint)
    count_parameters(model)

    x = torch.randn(1, 3, 384, 288, requires_grad=True)

    torch.onnx.export(
        model=model,
        args=x,
        f=ONNX_PATH,  # where should it be saved
        verbose=False,
        export_params=True,
        do_constant_folding=False,  # fold constant values for optimization
        # do_constant_folding=True,   # fold constant values for optimization
        input_names=['input'],
        output_names=['output'])

    onnx_model = onnx.load(ONNX_PATH)
    onnx.checker.check_model(onnx_model)
Exemple #3
0
def main(args):

    if args.task == 'train':
        if not args.train_image_path:
            raise 'train data path should be specified !'
        train_dataset = SumitomoCADDS(file_path=args.train_image_path)
        #train_dataset = SumitomoCADDS(file_path=args.val_image_path)

        if args.val_image_path:
            val_dataset = SumitomoCADDS(file_path=args.val_image_path,
                                        val=True)

        model = HRNet(3, 32, 8).to(device)
        #model = ResUNet(3, 8).to(device)
        #model = R2AttU_Net(3, 1).to(device)
        optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0003)

        if args.resume:
            if not os.path.isfile(args.resume):
                raise '=> no checkpoint found at %s' % args.resume
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            args.best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print('=> loaded checkpoint %s (epoch %d)' %
                  (args.resume, args.start_epoch))

        train(args, model, optimizer, train_dataset, val_dataset)

    else:  # test
        if not args.test_image_path:
            raise '=> test data path should be specified'
        if not args.resume or not os.path.isfile(args.resume):
            raise '=> resume not specified or no checkpoint found'
        test_dataset = SumitomoCADDS(file_path=args.test_image_path, test=True)
        model = ResUNet(3, 8).to(device)
        #model = R2AttU_Net(3, 1).to(device)
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        print(f'Successfully loaded model from {args.resume}')
        test(args, model, test_dataset)
    def __init__(self,
                 c,
                 nof_joints,
                 checkpoint_path,
                 model_name='HRNet',
                 resolution=(384, 288),
                 device=torch.device('cuda')):

        self.c = c
        self.nof_joints = nof_joints
        self.checkpoint_path = checkpoint_path
        self.model_name = model_name
        self.resolution = resolution
        self.device = device

        if model_name in ('HRNet', 'hrnet'):
            self.model = HRNet(c=c, nof_joints=nof_joints)
        elif model_name in ('PoseResNet', 'poseresnet', 'ResNet', 'resnet'):
            self.model = PoseResNet(resnet_size=c, nof_joints=nof_joints)
        elif model_name in ('hg', 'HG'):
            self.model = hg(num_stacks=c, num_blocks=1, num_classes=nof_joints)
        else:
            raise ValueError('Wrong model name.')

        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        if 'model' in checkpoint:
            self.model.load_state_dict(checkpoint['model'])
        else:
            self.model.load_state_dict(checkpoint)

        self.model = self.model.to(device)
        self.model.eval()

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
Exemple #5
0
class Train(object):
    """
    Train  class.

    The class provides a basic tool for training HRNet.
    Most of the training options are customizable.

    The only method supposed to be directly called is `run()`.
    """

    def __init__(self,
                 exp_name,
                 ds_train,
                 ds_val,
                 epochs=210,
                 batch_size=16,
                 num_workers=4,
                 loss='JointsMSELoss',
                 lr=0.001,
                 lr_decay=True,
                 lr_decay_steps=(170, 200),
                 lr_decay_gamma=0.1,
                 optimizer='Adam',
                 weight_decay=0.,
                 momentum=0.9,
                 nesterov=False,
                 pretrained_weight_path=None,
                 checkpoint_path=None,
                 log_path='./logs',
                 use_tensorboard=True,
                 model_c=48,
                 model_nof_joints=17,
                 model_bn_momentum=0.1,
                 flip_test_images=True,
                 device=None
                 ):
        """
        Initializes a new Train object.

        The log folder is created, the HRNet model is initialized and optional pre-trained weights or saved checkpoints
        are loaded.
        The DataLoaders, the loss function, and the optimizer are defined.

        Args:
            exp_name (str):  experiment name.
            ds_train (HumanPoseEstimationDataset): train dataset.
            ds_val (HumanPoseEstimationDataset): validation dataset.
            epochs (int): number of epochs.
                Default: 210
            batch_size (int): batch size.
                Default: 16
            num_workers (int): number of workers for each DataLoader
                Default: 4
            loss (str): loss function. Valid values are 'JointsMSELoss' and 'JointsOHKMMSELoss'.
                Default: "JointsMSELoss"
            lr (float): learning rate.
                Default: 0.001
            lr_decay (bool): learning rate decay.
                Default: True
            lr_decay_steps (tuple): steps for the learning rate decay scheduler.
                Default: (170, 200)
            lr_decay_gamma (float): scale factor for each learning rate decay step.
                Default: 0.1
            optimizer (str): network optimizer. Valid values are 'Adam' and 'SGD'.
                Default: "Adam"
            weight_decay (float): weight decay.
                Default: 0.
            momentum (float): momentum factor.
                Default: 0.9
            nesterov (bool): Nesterov momentum.
                Default: False
            pretrained_weight_path (str): path to pre-trained weights (such as weights from pre-train on imagenet).
                Default: None
            checkpoint_path (str): path to a previous checkpoint.
                Default: None
            log_path (str): path where tensorboard data and checkpoints will be saved.
                Default: "./logs"
            use_tensorboard (bool): enables tensorboard use.
                Default: True
            model_c (int): hrnet parameters - number of channels.
                Default: 48
            model_nof_joints (int): hrnet parameters - number of joints.
                Default: 17
            model_bn_momentum (float): hrnet parameters - path to the pretrained weights.
                Default: 0.1
            flip_test_images (bool): flip images during validating.
                Default: True
            device (torch.device): device to be used (default: cuda, if available).
                Default: None
        """
        super(Train, self).__init__()

        self.exp_name = exp_name
        self.ds_train = ds_train
        self.ds_val = ds_val
        self.epochs = epochs
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.loss = loss
        self.lr = lr
        self.lr_decay = lr_decay
        self.lr_decay_steps = lr_decay_steps
        self.lr_decay_gamma = lr_decay_gamma
        self.optimizer = optimizer
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.nesterov = nesterov
        self.pretrained_weight_path = pretrained_weight_path
        self.checkpoint_path = checkpoint_path
        self.log_path = os.path.join(log_path, self.exp_name)
        self.use_tensorboard = use_tensorboard
        self.model_c = model_c
        self.model_nof_joints = model_nof_joints
        self.model_bn_momentum = model_bn_momentum
        self.flip_test_images = flip_test_images
        self.epoch = 0

        # torch device
        if device is not None:
            self.device = device
        else:
            if torch.cuda.is_available():
                self.device = torch.device('cuda:0')
            else:
                self.device = torch.device('cpu')

        print(self.device)

        os.makedirs(self.log_path, 0o755, exist_ok=False)  # exist_ok=False to avoid overwriting
        if self.use_tensorboard:
            self.summary_writer = tb.SummaryWriter(self.log_path)

        #
        # write all experiment parameters in parameters.txt and in tensorboard text field
        self.parameters = [x + ': ' + str(y) + '\n' for x, y in locals().items()]
        with open(os.path.join(self.log_path, 'parameters.txt'), 'w') as fd:
            fd.writelines(self.parameters)
        if self.use_tensorboard:
            self.summary_writer.add_text('parameters', '\n'.join(self.parameters))

        #
        # load model
        self.model = HRNet(c=self.model_c, nof_joints=self.model_nof_joints,
                           bn_momentum=self.model_bn_momentum).to(self.device)

        #
        # define loss and optimizers
        if self.loss == 'JointsMSELoss':
            self.loss_fn = JointsMSELoss().to(self.device)
        elif self.loss == 'JointsOHKMMSELoss':
            self.loss_fn = JointsOHKMMSELoss().to(self.device)
        else:
            raise NotImplementedError

        if optimizer == 'SGD':
            self.optim = SGD(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay,
                             momentum=self.momentum, nesterov=self.nesterov)
        elif optimizer == 'Adam':
            self.optim = Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        else:
            raise NotImplementedError

        #
        # load pre-trained weights (such as those pre-trained on imagenet)
        if self.pretrained_weight_path is not None:
            self.model.load_state_dict(torch.load(self.pretrained_weight_path, map_location=self.device), strict=True)
            print('Pre-trained weights loaded.')

        #
        # load previous checkpoint
        if self.checkpoint_path is not None:
            print('Loading checkpoint %s...' % self.checkpoint_path)
            if os.path.isdir(self.checkpoint_path):
                path = os.path.join(self.checkpoint_path, 'checkpoint_last.pth')
            else:
                path = self.checkpoint_path
            self.starting_epoch, self.model, self.optim, self.params = load_checkpoint(path, self.model, self.optim,
                                                                                       self.device)
        else:
            self.starting_epoch = 0

        if lr_decay:
            self.lr_scheduler = MultiStepLR(self.optim, list(self.lr_decay_steps), gamma=self.lr_decay_gamma,
                                            last_epoch=self.starting_epoch if self.starting_epoch else -1)

        #
        # load train and val datasets
        self.dl_train = DataLoader(self.ds_train, batch_size=self.batch_size, shuffle=True,
                                   num_workers=self.num_workers, drop_last=True)
        self.len_dl_train = len(self.dl_train)

        # dl_val = DataLoader(self.ds_val, batch_size=1, shuffle=False, num_workers=num_workers)
        self.dl_val = DataLoader(self.ds_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        self.len_dl_val = len(self.dl_val)

        #
        # initialize variables
        self.mean_loss_train = 0.
        self.mean_acc_train = 0.
        self.mean_loss_val = 0.
        self.mean_acc_val = 0.
        self.mean_mAP_val = 0.

        self.best_loss = None
        self.best_acc = None
        self.best_mAP = None

    def _train(self):
        self.model.train()

        for step, (image, target, target_weight, joints_data) in enumerate(tqdm(self.dl_train, desc='Training')):
            image = image.to(self.device)
            target = target.to(self.device)
            target_weight = target_weight.to(self.device)

            self.optim.zero_grad()

            output = self.model(image)

            loss = self.loss_fn(output, target, target_weight)

            loss.backward()

            self.optim.step()

            # Evaluate accuracy
            # Get predictions on the input
            accs, avg_acc, cnt, joints_preds, joints_target = self.ds_train.evaluate_accuracy(output, target)

            self.mean_loss_train += loss.item()
            self.mean_acc_train += avg_acc.item()
            if self.use_tensorboard:
                self.summary_writer.add_scalar('train_loss', loss.item(),
                                               global_step=step + self.epoch * self.len_dl_train)
                self.summary_writer.add_scalar('train_acc', avg_acc.item(),
                                               global_step=step + self.epoch * self.len_dl_train)
                if step == 0:
                    save_images(image, target, joints_target, output, joints_preds, joints_data['joints_visibility'],
                                self.summary_writer, step=step + self.epoch * self.len_dl_train, prefix='train_')

        self.mean_loss_train /= len(self.dl_train)
        self.mean_acc_train /= len(self.dl_train)

        print('\nTrain: Loss %f - Accuracy %f' % (self.mean_loss_train, self.mean_acc_train))

    def _val(self):
        self.model.eval()

        with torch.no_grad():
            for step, (image, target, target_weight, joints_data) in enumerate(tqdm(self.dl_val, desc='Validating')):
                image = image.to(self.device)
                target = target.to(self.device)
                target_weight = target_weight.to(self.device)

                output = self.model(image)

                if self.flip_test_images:
                    image_flipped = flip_tensor(image, dim=-1)
                    output_flipped = self.model(image_flipped)

                    output_flipped = flip_back(output_flipped, self.ds_val.flip_pairs)

                    output = (output + output_flipped) * 0.5

                loss = self.loss_fn(output, target, target_weight)

                # Evaluate accuracy
                # Get predictions on the input
                accs, avg_acc, cnt, joints_preds, joints_target = \
                    self.ds_train.evaluate_accuracy(output, target)

                self.mean_loss_train += loss.item()
                self.mean_acc_train += avg_acc.item()
                if self.use_tensorboard:
                    self.summary_writer.add_scalar('val_loss', loss.item(),
                                                   global_step=step + self.epoch * self.len_dl_train)
                    self.summary_writer.add_scalar('val_acc', avg_acc.item(),
                                                   global_step=step + self.epoch * self.len_dl_train)
                    if step == 0:
                        save_images(image, target, joints_target, output, joints_preds,
                                    joints_data['joints_visibility'], self.summary_writer,
                                    step=step + self.epoch * self.len_dl_train, prefix='val_')

        self.mean_loss_val /= len(self.dl_val)
        self.mean_acc_val /= len(self.dl_val)

        print('\nValidation: Loss %f - Accuracy %f' % (self.mean_loss_val, self.mean_acc_val))

    def _checkpoint(self):

        save_checkpoint(path=os.path.join(self.log_path, 'checkpoint_last.pth'), epoch=self.epoch + 1, model=self.model,
                        optimizer=self.optim, params=self.parameters)

        if self.best_loss is None or self.best_loss > self.mean_loss_val:
            self.best_loss = self.mean_loss_val
            print('best_loss %f at epoch %d' % (self.best_loss, self.epoch + 1))
            save_checkpoint(path=os.path.join(self.log_path, 'checkpoint_best_loss.pth'), epoch=self.epoch + 1,
                            model=self.model, optimizer=self.optim, params=self.parameters)
        if self.best_acc is None or self.best_acc < self.mean_acc_val:
            self.best_acc = self.mean_acc_val
            print('best_acc %f at epoch %d' % (self.best_acc, self.epoch + 1))
            save_checkpoint(path=os.path.join(self.log_path, 'checkpoint_best_acc.pth'), epoch=self.epoch + 1,
                            model=self.model, optimizer=self.optim, params=self.parameters)
        if self.best_mAP is None or self.best_mAP < self.mean_mAP_val:
            self.best_mAP = self.mean_mAP_val
            print('best_mAP %f at epoch %d' % (self.best_mAP, self.epoch + 1))
            save_checkpoint(path=os.path.join(self.log_path, 'checkpoint_best_mAP.pth'), epoch=self.epoch + 1,
                            model=self.model, optimizer=self.optim, params=self.parameters)

    def run(self):
        """
        Runs the training.
        """

        print('\nTraining started @ %s' % datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

        # start training
        for self.epoch in range(self.starting_epoch, self.epochs):
            print('\nEpoch %d of %d @ %s' % (self.epoch + 1, self.epochs, datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

            self.mean_loss_train = 0.
            self.mean_loss_val = 0.
            self.mean_acc_train = 0.
            self.mean_acc_val = 0.
            self.mean_mAP_val = 0.

            #
            # Train

            self._train()

            #
            # Val

            self._val()

            #
            # LR Update

            if self.lr_decay:
                self.lr_scheduler.step()

            #
            # Checkpoint

            self._checkpoint()

        print('\nTraining ended @ %s' % datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
class SimpleHRNet(object):
    def __init__(
        self,
        c,
        nof_joints,
        checkpoint_path='weights/mod_pose_hrnet_w32_256x192.pth',
        resolution=(384, 288),
        interpolation=cv2.INTER_CUBIC,
        multiperson=True,
        max_batch_size=32,
        yolo_model_def="./models/detectors/yolo/config/yolov3.cfg",
        yolo_class_path="./models/detectors/yolo/data/coco.names",
        yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights",
        device=torch.device("cpu")):

        self.c = c
        self.nof_joints = nof_joints
        self.checkpoint_path = checkpoint_path
        self.resolution = resolution  # in the form (height, width) as in the original implementation
        self.interpolation = interpolation
        self.multiperson = multiperson
        self.max_batch_size = max_batch_size
        self.yolo_model_def = yolo_model_def
        self.yolo_class_path = yolo_class_path
        self.yolo_weights_path = yolo_weights_path
        self.device = device

        self.model = HRNet(c=c, nof_joints=nof_joints).to(device)
        self.model.load_state_dict(
            torch.load(checkpoint_path, map_location=self.device))
        self.model.eval()

        if not self.multiperson:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

        else:
            self.detector = YOLOv3(model_def=yolo_model_def,
                                   class_path=yolo_class_path,
                                   weights_path=yolo_weights_path,
                                   classes=('person', ),
                                   max_batch_size=self.max_batch_size,
                                   device=device)
            self.transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((self.resolution[0],
                                   self.resolution[1])),  # (height, width)
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

        pass

    def predict(self, image):
        if len(image.shape) == 3:
            return self._predict_single(image)
        elif len(image.shape) == 4:
            return self._predict_batch(image)
        else:
            raise ValueError('Wrong image format.')

    def _predict_single(self, image):
        if not self.multiperson:
            old_res = image.shape
            if self.resolution is not None:
                image = cv2.resize(
                    image,
                    (self.resolution[1],
                     self.resolution[0]),  # (width, height)
                    interpolation=self.interpolation)

            images = self.transform(cv2.cvtColor(
                image, cv2.COLOR_BGR2RGB)).unsqueeze(dim=0)
            boxes = np.asarray([[0, 0, old_res[1], old_res[0]]],
                               dtype=np.float32)  # [x1, y1, x2, y2]

        else:
            detections = self.detector.predict_single(image)

            boxes = []
            if detections is not None:
                images = torch.empty((len(detections), 3, self.resolution[0],
                                      self.resolution[1]))  # (height, width)
                for i, (x1, y1, x2, y2, conf, cls_conf,
                        cls_pred) in enumerate(detections):
                    x1 = int(round(x1.item()))
                    x2 = int(round(x2.item()))
                    y1 = int(round(y1.item()))
                    y2 = int(round(y2.item()))

                    boxes.append([x1, y1, x2, y2])
                    images[i] = self.transform(image[y1:y2, x1:x2, ::-1])

            else:
                images = torch.empty((0, 3, self.resolution[0],
                                      self.resolution[1]))  # (height, width)

            boxes = np.asarray(boxes, dtype=np.int32)

        if images.shape[0] > 0:
            images = images.to(self.device)

            with torch.no_grad():
                if len(images) <= self.max_batch_size:
                    out = self.model(images)

                else:
                    out = torch.empty(
                        (images.shape[0], self.nof_joints,
                         self.resolution[0] // 4,
                         self.resolution[1] // 4)).to(self.device)
                    for i in range(0, len(images), self.max_batch_size):
                        out[i:i + self.max_batch_size] = self.model(
                            images[i:i + self.max_batch_size])

            out = out.detach().cpu().numpy()
            pts = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)
            # For each human, for each joint: x, y, confidence
            for i, human in enumerate(out):
                for j, joint in enumerate(human):
                    pt = np.unravel_index(
                        np.argmax(joint),
                        (self.resolution[0] // 4, self.resolution[1] // 4))

                    pts[i, j, 0] = pt[0] * 1. / (self.resolution[0] // 4) * (
                        boxes[i][3] - boxes[i][1]) + boxes[i][1]
                    pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (
                        boxes[i][2] - boxes[i][0]) + boxes[i][0]
                    pts[i, j, 2] = joint[pt]

        else:
            pts = np.empty((0, 0, 3), dtype=np.float32)

        return pts

    def _predict_batch(self, images):
        if not self.multiperson:
            old_res = images[0].shape

            if self.resolution is not None:
                images_tensor = torch.empty(images.shape[0], 3,
                                            self.resolution[0],
                                            self.resolution[1])
            else:
                images_tensor = torch.empty(images.shape[0], 3,
                                            images.shape[1], images.shape[2])

            for i, image in enumerate(images):
                if self.resolution is not None:
                    image = cv2.resize(
                        image,
                        (self.resolution[1],
                         self.resolution[0]),  # (width, height)
                        interpolation=self.interpolation)

                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

                images_tensor[i] = self.transform(image)

            images = images_tensor
            boxes = np.repeat(np.asarray([[0, 0, old_res[1], old_res[0]]],
                                         dtype=np.float32),
                              len(images),
                              axis=0)  # [x1, y1, x2, y2]

        else:
            image_detections = self.detector.predict(images)

            boxes = []
            images_tensor = []
            for d, detections in enumerate(image_detections):
                image = images[d]
                boxes_image = []
                if detections is not None:
                    images_tensor_image = torch.empty(
                        (len(detections), 3, self.resolution[0],
                         self.resolution[1]))  # (height, width)
                    for i, (x1, y1, x2, y2, conf, cls_conf,
                            cls_pred) in enumerate(detections):
                        x1 = int(round(x1.item()))
                        x2 = int(round(x2.item()))
                        y1 = int(round(y1.item()))
                        y2 = int(round(y2.item()))

                        boxes_image.append([x1, y1, x2, y2])
                        images_tensor_image[i] = self.transform(
                            image[y1:y2, x1:x2, ::-1])

                else:
                    images_tensor_image = torch.empty(
                        (0, 3, self.resolution[0],
                         self.resolution[1]))  # (height, width)

                # stack all images and boxes in single lists
                images_tensor.extend(images_tensor_image)
                boxes.extend(boxes_image)

            # convert lists into tensors/np.ndarrays
            images = torch.tensor(np.stack(images_tensor))
            boxes = np.asarray(boxes, dtype=np.int32)

        images = images.to(self.device)

        with torch.no_grad():
            if len(images) <= self.max_batch_size:
                out = self.model(images)

            else:
                out = torch.empty(
                    (images.shape[0], self.nof_joints, self.resolution[0] // 4,
                     self.resolution[1] // 4)).to(self.device)
                for i in range(0, len(images), self.max_batch_size):
                    out[i:i + self.max_batch_size] = self.model(
                        images[i:i + self.max_batch_size])

        out = out.detach().cpu().numpy()
        pts = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)
        # For each human, for each joint: x, y, confidence
        for i, human in enumerate(out):
            for j, joint in enumerate(human):
                pt = np.unravel_index(
                    np.argmax(joint),
                    (self.resolution[0] // 4, self.resolution[1] // 4))
                # 0: pt_x / (width // 4) * (bb_x2 - bb_x1) + bb_x1
                # 1: pt_y / (height // 4) * (bb_y2 - bb_y1) + bb_y1
                # 2: confidences
                pts[i, j, 0] = pt[0] * 1. / (self.resolution[0] // 4) * (
                    boxes[i][3] - boxes[i][1]) + boxes[i][1]
                pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (
                    boxes[i][2] - boxes[i][0]) + boxes[i][0]
                pts[i, j, 2] = joint[pt]

        if self.multiperson:
            # re-add the removed batch axis (n)
            pts_batch = []
            index = 0
            for detections in image_detections:
                if detections is not None:
                    pts_batch.append(pts[index:index + len(detections)])
                    index += len(detections)
                else:
                    pts_batch.append(
                        np.zeros((0, self.nof_joints, 3), dtype=np.float32))
            pts = pts_batch

        else:
            pts = np.expand_dims(pts, axis=1)

        return pts
Exemple #7
0

    if config["dataset_name"] == "inria" :
        data_parser = Inria(config)
        data_parserv = Inria_v(config)
    if config["dataset_name"] == "ade20k" :
        data_parser = Ade20k(config)
        data_parserv = Ade20k_v(config)
    if config["dataset_name"] == "cityscape" :
        data_parser = Cityscape(config)
        data_parserv = Cityscape_v(config)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        if config["model_name"] == "hrnet" : 
            model = HRNet(configs=config)
        elif config["model_name"] == "vggunet" :
            model = Vggunet(configs=config)
        elif config["model_name"] == "subject4" : 
            model = Subject4(configs=config)
        elif config["model_name"] == "bisenet" : 
            model = Bisenet(configs=config)

    with mirrored_strategy.scope():
        dataset = tf.data.Dataset.from_generator(
            data_parser.generator,
            (tf.float32, tf.float32),
            # (tf.TensorShape([config["image_size"][0], config["image_size"][1], 3]), tf.TensorShape([config["image_size"][0], config["image_size"][1], 3]))
            (tf.TensorShape([None, None, 3]), tf.TensorShape([None, None]))
        ).batch(config["batch_size"], drop_remainder=True)
Exemple #8
0
    def __init__(self,
                 exp_name,
                 ds_train,
                 ds_val,
                 epochs=210,
                 batch_size=16,
                 num_workers=4,
                 loss='JointsMSELoss',
                 lr=0.001,
                 lr_decay=True,
                 lr_decay_steps=(170, 200),
                 lr_decay_gamma=0.1,
                 optimizer='Adam',
                 weight_decay=0.00001,
                 momentum=0.9,
                 nesterov=False,
                 pretrained_weight_path=None,
                 checkpoint_path=None,
                 log_path='./logs',
                 use_tensorboard=True,
                 model_c=48,
                 model_nof_joints=17,
                 model_bn_momentum=0.1,
                 flip_test_images=True,
                 device=None):
        """
        Inicializa el nuevo objeto Train
        Se crea el folder de logs, se inicializa el modelo HRNet y se determinan dimensiones pre entrenadas o puntos
        de control guardos son cargados
        """
        super(Train, self).__init__()

        self.exp_name = exp_name
        self.ds_train = ds_train
        self.ds_val = ds_val
        self.epochs = epochs
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.loss = loss
        self.lr = lr
        self.lr_decay = lr_decay
        self.lr_decay_steps = lr_decay_steps
        self.lr_decay_gamma = lr_decay_gamma
        self.optimizer = optimizer
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.nesterov = nesterov
        self.pretrained_weight_path = pretrained_weight_path
        self.checkpoint_path = checkpoint_path
        self.log_path = os.path.join(log_path, self.exp_name)
        self.use_tensorboard = use_tensorboard
        self.model_c = model_c
        self.model_nof_joints = model_nof_joints
        self.model_bn_momentum = model_bn_momentum
        self.flip_test_images = flip_test_images
        self.epoch = 0

        # torch devz
        if device is not None:
            self.device = device
        else:
            if torch.cuda.is_available():
                self.device = torch.device('cuda:0')
            else:
                self.device = torch.device('cpu')

        print(self.device)

        os.makedirs(self.log_path, 0o755, exist_ok=False)
        if self.use_tensorboard:
            self.summary_writer = tb.SummaryWriter(self.log_path)

        #escribe todos los parametros experimentales en parameters.txt y en campos de texto de tensorboard
        self.parameters = [
            x + ': ' + str(y) + '\n' for x, y in locals().items()
        ]
        with open(os.path.join(self.log_path, 'parameters.txt'), 'w') as fd:
            fd.writelines(self.parameters)
        if self.use_tensorboard:
            self.summary_writer.add_text('parameters',
                                         '\n'.join(self.parameters))

        #
        # Carga el modelo
        self.model = HRNet(c=self.model_c,
                           nof_joints=self.model_nof_joints,
                           bn_momentum=self.model_bn_momentum).to(self.device)

        if self.loss == 'JointsMSELoss':
            self.loss_fn = JointsMSELoss().to(self.device)
        elif self.loss == 'JointsOHKMMSELoss':
            self.loss_fn = JointsOHKMMSELoss().to(self.device)
        else:
            raise NotImplementedError

        if optimizer == 'SGD':
            self.optim = SGD(self.model.parameters(),
                             lr=self.lr,
                             weight_decay=self.weight_decay,
                             momentum=self.momentum,
                             nesterov=self.nesterov)
        elif optimizer == 'Adam':
            self.optim = Adam(self.model.parameters(),
                              lr=self.lr,
                              weight_decay=self.weight_decay)
        else:
            raise NotImplementedError

        # Carga las dimensiones preentrenadas
        if self.pretrained_weight_path is not None:
            self.model.load_state_dict(torch.load(self.pretrained_weight_path,
                                                  map_location=self.device),
                                       strict=False)

        #
        # carga puntos de control previos
        if self.checkpoint_path is not None:
            print('Loading checkpoint %s...' % self.checkpoint_path)
            if os.path.isdir(self.checkpoint_path):
                path = os.path.join(self.checkpoint_path,
                                    'checkpoint_last.pth')
            else:
                path = self.checkpoint_path
            self.starting_epoch, self.model, self.optim, self.params = load_checkpoint(
                path, self.model, self.optim, self.device)
        else:
            self.starting_epoch = 0

        if lr_decay:
            self.lr_scheduler = MultiStepLR(self.optim,
                                            list(self.lr_decay_steps),
                                            gamma=self.lr_decay_gamma,
                                            last_epoch=self.starting_epoch)

        # Carga el entrenamiento y los valores de los datasets
        self.dl_train = DataLoader(self.ds_train,
                                   batch_size=self.batch_size,
                                   shuffle=True,
                                   num_workers=self.num_workers,
                                   drop_last=True)
        self.len_dl_train = len(self.dl_train)

        self.dl_val = DataLoader(self.ds_val,
                                 batch_size=self.batch_size,
                                 shuffle=False,
                                 num_workers=self.num_workers)
        self.len_dl_val = len(self.dl_val)

        #
        # inicializa las variables
        self.mean_loss_train = 0.
        self.mean_acc_train = 0.
        self.mean_loss_val = 0.
        self.mean_acc_val = 0.
        self.mean_mAP_val = 0.

        self.best_loss = None
        self.best_acc = None
        self.best_mAP = None
Exemple #9
0
class THRNet:
    """
    Clase HRNet.

    La clase proporciona un método simple y personalizable para cargar la red HRNet, cargar el oficial pre-entrenado
    pesos y predecir la pose humana en imágenes individuales.
    """
    def __init__(
        self,
        c,
        nof_joints,
        checkpoint_path,
        resolution=(384, 288),
        interpolation=cv2.INTER_CUBIC,
        multiperson=True,
        yolo_model_def="./models/detectors/yolo/config/yolov3.cfg",
        yolo_class_path="./models/detectors/yolo/data/coco.names",
        yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights",
        device=torch.device("cpu")):

        self.c = c
        self.nof_joints = nof_joints
        self.checkpoint_path = checkpoint_path
        self.resolution = resolution  #en la forma (alto, ancho) como en la implementación original
        self.interpolation = interpolation
        self.multiperson = multiperson
        self.max_batch_size = max_batch_size
        self.yolo_model_def = yolo_model_def
        self.yolo_class_path = yolo_class_path
        self.yolo_weights_path = yolo_weights_path
        self.device = device

        self.model = HRNet(c=c, nof_joints=nof_joints).to(device)
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        if 'model' in checkpoint:
            self.model.load_state_dict(checkpoint['model'])
        else:
            self.model.load_state_dict(checkpoint)
        self.model.eval()

        if not self.multiperson:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

        else:
            self.detector = YOLOv3(model_def=yolo_model_def,
                                   class_path=yolo_class_path,
                                   weights_path=yolo_weights_path,
                                   classes=('person', ),
                                   max_batch_size=self.max_batch_size,
                                   device=device)
            self.transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((self.resolution[0],
                                   self.resolution[1])),  # (height, width)
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

        pass

    def predict(self, image):
        """
        Predice la pose humana en una sola imagen.
        """
        if len(image.shape) == 3:
            return self._predict_single(image)
        elif len(image.shape) == 4:
            return self._predict_batch(image)
        else:
            raise ValueError('Mal formato de imagen.')

    def _predict_single(self, image):
        if not self.multiperson:
            old_res = image.shape
            if self.resolution is not None:
                image = cv2.resize(
                    image,
                    (self.resolution[1], self.resolution[0]),  # (ancho, alto)
                    interpolation=self.interpolation)

            images = self.transform(cv2.cvtColor(
                image, cv2.COLOR_BGR2RGB)).unsqueeze(dim=0)
            boxes = np.asarray([[0, 0, old_res[1], old_res[0]]],
                               dtype=np.float32)  # [x1, y1, x2, y2]

        else:
            detections = self.detector.predict_single(image)

            boxes = []
            if detections is not None:
                images = torch.empty((len(detections), 3, self.resolution[0],
                                      self.resolution[1]))
                for i, (x1, y1, x2, y2, conf, cls_conf,
                        cls_pred) in enumerate(detections):
                    x1 = int(round(x1.item()))
                    x2 = int(round(x2.item()))
                    y1 = int(round(y1.item()))
                    y2 = int(round(y2.item()))

                    # Adapte las detecciones para que coincidan con la relación de aspecto de entrada de HRNet
                    correction_factor = self.resolution[0] / self.resolution[
                        1] * (x2 - x1) / (y2 - y1)
                    if correction_factor > 1:
                        #  incrementando
                        center = y1 + (y2 - y1) // 2
                        length = int(round((y2 - y1) * correction_factor))
                        y1 = max(0, center - length // 2)
                        y2 = min(image.shape[0], center + length // 2)
                    elif correction_factor < 1:
                        # seguimos incrementando
                        center = x1 + (x2 - x1) // 2
                        length = int(round((x2 - x1) * 1 / correction_factor))
                        x1 = max(0, center - length // 2)
                        x2 = min(image.shape[1], center + length // 2)

                    boxes.append([x1, y1, x2, y2])
                    images[i] = self.transform(image[y1:y2, x1:x2, ::-1])

            else:
                images = torch.empty((0, 3, self.resolution[0],
                                      self.resolution[1]))  # (height, width)

            boxes = np.asarray(boxes, dtype=np.int32)

        if images.shape[0] > 0:
            images = images.to(self.device)

            with torch.no_grad():
                if len(images) <= self.max_batch_size:
                    out = self.model(images)

                else:
                    out = torch.empty(
                        (images.shape[0], self.nof_joints,
                         self.resolution[0] // 4,
                         self.resolution[1] // 4)).to(self.device)
                    for i in range(0, len(images), self.max_batch_size):
                        out[i:i + self.max_batch_size] = self.model(
                            images[i:i + self.max_batch_size])

            out = out.detach().cpu().numpy()
            pts = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)
            # Para cada humano, para cada articulación: x, y, confianza
            for i, human in enumerate(out):
                for j, joint in enumerate(human):
                    pt = np.unravel_index(
                        np.argmax(joint),
                        (self.resolution[0] // 4, self.resolution[1] // 4))
                    pts[i, j, 0] = pt[0] * 1. / (self.resolution[0] // 4) * (
                        boxes[i][3] - boxes[i][1]) + boxes[i][1]
                    pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (
                        boxes[i][2] - boxes[i][0]) + boxes[i][0]
                    pts[i, j, 2] = joint[pt]

        else:
            pts = np.empty((0, 0, 3), dtype=np.float32)

        return pts

    def _predict_batch(self, images):
        if not self.multiperson:
            old_res = images[0].shape

            if self.resolution is not None:
                images_tensor = torch.empty(images.shape[0], 3,
                                            self.resolution[0],
                                            self.resolution[1])
            else:
                images_tensor = torch.empty(images.shape[0], 3,
                                            images.shape[1], images.shape[2])

            for i, image in enumerate(images):
                if self.resolution is not None:
                    image = cv2.resize(
                        image, (self.resolution[1], self.resolution[0]),
                        interpolation=self.interpolation)

                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

                images_tensor[i] = self.transform(image)

            images = images_tensor
            boxes = np.repeat(np.asarray([[0, 0, old_res[1], old_res[0]]],
                                         dtype=np.float32),
                              len(images),
                              axis=0)  # [x1, y1, x2, y2]

        else:
            image_detections = self.detector.predict(images)

            boxes = []
            images_tensor = []
            for d, detections in enumerate(image_detections):
                image = images[d]
                boxes_image = []
                if detections is not None:
                    images_tensor_image = torch.empty(
                        (len(detections), 3, self.resolution[0],
                         self.resolution[1]))  # (height, width)
                    for i, (x1, y1, x2, y2, conf, cls_conf,
                            cls_pred) in enumerate(detections):
                        x1 = int(round(x1.item()))
                        x2 = int(round(x2.item()))
                        y1 = int(round(y1.item()))
                        y2 = int(round(y2.item()))

                        correction_factor = self.resolution[
                            0] / self.resolution[1] * (x2 - x1) / (y2 - y1)
                        if correction_factor > 1:

                            center = y1 + (y2 - y1) // 2
                            length = int(round((y2 - y1) * correction_factor))
                            y1 = max(0, center - length // 2)
                            y2 = min(image.shape[0], center + length // 2)
                        elif correction_factor < 1:

                            center = x1 + (x2 - x1) // 2
                            length = int(
                                round((x2 - x1) * 1 / correction_factor))
                            x1 = max(0, center - length // 2)
                            x2 = min(image.shape[1], center + length // 2)

                        boxes_image.append([x1, y1, x2, y2])
                        images_tensor_image[i] = self.transform(
                            image[y1:y2, x1:x2, ::-1])

                else:
                    images_tensor_image = torch.empty(
                        (0, 3, self.resolution[0],
                         self.resolution[1]))  # (height, width)

                # apilar todas las imágenes y cuadros en listas individuales
                images_tensor.extend(images_tensor_image)
                boxes.extend(boxes_image)

            # convertir listas en tensores / np.ndarrays
            images = torch.tensor(np.stack(images_tensor))
            boxes = np.asarray(boxes, dtype=np.int32)

        images = images.to(self.device)

        with torch.no_grad():
            if len(images) <= self.max_batch_size:
                out = self.model(images)

            else:
                out = torch.empty(
                    (images.shape[0], self.nof_joints, self.resolution[0] // 4,
                     self.resolution[1] // 4)).to(self.device)
                for i in range(0, len(images), self.max_batch_size):
                    out[i:i + self.max_batch_size] = self.model(
                        images[i:i + self.max_batch_size])

        out = out.detach().cpu().numpy()
        pts = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)

        for i, human in enumerate(out):
            for j, joint in enumerate(human):
                pt = np.unravel_index(
                    np.argmax(joint),
                    (self.resolution[0] // 4, self.resolution[1] // 4))

                pts[i, j, 0] = pt[0] * 1. / (self.resolution[0] // 4) * (
                    boxes[i][3] - boxes[i][1]) + boxes[i][1]
                pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (
                    boxes[i][2] - boxes[i][0]) + boxes[i][0]
                pts[i, j, 2] = joint[pt]

        if self.multiperson:
            # volver a agregar el eje de lote eliminado (n)
            pts_batch = []
            index = 0
            for detections in image_detections:
                if detections is not None:
                    pts_batch.append(pts[index:index + len(detections)])
                    index += len(detections)
                else:
                    pts_batch.append(
                        np.zeros((0, self.nof_joints, 3), dtype=np.float32))
            pts = pts_batch

        else:
            pts = np.expand_dims(pts, axis=1)

        return pts
    def __init__(
        self,
        c,
        nof_joints,
        checkpoint_path,
        resolution=(384, 288),
        interpolation=cv2.INTER_CUBIC,
        multiperson=True,
        return_bounding_boxes=False,
        max_batch_size=32,
        yolo_model_def="./models/detectors/yolo/config/yolov3.cfg",
        yolo_class_path="./models/detectors/yolo/data/coco.names",
        yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights",
        device=torch.device("cpu")):
        """
        Initializes a new SimpleHRNet object.
        HRNet (and YOLOv3) are initialized on the torch.device("device") and
        its (their) pre-trained weights will be loaded from disk.

        Args:
            c (int): number of channels.
            nof_joints (int): number of joints.
            checkpoint_path (str): path to an official hrnet checkpoint or a checkpoint obtained with `train_coco.py`.
            resolution (tuple): hrnet input resolution - format: (height, width).
                Default: (384, 288)
            interpolation (int): opencv interpolation algorithm.
                Default: cv2.INTER_CUBIC
            multiperson (bool): if True, multiperson detection will be enabled.
                This requires the use of a people detector (like YOLOv3).
                Default: True
            return_bounding_boxes (bool): if True, bounding boxes will be returned along with poses by self.predict.
                Default: False
            max_batch_size (int): maximum batch size used in hrnet inference.
                Useless without multiperson=True.
                Default: 16
            yolo_model_def (str): path to yolo model definition file.
                Default: "./models/detectors/yolo/config/yolov3.cfg"
            yolo_class_path (str): path to yolo class definition file.
                Default: "./models/detectors/yolo/data/coco.names"
            yolo_weights_path (str): path to yolo pretrained weights file.
                Default: "./models/detectors/yolo/weights/yolov3.weights.cfg"
            device (:class:`torch.device`): the hrnet (and yolo) inference will be run on this device.
                Default: torch.device("cpu")
        """

        self.c = c
        self.nof_joints = nof_joints
        self.checkpoint_path = checkpoint_path
        self.resolution = resolution  # in the form (height, width) as in the original implementation
        self.interpolation = interpolation
        self.multiperson = multiperson
        self.return_bounding_boxes = return_bounding_boxes
        self.max_batch_size = max_batch_size
        self.yolo_model_def = yolo_model_def
        self.yolo_class_path = yolo_class_path
        self.yolo_weights_path = yolo_weights_path
        self.device = device

        self.model = HRNet(c=c, nof_joints=nof_joints).to(device)
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        if 'model' in checkpoint:
            self.model.load_state_dict(checkpoint['model'])
        else:
            self.model.load_state_dict(checkpoint)
        self.model.eval()

        if not self.multiperson:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

        else:
            self.detector = YOLOv3(model_def=yolo_model_def,
                                   class_path=yolo_class_path,
                                   weights_path=yolo_weights_path,
                                   classes=('person', ),
                                   max_batch_size=self.max_batch_size,
                                   device=device)
            self.transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((self.resolution[0],
                                   self.resolution[1])),  # (height, width)
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])
Exemple #11
0
        data_parserv = Inria_v(config)
    if config["dataset_name"] == "ade20k":
        data_parserv = Ade20k_v(config)
    if config["dataset_name"] == "cityscape":
        data_parserv = Cityscape_v(config)

    repeatv = config["epoch"] * data_parserv.steps
    datasetv = tf.data.Dataset.from_generator(
        data_parserv.generator, (tf.float32, tf.float32),
        (tf.TensorShape([None, None, 3]), tf.TensorShape([None, None]))).batch(
            config["batch_size"], drop_remainder=False)

    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        if config["model_name"] == "hrnet":
            the_model = HRNet(configs=config)
        elif config["model_name"] == "vggunet":
            the_model = Vggunet(configs=config)
        elif config["model_name"] == "subject4":
            the_model = Subject4(configs=config)
        elif config["model_name"] == "bisenet":
            the_model = Bisenet(configs=config)

        print(the_model.model)
        dist_datasetv = mirrored_strategy.experimental_distribute_dataset(
            datasetv)

        the_model.miou_op.reset_states()

    saving_folder = Path(config["test"]["output_folder"])
    if not saving_folder.is_dir():
Exemple #12
0
import torch
import cv2
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import numpy as np
import warnings
from vidgear.gears import CamGear
from models.hrnet import HRNet

warnings.filterwarnings("ignore",category=UserWarning)

if __name__ == "__main__":

    model = HRNet(32, 17, 0.1)

    model.load_state_dict(
        torch.load('weights/mod_pose_hrnet_w32_256x192.pth')
    )
    print('ok!!')

    if torch.cuda.is_available() and False:
        torch.backends.cudnn.deterministic = True
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    print(device)

    model = model.to(device)

    video = CamGear(0).start()
Exemple #13
0
import os

print(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'hrnet'))
sys.path.append(os.path.join(os.getcwd(), 'hrnet', 'models'))

print(sys.path)

print(os.getcwd())
if not os.getcwd().endswith("hrnet"):
    os.chdir(os.getcwd() + "/hrnet")
    if os.getcwd() not in sys.path:
        sys.path.append(os.getcwd())

from models.hrnet import HRNet
model = HRNet(c=48, nof_joints=17)
if os.getcwd().endswith("hrnet"):
    os.chdir("".join(os.getcwd()[:-len("/hrnet")]))
print(os.getcwd())

import onnx
import onnxruntime

from prettytable import PrettyTable
import torch
import json
import cv2
import numpy as np

ONNX_PATH = "./my_model.onnx"
Exemple #14
0
    def __init__(self,
                 ds_test,
                 batch_size=1,
                 num_workers=4,
                 loss='JointsMSELoss',
                 checkpoint_path=None,
                 model_c=48,
                 model_nof_joints=17,
                 model_bn_momentum=0.1,
                 flip_test_images=True,
                 device=None):
        super(Test, self).__init__()

        self.ds_test = ds_test
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.loss = loss
        self.checkpoint_path = checkpoint_path
        self.model_c = model_c
        self.model_nof_joints = model_nof_joints
        self.model_bn_momentum = model_bn_momentum
        self.flip_test_images = flip_test_images
        self.epoch = 0

        # dispositivo torch
        if device is not None:
            self.device = device
        else:
            if torch.cuda.is_available():
                self.device = torch.device('cuda:0')
            else:
                self.device = torch.device('cpu')

        print(self.device)

        # cargando modelo
        self.model = HRNet(c=self.model_c,
                           nof_joints=self.model_nof_joints,
                           bn_momentum=self.model_bn_momentum).to(self.device)

        #Definiendo perdidas
        if self.loss == 'JointsMSELoss':
            self.loss_fn = JointsMSELoss().to(self.device)
        elif self.loss == 'JointsOHKMMSELoss':
            self.loss_fn = JointsOHKMMSELoss().to(self.device)
        else:
            raise NotImplementedError

        #cargar punto de control anterior
        if self.checkpoint_path is not None:
            print('Loading checkpoint %s...' % self.checkpoint_path)
            if os.path.isdir(self.checkpoint_path):
                path = os.path.join(self.checkpoint_path,
                                    'checkpoint_last.pth')
            else:
                path = self.checkpoint_path
            self.starting_epoch, self.model, _, self.params = load_checkpoint(
                path, self.model, device=self.device)
        else:
            raise ValueError('checkpoint_path is not defined')

        #conjunto de datos de prueba de carga
        self.dl_test = DataLoader(self.ds_test,
                                  batch_size=self.batch_size,
                                  shuffle=False,
                                  num_workers=self.num_workers)
        self.len_dl_test = len(self.dl_test)

        #inicializar variables
        self.mean_loss_test = 0.
        self.mean_acc_test = 0.
Exemple #15
0
class Test(object):
    def __init__(self,
                 ds_test,
                 batch_size=1,
                 num_workers=4,
                 loss='JointsMSELoss',
                 checkpoint_path=None,
                 model_c=48,
                 model_nof_joints=17,
                 model_bn_momentum=0.1,
                 flip_test_images=True,
                 device=None):
        super(Test, self).__init__()

        self.ds_test = ds_test
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.loss = loss
        self.checkpoint_path = checkpoint_path
        self.model_c = model_c
        self.model_nof_joints = model_nof_joints
        self.model_bn_momentum = model_bn_momentum
        self.flip_test_images = flip_test_images
        self.epoch = 0

        # dispositivo torch
        if device is not None:
            self.device = device
        else:
            if torch.cuda.is_available():
                self.device = torch.device('cuda:0')
            else:
                self.device = torch.device('cpu')

        print(self.device)

        # cargando modelo
        self.model = HRNet(c=self.model_c,
                           nof_joints=self.model_nof_joints,
                           bn_momentum=self.model_bn_momentum).to(self.device)

        #Definiendo perdidas
        if self.loss == 'JointsMSELoss':
            self.loss_fn = JointsMSELoss().to(self.device)
        elif self.loss == 'JointsOHKMMSELoss':
            self.loss_fn = JointsOHKMMSELoss().to(self.device)
        else:
            raise NotImplementedError

        #cargar punto de control anterior
        if self.checkpoint_path is not None:
            print('Loading checkpoint %s...' % self.checkpoint_path)
            if os.path.isdir(self.checkpoint_path):
                path = os.path.join(self.checkpoint_path,
                                    'checkpoint_last.pth')
            else:
                path = self.checkpoint_path
            self.starting_epoch, self.model, _, self.params = load_checkpoint(
                path, self.model, device=self.device)
        else:
            raise ValueError('checkpoint_path is not defined')

        #conjunto de datos de prueba de carga
        self.dl_test = DataLoader(self.ds_test,
                                  batch_size=self.batch_size,
                                  shuffle=False,
                                  num_workers=self.num_workers)
        self.len_dl_test = len(self.dl_test)

        #inicializar variables
        self.mean_loss_test = 0.
        self.mean_acc_test = 0.

    def _test(self):
        self.model.eval()
        with torch.no_grad():
            for step, (image, target, target_weight, joints_data) in enumerate(
                    tqdm(self.dl_test, desc='Test')):
                image = image.to(self.device)
                target = target.to(self.device)
                target_weight = target_weight.to(self.device)

                output = self.model(image)

                if self.flip_test_images:
                    image_flipped = flip_tensor(image, dim=-1)
                    output_flipped = self.model(image_flipped)

                    output_flipped = flip_back(output_flipped,
                                               self.ds_test.flip_pairs)

                    output = (output + output_flipped) * 0.5

                loss = self.loss_fn(output, target, target_weight)

                # Evaluar la precisión
                # Obtenga predicciones sobre la entrada

                accs, avg_acc, cnt, joints_preds, joints_target = \
                    self.ds_test.evaluate_accuracy(output, target)

                self.mean_loss_test += loss.item()
                self.mean_acc_test += avg_acc.item()
                if step == 0:
                    save_images(image, target, joints_target, output,
                                joints_preds, joints_data['joints_visibility'])

        self.mean_loss_test /= self.len_dl_test
        self.mean_acc_test /= self.len_dl_test

        print('\nTest: Loss %f - Accuracy %f' %
              (self.mean_loss_test, self.mean_acc_test))

    def run(self):
        """
        Runs the test.
        """

        print('\nTest started @ %s' %
              datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

        # empezar a probar
        print('\nLoaded checkpoint %s @ %s\nSaved epoch %d' %
              (self.checkpoint_path,
               datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
               self.starting_epoch))

        self.mean_loss_test = 0.
        self.mean_acc_test = 0.

        #Prueba

        self._test()

        print('\nTest ended @ %s' %
              datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
Exemple #16
0
    def __init__(self,
                 exp_name,
                 ds_train,
                 ds_val,
                 epochs=210,
                 batch_size=16,
                 num_workers=4,
                 loss='JointsMSELoss',
                 lr=0.001,
                 lr_decay=True,
                 lr_decay_steps=(170, 200),
                 lr_decay_gamma=0.1,
                 optimizer='Adam',
                 weight_decay=0.,
                 momentum=0.9,
                 nesterov=False,
                 pretrained_weight_path=None,
                 checkpoint_path=None,
                 log_path='./logs',
                 use_tensorboard=True,
                 model_c=48,
                 model_nof_joints=17,
                 model_bn_momentum=0.1,
                 flip_test_images=True,
                 device=None
                 ):
        """
        Initializes a new Train object.

        The log folder is created, the HRNet model is initialized and optional pre-trained weights or saved checkpoints
        are loaded.
        The DataLoaders, the loss function, and the optimizer are defined.

        Args:
            exp_name (str):  experiment name.
            ds_train (HumanPoseEstimationDataset): train dataset.
            ds_val (HumanPoseEstimationDataset): validation dataset.
            epochs (int): number of epochs.
                Default: 210
            batch_size (int): batch size.
                Default: 16
            num_workers (int): number of workers for each DataLoader
                Default: 4
            loss (str): loss function. Valid values are 'JointsMSELoss' and 'JointsOHKMMSELoss'.
                Default: "JointsMSELoss"
            lr (float): learning rate.
                Default: 0.001
            lr_decay (bool): learning rate decay.
                Default: True
            lr_decay_steps (tuple): steps for the learning rate decay scheduler.
                Default: (170, 200)
            lr_decay_gamma (float): scale factor for each learning rate decay step.
                Default: 0.1
            optimizer (str): network optimizer. Valid values are 'Adam' and 'SGD'.
                Default: "Adam"
            weight_decay (float): weight decay.
                Default: 0.
            momentum (float): momentum factor.
                Default: 0.9
            nesterov (bool): Nesterov momentum.
                Default: False
            pretrained_weight_path (str): path to pre-trained weights (such as weights from pre-train on imagenet).
                Default: None
            checkpoint_path (str): path to a previous checkpoint.
                Default: None
            log_path (str): path where tensorboard data and checkpoints will be saved.
                Default: "./logs"
            use_tensorboard (bool): enables tensorboard use.
                Default: True
            model_c (int): hrnet parameters - number of channels.
                Default: 48
            model_nof_joints (int): hrnet parameters - number of joints.
                Default: 17
            model_bn_momentum (float): hrnet parameters - path to the pretrained weights.
                Default: 0.1
            flip_test_images (bool): flip images during validating.
                Default: True
            device (torch.device): device to be used (default: cuda, if available).
                Default: None
        """
        super(Train, self).__init__()

        self.exp_name = exp_name
        self.ds_train = ds_train
        self.ds_val = ds_val
        self.epochs = epochs
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.loss = loss
        self.lr = lr
        self.lr_decay = lr_decay
        self.lr_decay_steps = lr_decay_steps
        self.lr_decay_gamma = lr_decay_gamma
        self.optimizer = optimizer
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.nesterov = nesterov
        self.pretrained_weight_path = pretrained_weight_path
        self.checkpoint_path = checkpoint_path
        self.log_path = os.path.join(log_path, self.exp_name)
        self.use_tensorboard = use_tensorboard
        self.model_c = model_c
        self.model_nof_joints = model_nof_joints
        self.model_bn_momentum = model_bn_momentum
        self.flip_test_images = flip_test_images
        self.epoch = 0

        # torch device
        if device is not None:
            self.device = device
        else:
            if torch.cuda.is_available():
                self.device = torch.device('cuda:0')
            else:
                self.device = torch.device('cpu')

        print(self.device)

        os.makedirs(self.log_path, 0o755, exist_ok=False)  # exist_ok=False to avoid overwriting
        if self.use_tensorboard:
            self.summary_writer = tb.SummaryWriter(self.log_path)

        #
        # write all experiment parameters in parameters.txt and in tensorboard text field
        self.parameters = [x + ': ' + str(y) + '\n' for x, y in locals().items()]
        with open(os.path.join(self.log_path, 'parameters.txt'), 'w') as fd:
            fd.writelines(self.parameters)
        if self.use_tensorboard:
            self.summary_writer.add_text('parameters', '\n'.join(self.parameters))

        #
        # load model
        self.model = HRNet(c=self.model_c, nof_joints=self.model_nof_joints,
                           bn_momentum=self.model_bn_momentum).to(self.device)

        #
        # define loss and optimizers
        if self.loss == 'JointsMSELoss':
            self.loss_fn = JointsMSELoss().to(self.device)
        elif self.loss == 'JointsOHKMMSELoss':
            self.loss_fn = JointsOHKMMSELoss().to(self.device)
        else:
            raise NotImplementedError

        if optimizer == 'SGD':
            self.optim = SGD(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay,
                             momentum=self.momentum, nesterov=self.nesterov)
        elif optimizer == 'Adam':
            self.optim = Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        else:
            raise NotImplementedError

        #
        # load pre-trained weights (such as those pre-trained on imagenet)
        if self.pretrained_weight_path is not None:
            self.model.load_state_dict(torch.load(self.pretrained_weight_path, map_location=self.device), strict=True)
            print('Pre-trained weights loaded.')

        #
        # load previous checkpoint
        if self.checkpoint_path is not None:
            print('Loading checkpoint %s...' % self.checkpoint_path)
            if os.path.isdir(self.checkpoint_path):
                path = os.path.join(self.checkpoint_path, 'checkpoint_last.pth')
            else:
                path = self.checkpoint_path
            self.starting_epoch, self.model, self.optim, self.params = load_checkpoint(path, self.model, self.optim,
                                                                                       self.device)
        else:
            self.starting_epoch = 0

        if lr_decay:
            self.lr_scheduler = MultiStepLR(self.optim, list(self.lr_decay_steps), gamma=self.lr_decay_gamma,
                                            last_epoch=self.starting_epoch if self.starting_epoch else -1)

        #
        # load train and val datasets
        self.dl_train = DataLoader(self.ds_train, batch_size=self.batch_size, shuffle=True,
                                   num_workers=self.num_workers, drop_last=True)
        self.len_dl_train = len(self.dl_train)

        # dl_val = DataLoader(self.ds_val, batch_size=1, shuffle=False, num_workers=num_workers)
        self.dl_val = DataLoader(self.ds_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        self.len_dl_val = len(self.dl_val)

        #
        # initialize variables
        self.mean_loss_train = 0.
        self.mean_acc_train = 0.
        self.mean_loss_val = 0.
        self.mean_acc_val = 0.
        self.mean_mAP_val = 0.

        self.best_loss = None
        self.best_acc = None
        self.best_mAP = None
Exemple #17
0
class SimpleHRNet:
    """
    SimpleHRNet class.
    The class provides a simple and customizable method to load the HRNet network, load the official pre-trained
    weights, and predict the human pose on single images.
    Multi-person support with the YOLOv3 detector is also included (and enabled by default).
    """
    def __init__(
        self,
        c,
        nof_joints,
        checkpoint_path,
        resolution=(384, 288),
        interpolation=cv2.INTER_CUBIC,
        multiperson=True,
        max_batch_size=32,
        yolo_model_def="./models/detectors/yolo/config/yolov3.cfg",
        yolo_class_path="./models/detectors/yolo/data/coco.names",
        yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights",
        device=torch.device("cpu")):
        """
        Initializes a new SimpleHRNet object.
        HRNet (and YOLOv3) are initialized on the torch.device(``device``) and
        its (their) pretrained weights will be loaded from disk.

        Arguments:
            c (int): number of channels.
            nof_joints (int): number of joints.
            checkpoint_path (str): hrnet checkpoint path.
            resolution (tuple): hrnet input resolution - format: (height, width).
                Default: ``(384, 288)``
            interpolation (int): opencv interpolation algorithm.
                Default: ``cv2.INTER_CUBIC``
            multiperson (bool): if ``True``, multiperson detection will be enabled.
                This requires the use of a people detector (like YOLOv3).
                Default: ``True``
            max_batch_size (int): maximum batch size used in hrnet inference.
                Useless without multiperson=True.
                Default: ``16``
            yolo_model_def (str): path to yolo model definition file.
                Default: ``"./models/detectors/yolo/config/yolov3.cfg"``
            yolo_class_path (str): path to yolo class definition file.
                Default: ``"./models/detectors/yolo/data/coco.names"``
            yolo_weights_path (str): path to yolo pretrained weights file.
                Default: ``"./models/detectors/yolo/weights/yolov3.weights.cfg"``
            device (:class:`torch.device`): the hrnet (and yolo) inference will be run on this device.
                Default: ``torch.device("cpu")``
        """

        self.c = c
        self.nof_joints = nof_joints
        self.checkpoint_path = checkpoint_path
        self.resolution = resolution  # in the form (height, width) as in the original implementation
        self.interpolation = interpolation
        self.multiperson = multiperson
        self.max_batch_size = max_batch_size
        self.yolo_model_def = yolo_model_def
        self.yolo_class_path = yolo_class_path
        self.yolo_weights_path = yolo_weights_path
        self.device = device

        self.model = HRNet(c=c, nof_joints=nof_joints).to(device)
        self.model.load_state_dict(torch.load(checkpoint_path))
        self.model.eval()

        if not self.multiperson:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

        else:
            self.detector = YOLOv3(model_def=yolo_model_def,
                                   class_path=yolo_class_path,
                                   weights_path=yolo_weights_path,
                                   classes=('person', ),
                                   device=device)
            self.transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((self.resolution[0],
                                   self.resolution[1])),  # (height, width)
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

        pass

    def predict(self, image):
        """
        Predicts the human pose on a single image.

        Arguments:
            image (:class:`np.ndarray`):
                the image on which the human pose will be estimated.

                Image must be in the opencv format, i.e. shape=(height, width, BGR color channel).

        Returns:
            `:class:np.ndarray`:
                a numpy array containing human joints for each (detected) person.

                Format: shape=(# of people, # of joints (nof_joints), 3);  dtype=(np.float32).

                Each joint has 3 values: (x position, y position, joint confidence)
        """
        if not self.multiperson:
            old_res = image.shape
            if self.resolution is not None:
                image = cv2.resize(
                    image,
                    (self.resolution[1],
                     self.resolution[0]),  # (width, height)
                    interpolation=self.interpolation)

            images = self.transform(cv2.cvtColor(
                image, cv2.COLOR_BGR2RGB)).unsqueeze(dim=0)
            boxes = np.asarray([[0, 0, old_res[1], old_res[0]]],
                               dtype=np.float32)  # [x1, y1, x2, y2]

        else:
            detections = self.detector.predict_single(image)

            boxes = []
            if detections is not None:
                images = torch.empty((len(detections), 3, self.resolution[0],
                                      self.resolution[1]))  # (height, width)
                for i, (x1, y1, x2, y2, conf, cls_conf,
                        cls_pred) in enumerate(detections):
                    x1 = int(round(x1.item()))
                    x2 = int(round(x2.item()))
                    y1 = int(round(y1.item()))
                    y2 = int(round(y2.item()))

                    boxes.append([x1, y1, x2, y2])
                    images[i] = self.transform(image[y1:y2, x1:x2, ::-1])

            else:
                images = torch.empty((0, 3, self.resolution[0],
                                      self.resolution[1]))  # (height, width)

            boxes = np.asarray(boxes, dtype=np.int32)

        if images.shape[0] > 0:
            images = images.to(self.device)

            with torch.no_grad():
                if len(images) <= self.max_batch_size:
                    out = self.model(images)

                else:
                    out = torch.empty(
                        (images.shape[0], self.nof_joints,
                         self.resolution[0] // 4,
                         self.resolution[1] // 4)).to(self.device)
                    for i in range(0, len(images), self.max_batch_size):
                        out[i:i + self.max_batch_size] = self.model(
                            images[i:i + self.max_batch_size])

            out = out.detach().cpu().numpy()
            pts = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)
            # For each human, for each joint: x, y, confidence
            for i, human in enumerate(out):
                for j, joint in enumerate(human):
                    pt = np.unravel_index(np.argmax(joint),
                                          shape=(self.resolution[0] // 4,
                                                 self.resolution[1] // 4))
                    # 0: pt_x / (width // 4) * (bb_x2 - bb_x1) + bb_x1
                    # 1: pt_y / (height // 4) * (bb_y2 - bb_y1) + bb_y1
                    # 2: confidences
                    pts[i, j, 0] = pt[0] * 1. / (self.resolution[0] // 4) * (
                        boxes[i][3] - boxes[i][1]) + boxes[i][1]
                    pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (
                        boxes[i][2] - boxes[i][0]) + boxes[i][0]
                    pts[i, j, 2] = joint[pt]

        else:
            pts = np.empty((0, 0, 3), dtype=np.float32)

        return pts
Exemple #18
0
class Test(object):
    """
    Test class.

    The class provides a basic tool for testing HRNet checkpoints.

    The only method supposed to be directly called is `run()`.
    """
    def __init__(self,
                 ds_test,
                 batch_size=1,
                 num_workers=4,
                 loss='JointsMSELoss',
                 checkpoint_path="./weights/pose_hrnet_w48_384x288.pth",
                 model_c=48,
                 model_nof_joints=17,
                 model_bn_momentum=0.1,
                 flip_test_images=True,
                 device=None):
        """
        Initializes a new Test object.

        The HRNet model is initialized and the saved checkpoint is loaded.
        The DataLoader and the loss function are defined.

        Args:
            ds_test (HumanPoseEstimationDataset): test dataset.
            batch_size (int): batch size.
                Default: 1
            num_workers (int): number of workers for each DataLoader
                Default: 4
            loss (str): loss function. Valid values are 'JointsMSELoss' and 'JointsOHKMMSELoss'.
                Default: "JointsMSELoss"
            checkpoint_path (str): path to a previous checkpoint.
                Default: None
            model_c (int): hrnet parameters - number of channels.
                Default: 48
            model_nof_joints (int): hrnet parameters - number of joints.
                Default: 17
            model_bn_momentum (float): hrnet parameters - path to the pretrained weights.
                Default: 0.1
            flip_test_images (bool): flip images during validating.
                Default: True
            device (torch.device): device to be used (default: cuda, if available).
                Default: None
        """
        super(Test, self).__init__()

        self.ds_test = ds_test
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.loss = loss
        self.checkpoint_path = checkpoint_path
        self.model_c = model_c
        self.model_nof_joints = model_nof_joints
        self.model_bn_momentum = model_bn_momentum
        self.flip_test_images = flip_test_images
        self.epoch = 0

        # torch device
        if device is not None:
            self.device = device
        else:
            if torch.cuda.is_available():
                self.device = torch.device('cuda:0')
            else:
                self.device = torch.device('cpu')

        print(self.device)

        #
        # load model
        self.model = HRNet(c=self.model_c,
                           nof_joints=self.model_nof_joints,
                           bn_momentum=self.model_bn_momentum).to(self.device)

        #
        # define loss
        if self.loss == 'JointsMSELoss':
            self.loss_fn = JointsMSELoss().to(self.device)
        elif self.loss == 'JointsOHKMMSELoss':
            self.loss_fn = JointsOHKMMSELoss().to(self.device)
        else:
            raise NotImplementedError

        #
        # load previous checkpoint
        if os.path.basename(
                self.checkpoint_path) == "pose_hrnet_w48_384x288.pth":
            self.starting_epoch = 210  # 1?
            #self.params = None

        elif self.checkpoint_path is not None:
            print('Loading checkpoint %s...' % self.checkpoint_path)
            if os.path.isdir(self.checkpoint_path):
                path = os.path.join(self.checkpoint_path,
                                    'checkpoint_last.pth')
            else:
                path = self.checkpoint_path
            self.starting_epoch, self.model, _, self.params = load_checkpoint(
                path, self.model, device=self.device)

        else:
            raise ValueError('checkpoint_path is not defined')

        #
        # load test dataset
        self.dl_test = DataLoader(self.ds_test,
                                  batch_size=self.batch_size,
                                  shuffle=False,
                                  num_workers=self.num_workers)
        self.len_dl_test = len(self.dl_test)
        print("print len_dl_test: ", self.len_dl_test)

        #
        # initialize variables
        self.mean_loss_test = 0.
        self.mean_acc_test = 0.

    def _test(self):
        self.model.eval()
        count_test = 0

        with torch.no_grad():
            for step, (image, target, target_weight, joints_data) in enumerate(
                    tqdm(self.dl_test, desc='Test')):
                image = image.to(self.device)
                target = target.to(self.device)
                target_weight = target_weight.to(self.device)

                output = self.model(image)

                if self.flip_test_images:
                    image_flipped = flip_tensor(image, dim=-1)
                    output_flipped = self.model(image_flipped)

                    output_flipped = flip_back(output_flipped,
                                               self.ds_test.flip_pairs)

                    output = (output + output_flipped) * 0.5

                loss = self.loss_fn(output, target, target_weight)

                # Evaluate accuracy
                # Get predictions on the input
                accs, avg_acc, cnt, joints_preds, joints_target = \
                    self.ds_test.evaluate_accuracy(output, target)

                if avg_acc == 0: continue
                self.mean_loss_test += loss.item()
                self.mean_acc_test += avg_acc.item()

                count_test += 1
                if step == 0:
                    save_images(image, target, joints_target, output,
                                joints_preds, joints_data['joints_visibility'])

        self.len_dl_test = count_test
        print("count_test:", count_test, " self.mean_acc_test:",
              self.mean_acc_test)
        self.mean_loss_test /= self.len_dl_test
        self.mean_acc_test /= self.len_dl_test

        print('\nTest: Loss %f - Accuracy %f' %
              (self.mean_loss_test, self.mean_acc_test))

    def run(self):
        """
        Runs the test.
        """

        print('\nTest started @ %s' %
              datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

        # start testing
        print('\nLoaded checkpoint %s @ %s\nSaved epoch %d' %
              (self.checkpoint_path,
               datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
               self.starting_epoch))

        self.mean_loss_test = 0.
        self.mean_acc_test = 0.

        #
        # Test

        self._test()

        print('\nTest ended @ %s' %
              datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    def __init__(self,
                 c,
                 nof_joints,
                 checkpoint_path,
                 model_name='HRNet',
                 resolution=(384, 288),
                 interpolation=cv2.INTER_CUBIC,
                 return_bounding_boxes=False,
                 max_batch_size=32,
                 device=torch.device("cpu")):
        """
        Initializes a new SimpleHRNet object.
        HRNet (and YOLOv3) are initialized on the torch.device("device") and
        its (their) pre-trained weights will be loaded from disk.

        Args:
            c (int): number of channels (when using HRNet model) or resnet size (when using PoseResNet model).
            nof_joints (int): number of joints.
            checkpoint_path (str): path to an official hrnet checkpoint or a checkpoint obtained with `train_coco.py`.
            model_name (str): model name (HRNet or PoseResNet).
                Valid names for HRNet are: `HRNet`, `hrnet`
                Valid names for PoseResNet are: `PoseResNet`, `poseresnet`, `ResNet`, `resnet`
                Default: "HRNet"
            resolution (tuple): hrnet input resolution - format: (height, width).
                Default: (384, 288)
            interpolation (int): opencv interpolation algorithm.
                Default: cv2.INTER_CUBIC
            multiperson (bool): if True, multiperson detection will be enabled.
                This requires the use of a people detector (like YOLOv3).
                Default: True
            return_bounding_boxes (bool): if True, bounding boxes will be returned along with poses by self.predict.
                Default: False
            max_batch_size (int): maximum batch size used in hrnet inference.
                Useless without multiperson=True.
                Default: 16
            yolo_model_def (str): path to yolo model definition file.
                Default: "./models/detectors/yolo/config/yolov3.cfg"
            yolo_class_path (str): path to yolo class definition file.
                Default: "./models/detectors/yolo/data/coco.names"
            yolo_weights_path (str): path to yolo pretrained weights file.
                Default: "./models/detectors/yolo/weights/yolov3.weights.cfg"
            device (:class:`torch.device`): the hrnet (and yolo) inference will be run on this device.
                Default: torch.device("cpu")
        """

        self.c = c
        self.nof_joints = nof_joints
        self.detector_root = '/home/mmlab/CCTV_Server/models/detectors'
        self.checkpoint_path = checkpoint_path
        self.model_name = model_name
        self.resolution = resolution  # in the form (height, width) as in the original implementation
        self.interpolation = interpolation
        self.return_bounding_boxes = return_bounding_boxes
        self.max_batch_size = max_batch_size
        self.device = device
        self.previous_out_shape = None
        self.heatmap_club_head_cnt = 0
        self.heatmap_left_wrist_cnt = 0
        self.heatmap_club_head_dir = '/home/mmlab/CCTV_Server/golf/heatmap_club_head'
        self.heatmap_left_wrist_dir = '/home/mmlab/CCTV_Server/golf/heatmap_left_wrist'
        makedir(self.heatmap_club_head_dir)
        makedir(self.heatmap_left_wrist_dir)

        if model_name in ('HRNet', 'hrnet'):
            self.model = HRNet(c=c, nof_joints=nof_joints)
        elif model_name in ('PoseResNet', 'poseresnet', 'ResNet', 'resnet'):
            self.model = PoseResNet(resnet_size=c, nof_joints=nof_joints)
        else:
            raise ValueError('Wrong model name.')

        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        if 'model' in checkpoint:
            self.model.load_state_dict(checkpoint['model'])
        else:
            self.model.load_state_dict(checkpoint)

        if 'cuda' in str(self.device):
            print("device: 'cuda' - ", end="")

            if 'cuda' == str(self.device):
                # if device is set to 'cuda', all available GPUs will be used
                print("%d GPU(s) will be used" % torch.cuda.device_count())
                device_ids = None
            else:
                # if device is set to 'cuda:IDS', only that/those device(s) will be used
                print("GPU(s) '%s' will be used" % str(self.device))
                device_ids = [int(x) for x in str(self.device)[5:].split(',')]
            print(device_ids)

            self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
        elif 'cpu' == str(self.device):
            print("device: 'cpu'")
        else:
            raise ValueError('Wrong device name.')

        self.model = self.model.to(device)
        self.model.eval()

        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((self.resolution[0], self.resolution[1])),  # (height, width)
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
Exemple #20
0
    def __init__(self,
                 ds_test,
                 batch_size=1,
                 num_workers=4,
                 loss='JointsMSELoss',
                 checkpoint_path="./weights/pose_hrnet_w48_384x288.pth",
                 model_c=48,
                 model_nof_joints=17,
                 model_bn_momentum=0.1,
                 flip_test_images=True,
                 device=None):
        """
        Initializes a new Test object.

        The HRNet model is initialized and the saved checkpoint is loaded.
        The DataLoader and the loss function are defined.

        Args:
            ds_test (HumanPoseEstimationDataset): test dataset.
            batch_size (int): batch size.
                Default: 1
            num_workers (int): number of workers for each DataLoader
                Default: 4
            loss (str): loss function. Valid values are 'JointsMSELoss' and 'JointsOHKMMSELoss'.
                Default: "JointsMSELoss"
            checkpoint_path (str): path to a previous checkpoint.
                Default: None
            model_c (int): hrnet parameters - number of channels.
                Default: 48
            model_nof_joints (int): hrnet parameters - number of joints.
                Default: 17
            model_bn_momentum (float): hrnet parameters - path to the pretrained weights.
                Default: 0.1
            flip_test_images (bool): flip images during validating.
                Default: True
            device (torch.device): device to be used (default: cuda, if available).
                Default: None
        """
        super(Test, self).__init__()

        self.ds_test = ds_test
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.loss = loss
        self.checkpoint_path = checkpoint_path
        self.model_c = model_c
        self.model_nof_joints = model_nof_joints
        self.model_bn_momentum = model_bn_momentum
        self.flip_test_images = flip_test_images
        self.epoch = 0

        # torch device
        if device is not None:
            self.device = device
        else:
            if torch.cuda.is_available():
                self.device = torch.device('cuda:0')
            else:
                self.device = torch.device('cpu')

        print(self.device)

        #
        # load model
        self.model = HRNet(c=self.model_c,
                           nof_joints=self.model_nof_joints,
                           bn_momentum=self.model_bn_momentum).to(self.device)

        #
        # define loss
        if self.loss == 'JointsMSELoss':
            self.loss_fn = JointsMSELoss().to(self.device)
        elif self.loss == 'JointsOHKMMSELoss':
            self.loss_fn = JointsOHKMMSELoss().to(self.device)
        else:
            raise NotImplementedError

        #
        # load previous checkpoint
        if os.path.basename(
                self.checkpoint_path) == "pose_hrnet_w48_384x288.pth":
            self.starting_epoch = 210  # 1?
            #self.params = None

        elif self.checkpoint_path is not None:
            print('Loading checkpoint %s...' % self.checkpoint_path)
            if os.path.isdir(self.checkpoint_path):
                path = os.path.join(self.checkpoint_path,
                                    'checkpoint_last.pth')
            else:
                path = self.checkpoint_path
            self.starting_epoch, self.model, _, self.params = load_checkpoint(
                path, self.model, device=self.device)

        else:
            raise ValueError('checkpoint_path is not defined')

        #
        # load test dataset
        self.dl_test = DataLoader(self.ds_test,
                                  batch_size=self.batch_size,
                                  shuffle=False,
                                  num_workers=self.num_workers)
        self.len_dl_test = len(self.dl_test)
        print("print len_dl_test: ", self.len_dl_test)

        #
        # initialize variables
        self.mean_loss_test = 0.
        self.mean_acc_test = 0.
class SimpleHRNet:
    """
    SimpleHRNet class.

    The class provides a simple and customizable method to load the HRNet network, load the official pre-trained
    weights, and predict the human pose on single images.
    Multi-person support with the YOLOv3 detector is also included (and enabled by default).
    """
    def __init__(
        self,
        c,
        nof_joints,
        checkpoint_path,
        resolution=(384, 288),
        interpolation=cv2.INTER_CUBIC,
        multiperson=True,
        return_bounding_boxes=False,
        max_batch_size=32,
        yolo_model_def="./models/detectors/yolo/config/yolov3.cfg",
        yolo_class_path="./models/detectors/yolo/data/coco.names",
        yolo_weights_path="./models/detectors/yolo/weights/yolov3.weights",
        device=torch.device("cpu")):
        """
        Initializes a new SimpleHRNet object.
        HRNet (and YOLOv3) are initialized on the torch.device("device") and
        its (their) pre-trained weights will be loaded from disk.

        Args:
            c (int): number of channels.
            nof_joints (int): number of joints.
            checkpoint_path (str): path to an official hrnet checkpoint or a checkpoint obtained with `train_coco.py`.
            resolution (tuple): hrnet input resolution - format: (height, width).
                Default: (384, 288)
            interpolation (int): opencv interpolation algorithm.
                Default: cv2.INTER_CUBIC
            multiperson (bool): if True, multiperson detection will be enabled.
                This requires the use of a people detector (like YOLOv3).
                Default: True
            return_bounding_boxes (bool): if True, bounding boxes will be returned along with poses by self.predict.
                Default: False
            max_batch_size (int): maximum batch size used in hrnet inference.
                Useless without multiperson=True.
                Default: 16
            yolo_model_def (str): path to yolo model definition file.
                Default: "./models/detectors/yolo/config/yolov3.cfg"
            yolo_class_path (str): path to yolo class definition file.
                Default: "./models/detectors/yolo/data/coco.names"
            yolo_weights_path (str): path to yolo pretrained weights file.
                Default: "./models/detectors/yolo/weights/yolov3.weights.cfg"
            device (:class:`torch.device`): the hrnet (and yolo) inference will be run on this device.
                Default: torch.device("cpu")
        """

        self.c = c
        self.nof_joints = nof_joints
        self.checkpoint_path = checkpoint_path
        self.resolution = resolution  # in the form (height, width) as in the original implementation
        self.interpolation = interpolation
        self.multiperson = multiperson
        self.return_bounding_boxes = return_bounding_boxes
        self.max_batch_size = max_batch_size
        self.yolo_model_def = yolo_model_def
        self.yolo_class_path = yolo_class_path
        self.yolo_weights_path = yolo_weights_path
        self.device = device

        self.model = HRNet(c=c, nof_joints=nof_joints).to(device)
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        if 'model' in checkpoint:
            self.model.load_state_dict(checkpoint['model'])
        else:
            self.model.load_state_dict(checkpoint)
        self.model.eval()

        if not self.multiperson:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

        else:
            self.detector = YOLOv3(model_def=yolo_model_def,
                                   class_path=yolo_class_path,
                                   weights_path=yolo_weights_path,
                                   classes=('person', ),
                                   max_batch_size=self.max_batch_size,
                                   device=device)
            self.transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((self.resolution[0],
                                   self.resolution[1])),  # (height, width)
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])

    def predict(self, image):
        """
        Predicts the human pose on a single image or a stack of n images.

        Args:
            image (:class:`np.ndarray`):
                the image(s) on which the human pose will be estimated.

                image is expected to be in the opencv format.
                image can be:
                    - a single image with shape=(height, width, BGR color channel)
                    - a stack of n images with shape=(n, height, width, BGR color channel)

        Returns:
            :class:`np.ndarray`:
                a numpy array containing human joints for each (detected) person.

                Format:
                    if image is a single image:
                        shape=(# of people, # of joints (nof_joints), 3);  dtype=(np.float32).
                    if image is a stack of n images:
                        list of n np.ndarrays with
                        shape=(# of people, # of joints (nof_joints), 3);  dtype=(np.float32).

                Each joint has 3 values: (x position, y position, joint confidence).

                If self.return_bounding_boxes, the class returns a list with (bounding boxes, human joints)
        """
        if len(image.shape) == 3:
            return self._predict_single(image)
        elif len(image.shape) == 4:
            return self._predict_batch(image)
        else:
            raise ValueError('Wrong image format.')

    def _predict_single(self, image):
        if not self.multiperson:
            old_res = image.shape
            if self.resolution is not None:
                image = cv2.resize(
                    image,
                    (self.resolution[1],
                     self.resolution[0]),  # (width, height)
                    interpolation=self.interpolation)

            images = self.transform(cv2.cvtColor(
                image, cv2.COLOR_BGR2RGB)).unsqueeze(dim=0)
            boxes = np.asarray([[0, 0, old_res[1], old_res[0]]],
                               dtype=np.float32)  # [x1, y1, x2, y2]

        else:
            detections = self.detector.predict_single(image)

            boxes = []
            if detections is not None:
                images = torch.empty((len(detections), 3, self.resolution[0],
                                      self.resolution[1]))  # (height, width)
                for i, (x1, y1, x2, y2, conf, cls_conf,
                        cls_pred) in enumerate(detections):
                    x1 = int(round(x1.item()))
                    x2 = int(round(x2.item()))
                    y1 = int(round(y1.item()))
                    y2 = int(round(y2.item()))

                    # Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14)
                    correction_factor = self.resolution[0] / self.resolution[
                        1] * (x2 - x1) / (y2 - y1)
                    if correction_factor > 1:
                        # increase y side
                        center = y1 + (y2 - y1) // 2
                        length = int(round((y2 - y1) * correction_factor))
                        y1 = max(0, center - length // 2)
                        y2 = min(image.shape[0], center + length // 2)
                    elif correction_factor < 1:
                        # increase x side
                        center = x1 + (x2 - x1) // 2
                        length = int(round((x2 - x1) * 1 / correction_factor))
                        x1 = max(0, center - length // 2)
                        x2 = min(image.shape[1], center + length // 2)

                    boxes.append([x1, y1, x2, y2])
                    images[i] = self.transform(image[y1:y2, x1:x2, ::-1])

            else:
                images = torch.empty((0, 3, self.resolution[0],
                                      self.resolution[1]))  # (height, width)

            boxes = np.asarray(boxes, dtype=np.int32)

        if images.shape[0] > 0:
            images = images.to(self.device)

            with torch.no_grad():
                if len(images) <= self.max_batch_size:
                    out = self.model(images)

                else:
                    out = torch.empty(
                        (images.shape[0], self.nof_joints,
                         self.resolution[0] // 4, self.resolution[1] // 4),
                        device=self.device)
                    for i in range(0, len(images), self.max_batch_size):
                        out[i:i + self.max_batch_size] = self.model(
                            images[i:i + self.max_batch_size])

            out = out.detach().cpu().numpy()
            pts = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)
            # For each human, for each joint: x, y, confidence
            for i, human in enumerate(out):
                for j, joint in enumerate(human):
                    pt = np.unravel_index(
                        np.argmax(joint),
                        (self.resolution[0] // 4, self.resolution[1] // 4))
                    # 0: pt_x / (width // 4) * (bb_x2 - bb_x1) + bb_x1
                    # 1: pt_y / (height // 4) * (bb_y2 - bb_y1) + bb_y1
                    # 2: confidences
                    pts[i, j, 0] = pt[0] * 1. / (self.resolution[0] // 4) * (
                        boxes[i][3] - boxes[i][1]) + boxes[i][1]
                    pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (
                        boxes[i][2] - boxes[i][0]) + boxes[i][0]
                    pts[i, j, 2] = joint[pt]

        else:
            pts = np.empty((0, 0, 3), dtype=np.float32)

        if self.return_bounding_boxes:
            return boxes, pts
        else:
            return pts

    def _predict_batch(self, images):
        if not self.multiperson:
            old_res = images[0].shape

            if self.resolution is not None:
                images_tensor = torch.empty(images.shape[0], 3,
                                            self.resolution[0],
                                            self.resolution[1])
            else:
                images_tensor = torch.empty(images.shape[0], 3,
                                            images.shape[1], images.shape[2])

            for i, image in enumerate(images):
                if self.resolution is not None:
                    image = cv2.resize(
                        image,
                        (self.resolution[1],
                         self.resolution[0]),  # (width, height)
                        interpolation=self.interpolation)

                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

                images_tensor[i] = self.transform(image)

            images = images_tensor
            boxes = np.repeat(np.asarray([[0, 0, old_res[1], old_res[0]]],
                                         dtype=np.float32),
                              len(images),
                              axis=0)  # [x1, y1, x2, y2]

        else:
            image_detections = self.detector.predict(images)

            boxes = []
            images_tensor = []
            for d, detections in enumerate(image_detections):
                image = images[d]
                boxes_image = []
                if detections is not None:
                    images_tensor_image = torch.empty(
                        (len(detections), 3, self.resolution[0],
                         self.resolution[1]))  # (height, width)
                    for i, (x1, y1, x2, y2, conf, cls_conf,
                            cls_pred) in enumerate(detections):
                        x1 = int(round(x1.item()))
                        x2 = int(round(x2.item()))
                        y1 = int(round(y1.item()))
                        y2 = int(round(y2.item()))

                        # Adapt detections to match HRNet input aspect ratio (as suggested by xtyDoge in issue #14)
                        correction_factor = self.resolution[
                            0] / self.resolution[1] * (x2 - x1) / (y2 - y1)
                        if correction_factor > 1:
                            # increase y side
                            center = y1 + (y2 - y1) // 2
                            length = int(round((y2 - y1) * correction_factor))
                            y1 = max(0, center - length // 2)
                            y2 = min(image.shape[0], center + length // 2)
                        elif correction_factor < 1:
                            # increase x side
                            center = x1 + (x2 - x1) // 2
                            length = int(
                                round((x2 - x1) * 1 / correction_factor))
                            x1 = max(0, center - length // 2)
                            x2 = min(image.shape[1], center + length // 2)

                        boxes_image.append([x1, y1, x2, y2])
                        images_tensor_image[i] = self.transform(
                            image[y1:y2, x1:x2, ::-1])

                else:
                    images_tensor_image = torch.empty(
                        (0, 3, self.resolution[0],
                         self.resolution[1]))  # (height, width)

                # stack all images and boxes in single lists
                images_tensor.extend(images_tensor_image)
                boxes.extend(boxes_image)

            # convert lists into tensors/np.ndarrays
            images = torch.tensor(np.stack(images_tensor))
            boxes = np.asarray(boxes, dtype=np.int32)

        images = images.to(self.device)

        with torch.no_grad():
            if len(images) <= self.max_batch_size:
                out = self.model(images)

            else:
                out = torch.empty(
                    (images.shape[0], self.nof_joints, self.resolution[0] // 4,
                     self.resolution[1] // 4),
                    device=self.device)
                for i in range(0, len(images), self.max_batch_size):
                    out[i:i + self.max_batch_size] = self.model(
                        images[i:i + self.max_batch_size])

        out = out.detach().cpu().numpy()
        pts = np.empty((out.shape[0], out.shape[1], 3), dtype=np.float32)
        # For each human, for each joint: x, y, confidence
        for i, human in enumerate(out):
            for j, joint in enumerate(human):
                pt = np.unravel_index(
                    np.argmax(joint),
                    (self.resolution[0] // 4, self.resolution[1] // 4))
                # 0: pt_x / (width // 4) * (bb_x2 - bb_x1) + bb_x1
                # 1: pt_y / (height // 4) * (bb_y2 - bb_y1) + bb_y1
                # 2: confidences
                pts[i, j, 0] = pt[0] * 1. / (self.resolution[0] // 4) * (
                    boxes[i][3] - boxes[i][1]) + boxes[i][1]
                pts[i, j, 1] = pt[1] * 1. / (self.resolution[1] // 4) * (
                    boxes[i][2] - boxes[i][0]) + boxes[i][0]
                pts[i, j, 2] = joint[pt]

        if self.multiperson:
            # re-add the removed batch axis (n)
            pts_batch = []
            index = 0
            for detections in image_detections:
                if detections is not None:
                    pts_batch.append(pts[index:index + len(detections)])
                    index += len(detections)
                else:
                    pts_batch.append(
                        np.zeros((0, self.nof_joints, 3), dtype=np.float32))
            pts = pts_batch

        else:
            pts = np.expand_dims(pts, axis=1)

        if self.return_bounding_boxes:
            return boxes, pts
        else:
            return pts
Exemple #22
0
class Train(object):

    #Clase de entrenamiento.
    #porporciona herrmientas basicas para entrenar HRNet

    def __init__(self,
                 exp_name,
                 ds_train,
                 ds_val,
                 epochs=210,
                 batch_size=16,
                 num_workers=4,
                 loss='JointsMSELoss',
                 lr=0.001,
                 lr_decay=True,
                 lr_decay_steps=(170, 200),
                 lr_decay_gamma=0.1,
                 optimizer='Adam',
                 weight_decay=0.00001,
                 momentum=0.9,
                 nesterov=False,
                 pretrained_weight_path=None,
                 checkpoint_path=None,
                 log_path='./logs',
                 use_tensorboard=True,
                 model_c=48,
                 model_nof_joints=17,
                 model_bn_momentum=0.1,
                 flip_test_images=True,
                 device=None):
        """
        Inicializa el nuevo objeto Train
        Se crea el folder de logs, se inicializa el modelo HRNet y se determinan dimensiones pre entrenadas o puntos
        de control guardos son cargados
        """
        super(Train, self).__init__()

        self.exp_name = exp_name
        self.ds_train = ds_train
        self.ds_val = ds_val
        self.epochs = epochs
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.loss = loss
        self.lr = lr
        self.lr_decay = lr_decay
        self.lr_decay_steps = lr_decay_steps
        self.lr_decay_gamma = lr_decay_gamma
        self.optimizer = optimizer
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.nesterov = nesterov
        self.pretrained_weight_path = pretrained_weight_path
        self.checkpoint_path = checkpoint_path
        self.log_path = os.path.join(log_path, self.exp_name)
        self.use_tensorboard = use_tensorboard
        self.model_c = model_c
        self.model_nof_joints = model_nof_joints
        self.model_bn_momentum = model_bn_momentum
        self.flip_test_images = flip_test_images
        self.epoch = 0

        # torch devz
        if device is not None:
            self.device = device
        else:
            if torch.cuda.is_available():
                self.device = torch.device('cuda:0')
            else:
                self.device = torch.device('cpu')

        print(self.device)

        os.makedirs(self.log_path, 0o755, exist_ok=False)
        if self.use_tensorboard:
            self.summary_writer = tb.SummaryWriter(self.log_path)

        #escribe todos los parametros experimentales en parameters.txt y en campos de texto de tensorboard
        self.parameters = [
            x + ': ' + str(y) + '\n' for x, y in locals().items()
        ]
        with open(os.path.join(self.log_path, 'parameters.txt'), 'w') as fd:
            fd.writelines(self.parameters)
        if self.use_tensorboard:
            self.summary_writer.add_text('parameters',
                                         '\n'.join(self.parameters))

        #
        # Carga el modelo
        self.model = HRNet(c=self.model_c,
                           nof_joints=self.model_nof_joints,
                           bn_momentum=self.model_bn_momentum).to(self.device)

        if self.loss == 'JointsMSELoss':
            self.loss_fn = JointsMSELoss().to(self.device)
        elif self.loss == 'JointsOHKMMSELoss':
            self.loss_fn = JointsOHKMMSELoss().to(self.device)
        else:
            raise NotImplementedError

        if optimizer == 'SGD':
            self.optim = SGD(self.model.parameters(),
                             lr=self.lr,
                             weight_decay=self.weight_decay,
                             momentum=self.momentum,
                             nesterov=self.nesterov)
        elif optimizer == 'Adam':
            self.optim = Adam(self.model.parameters(),
                              lr=self.lr,
                              weight_decay=self.weight_decay)
        else:
            raise NotImplementedError

        # Carga las dimensiones preentrenadas
        if self.pretrained_weight_path is not None:
            self.model.load_state_dict(torch.load(self.pretrained_weight_path,
                                                  map_location=self.device),
                                       strict=False)

        #
        # carga puntos de control previos
        if self.checkpoint_path is not None:
            print('Loading checkpoint %s...' % self.checkpoint_path)
            if os.path.isdir(self.checkpoint_path):
                path = os.path.join(self.checkpoint_path,
                                    'checkpoint_last.pth')
            else:
                path = self.checkpoint_path
            self.starting_epoch, self.model, self.optim, self.params = load_checkpoint(
                path, self.model, self.optim, self.device)
        else:
            self.starting_epoch = 0

        if lr_decay:
            self.lr_scheduler = MultiStepLR(self.optim,
                                            list(self.lr_decay_steps),
                                            gamma=self.lr_decay_gamma,
                                            last_epoch=self.starting_epoch)

        # Carga el entrenamiento y los valores de los datasets
        self.dl_train = DataLoader(self.ds_train,
                                   batch_size=self.batch_size,
                                   shuffle=True,
                                   num_workers=self.num_workers,
                                   drop_last=True)
        self.len_dl_train = len(self.dl_train)

        self.dl_val = DataLoader(self.ds_val,
                                 batch_size=self.batch_size,
                                 shuffle=False,
                                 num_workers=self.num_workers)
        self.len_dl_val = len(self.dl_val)

        #
        # inicializa las variables
        self.mean_loss_train = 0.
        self.mean_acc_train = 0.
        self.mean_loss_val = 0.
        self.mean_acc_val = 0.
        self.mean_mAP_val = 0.

        self.best_loss = None
        self.best_acc = None
        self.best_mAP = None

    def _train(self):
        self.model.train()

        for step, (image, target, target_weight, joints_data) in enumerate(
                tqdm(self.dl_train, desc='Training')):
            image = image.to(self.device)
            target = target.to(self.device)
            target_weight = target_weight.to(self.device)

            self.optim.zero_grad()

            output = self.model(image)

            loss = self.loss_fn(output, target, target_weight)

            loss.backward()

            self.optim.step()

            # evalua el aciertos
            # obtiene predicciones
            accs, avg_acc, cnt, joints_preds, joints_target = self.ds_train.evaluate_accuracy(
                output, target)

            self.mean_loss_train += loss.item()
            self.mean_acc_train += avg_acc.item()
            if self.use_tensorboard:
                self.summary_writer.add_scalar('train_loss',
                                               loss.item(),
                                               global_step=step +
                                               self.epoch * self.len_dl_train)
                self.summary_writer.add_scalar('train_acc',
                                               avg_acc.item(),
                                               global_step=step +
                                               self.epoch * self.len_dl_train)
                if step == 0:
                    save_images(image,
                                target,
                                joints_target,
                                output,
                                joints_preds,
                                joints_data['joints_visibility'],
                                self.summary_writer,
                                step=step + self.epoch * self.len_dl_train,
                                prefix='train_')

        self.mean_loss_train /= len(self.dl_train)
        self.mean_acc_train /= len(self.dl_train)

        print('\nTrain: Loss %f - Accuracy %f' %
              (self.mean_loss_train, self.mean_acc_train))

    def _val(self):
        self.model.eval()

        with torch.no_grad():
            for step, (image, target, target_weight, joints_data) in enumerate(
                    tqdm(self.dl_val, desc='Validating')):
                image = image.to(self.device)
                target = target.to(self.device)
                target_weight = target_weight.to(self.device)

                output = self.model(image)

                if self.flip_test_images:
                    image_flipped = flip_tensor(image, dim=-1)
                    output_flipped = self.model(image_flipped)

                    output_flipped = flip_back(output_flipped,
                                               self.ds_val.flip_pairs)

                    output = (output + output_flipped) * 0.5

                loss = self.loss_fn(output, target, target_weight)

                # evalua el aciertos
                # obtiene predicciones
                accs, avg_acc, cnt, joints_preds, joints_target = \
                    self.ds_train.evaluate_accuracy(output, target)

                self.mean_loss_train += loss.item()
                self.mean_acc_train += avg_acc.item()
                if self.use_tensorboard:
                    self.summary_writer.add_scalar(
                        'val_loss',
                        loss.item(),
                        global_step=step + self.epoch * self.len_dl_train)
                    self.summary_writer.add_scalar(
                        'val_acc',
                        avg_acc.item(),
                        global_step=step + self.epoch * self.len_dl_train)
                    if step == 0:
                        save_images(image,
                                    target,
                                    joints_target,
                                    output,
                                    joints_preds,
                                    joints_data['joints_visibility'],
                                    self.summary_writer,
                                    step=step + self.epoch * self.len_dl_train,
                                    prefix='val_')

        self.mean_loss_val /= len(self.dl_val)
        self.mean_acc_val /= len(self.dl_val)

        print('\nValidation: Loss %f - Accuracy %f' %
              (self.mean_loss_val, self.mean_acc_val))

    def _checkpoint(self):

        save_checkpoint(path=os.path.join(self.log_path,
                                          'checkpoint_last.pth'),
                        epoch=self.epoch + 1,
                        model=self.model,
                        optimizer=self.optim,
                        params=self.parameters)

        if self.best_loss is None or self.best_loss > self.mean_loss_val:
            self.best_loss = self.mean_loss_val
            print('best_loss %f at epoch %d' %
                  (self.best_loss, self.epoch + 1))
            save_checkpoint(path=os.path.join(self.log_path,
                                              'checkpoint_best_loss.pth'),
                            epoch=self.epoch + 1,
                            model=self.model,
                            optimizer=self.optim,
                            params=self.parameters)
        if self.best_acc is None or self.best_acc < self.mean_acc_val:
            self.best_acc = self.mean_acc_val
            print('best_acc %f at epoch %d' % (self.best_acc, self.epoch + 1))
            save_checkpoint(path=os.path.join(self.log_path,
                                              'checkpoint_best_acc.pth'),
                            epoch=self.epoch + 1,
                            model=self.model,
                            optimizer=self.optim,
                            params=self.parameters)
        if self.best_mAP is None or self.best_mAP < self.mean_mAP_val:
            self.best_mAP = self.mean_mAP_val
            print('best_mAP %f at epoch %d' % (self.best_mAP, self.epoch + 1))
            save_checkpoint(path=os.path.join(self.log_path,
                                              'checkpoint_best_mAP.pth'),
                            epoch=self.epoch + 1,
                            model=self.model,
                            optimizer=self.optim,
                            params=self.parameters)

    def run(self):
        """
        Runs the training.
        """

        print('\nTraining started @ %s' %
              datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

        # Inica el entrenamiento
        for self.epoch in range(self.starting_epoch, self.epochs):
            print('\nEpoch %d of %d @ %s' %
                  (self.epoch + 1, self.epochs,
                   datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

            self.mean_loss_train = 0.
            self.mean_loss_val = 0.
            self.mean_acc_train = 0.
            self.mean_acc_val = 0.
            self.mean_mAP_val = 0.

            #
            # entrenamiento

            self._train()

            #
            # Val

            self._val()

            #
            # Actualiza LR

            if self.lr_decay:
                self.lr_scheduler.step()

            #
            # Punto de control

            self._checkpoint()

        print('\nTraining ended @ %s' %
              datetime.now().strftime("%Y-%m-%d %H:%M:%S"))