Beispiel #1
0
def export2caffe(weights, num_classes, img_size):
    model = HRNet(num_classes)
    weights = torch.load(weights, map_location='cpu')
    model.load_state_dict(weights['model'])
    model.eval()
    fuse(model)
    name = 'HRNet'
    dummy_input = torch.ones([1, 3, img_size[1], img_size[0]])
    pytorch2caffe.trans_net(model, dummy_input, name)
    pytorch2caffe.save_prototxt('{}.prototxt'.format(name))
    pytorch2caffe.save_caffemodel('{}.caffemodel'.format(name))
Beispiel #2
0
def run(img_dir, output_dir, img_size, num_classes, weights, show):
    shutil.rmtree(output_dir, ignore_errors=True)
    os.makedirs(output_dir, exist_ok=True)
    model = HRNet(num_classes)
    state_dict = torch.load(weights, map_location='cpu')
    model.load_state_dict(state_dict['model'])
    model = model.to(device)
    model.eval()
    names = [n for n in os.listdir(img_dir) if osp.splitext(n)[1] in IMG_EXT]
    names.sort()
    for name in tqdm(names):
        path = osp.join(img_dir, name)
        img = cv2.imread(path)
        kps = inference(model, [img], img_size)[0]
        for (x, y) in kps:
            cv2.circle(img, (int(x * img.shape[1]), int(y * img.shape[0])), 2,
                       (0, 0, 255), -1)
        cv2.imwrite(osp.join(output_dir, osp.splitext(name)[0] + '.png'), img)
Beispiel #3
0
    data_dir=data_dir,
    file_list=train_list,
    label_list=label_list,
    transforms=train_transforms,
    shuffle=True)

eval_reader = Reader(
    data_dir=data_dir,
    file_list=val_list,
    label_list=label_list,
    transforms=eval_transforms)

if args.model_type == 'unet':
    model = UNet(num_classes=num_classes, input_channel=channel)
elif args.model_type == 'hrnet':
    model = HRNet(num_classes=num_classes, input_channel=channel)
else:
    raise ValueError(
        "--model_type: {} is set wrong, it shold be one of ('unet', "
        "'hrnet')".format(args.model_type))

model.train(
    num_epochs=num_epochs,
    train_reader=train_reader,
    train_batch_size=train_batch_size,
    eval_reader=eval_reader,
    eval_best_metric='miou',
    save_interval_epochs=5,
    log_interval_steps=10,
    save_dir=save_dir,
    learning_rate=lr,
Beispiel #4
0
    parser = argparse.ArgumentParser()
    parser.add_argument('val', type=str)
    parser.add_argument('--weights', type=str, default='')
    parser.add_argument('--rect', action='store_true')
    parser.add_argument('-s',
                        '--img_size',
                        type=int,
                        nargs=2,
                        default=[416, 416])
    parser.add_argument('-bs', '--batch-size', type=int, default=32)
    parser.add_argument('--num-workers', type=int, default=4)
    opt = parser.parse_args()

    val_data = CocoDataset(opt.val,
                           img_size=opt.img_size,
                           augments=None,
                           rect=opt.rect)
    val_loader = DataLoader(
        val_data,
        batch_size=opt.batch_size,
        pin_memory=True,
        num_workers=opt.num_workers,
    )
    val_fetcher = Fetcher(val_loader, post_fetch_fn=val_data.post_fetch_fn)
    model = HRNet(len(val_data.classes))
    if opt.weights:
        state_dict = torch.load(opt.weights, map_location='cpu')
        model.load_state_dict(state_dict['model'])
    metrics = test(model, val_fetcher)
    print('metrics: %8g' % (metrics))
Beispiel #5
0
def train(data_dir, epochs, img_size, batch_size, accumulate, lr, adam, resume,
          weights, num_workers, multi_scale, rect, mixed_precision, notest,
          nosave):
    train_coco = osp.join(data_dir, 'train.json')
    val_coco = osp.join(data_dir, 'val.json')

    train_data = CocoDataset(train_coco,
                             img_size=img_size,
                             multi_scale=multi_scale,
                             rect=rect)
    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=not dist.is_initialized(),
        sampler=DistributedSampler(train_data, dist.get_world_size(),
                                   dist.get_rank())
        if dist.is_initialized() else None,
        pin_memory=True,
        num_workers=num_workers,
    )
    train_fetcher = Fetcher(train_loader, train_data.post_fetch_fn)
    if not notest:
        val_data = CocoDataset(val_coco,
                               img_size=img_size,
                               augments=None,
                               rect=rect)
        val_loader = DataLoader(
            val_data,
            batch_size=batch_size,
            shuffle=not dist.is_initialized(),
            sampler=DistributedSampler(val_data, dist.get_world_size(),
                                       dist.get_rank())
            if dist.is_initialized() else None,
            pin_memory=True,
            num_workers=num_workers,
        )
        val_fetcher = Fetcher(val_loader, post_fetch_fn=val_data.post_fetch_fn)

    model = HRNet(len(train_data.classes))

    trainer = Trainer(model,
                      train_fetcher,
                      loss_fn=compute_loss,
                      workdir='weights',
                      accumulate=accumulate,
                      adam=adam,
                      lr=lr,
                      weights=weights,
                      resume=resume,
                      mixed_precision=mixed_precision)
    trainer.metrics = 1
    while trainer.epoch < epochs:
        trainer.step()
        if not notest:
            best = False
            metrics = test(trainer.model, val_fetcher)
            if metrics < trainer.metrics:
                best = True
                print('save best, NME: %g' % metrics)
                trainer.metrics = metrics
        if not nosave:
            trainer.save(best)
Beispiel #6
0
                      transforms=train_transforms,
                      shuffle=True)

eval_reader = Reader(data_dir=data_dir,
                     file_list=val_list,
                     label_list=label_list,
                     transforms=eval_transforms)

if args.model_type == 'unet':
    model = UNet(num_classes=2,
                 input_channel=channel,
                 use_bce_loss=True,
                 use_dice_loss=True)
elif args.model_type == 'hrnet':
    model = HRNet(num_classes=2,
                  input_channel=channel,
                  use_bce_loss=True,
                  use_dice_loss=True)
else:
    raise ValueError(
        "--model_type: {} is set wrong, it shold be one of ('unet', "
        "'hrnet')".format(args.model_type))

model.train(num_epochs=num_epochs,
            train_reader=train_reader,
            train_batch_size=train_batch_size,
            eval_reader=eval_reader,
            save_interval_epochs=5,
            log_interval_steps=10,
            save_dir=save_dir,
            learning_rate=lr,
            use_vdl=True)
Beispiel #7
0
    def __init__(self,
                 exp_name,
                 ds_train,
                 ds_val,
                 epochs=210,
                 batch_size=16,
                 num_workers=4,
                 loss='JointsMSELoss',
                 lr=0.001,
                 lr_decay=True,
                 lr_decay_steps=(170, 200),
                 lr_decay_gamma=0.1,
                 optimizer='Adam',
                 weight_decay=0.,
                 momentum=0.9,
                 nesterov=False,
                 pretrained_weight_path=None,
                 checkpoint_path=None,
                 log_path='./logs',
                 use_tensorboard=True,
                 model_c=48,
                 model_nof_joints=18,
                 model_bn_momentum=0.1,
                 flip_test_images=True,
                 device=None):
        """
        Initializes a new Train object.

        The log folder is created, the HRNet model is initialized and optional pre-trained weights or saved checkpoints
        are loaded.
        The DataLoaders, the loss function, and the optimizer are defined.

        Args:
            exp_name (str):  experiment name.
            ds_train (HumanPoseEstimationDataset): train dataset.
            ds_val (HumanPoseEstimationDataset): validation dataset.
            epochs (int): number of epochs.
                Default: 210
            batch_size (int): batch size.
                Default: 16
            num_workers (int): number of workers for each DataLoader
                Default: 4
            loss (str): loss function. Valid values are 'JointsMSELoss' and 'JointsOHKMMSELoss'.
                Default: "JointsMSELoss"
            lr (float): learning rate.
                Default: 0.001
            lr_decay (bool): learning rate decay.
                Default: True
            lr_decay_steps (tuple): steps for the learning rate decay scheduler.
                Default: (170, 200)
            lr_decay_gamma (float): scale factor for each learning rate decay step.
                Default: 0.1
            optimizer (str): network optimizer. Valid values are 'Adam' and 'SGD'.
                Default: "Adam"
            weight_decay (float): weight decay.
                Default: 0.
            momentum (float): momentum factor.
                Default: 0.9
            nesterov (bool): Nesterov momentum.
                Default: False
            pretrained_weight_path (str): path to pre-trained weights (such as weights from pre-train on imagenet).
                Default: None
            checkpoint_path (str): path to a previous checkpoint.
                Default: None
            log_path (str): path where tensorboard data and checkpoints will be saved.
                Default: "./logs"
            use_tensorboard (bool): enables tensorboard use.
                Default: True
            model_c (int): hrnet parameters - number of channels.
                Default: 48
            model_nof_joints (int): hrnet parameters - number of joints.
                Default: 17
            model_bn_momentum (float): hrnet parameters - path to the pretrained weights.
                Default: 0.1
            flip_test_images (bool): flip images during validating.
                Default: True
            device (torch.device): device to be used (default: cuda, if available).
                Default: None
        """
        super(Train, self).__init__()

        self.exp_name = exp_name
        self.ds_train = ds_train
        self.ds_val = ds_val
        self.epochs = epochs
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.loss = loss
        self.lr = lr
        self.lr_decay = lr_decay
        self.lr_decay_steps = lr_decay_steps
        self.lr_decay_gamma = lr_decay_gamma
        self.optimizer = optimizer
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.nesterov = nesterov
        self.pretrained_weight_path = pretrained_weight_path
        self.checkpoint_path = checkpoint_path
        self.log_path = os.path.join(log_path, self.exp_name)
        self.use_tensorboard = use_tensorboard
        self.model_c = model_c
        self.model_nof_joints = model_nof_joints
        self.model_bn_momentum = model_bn_momentum
        self.flip_test_images = flip_test_images
        self.epoch = 0

        os.makedirs(self.log_path, 0o755,
                    exist_ok=True)  # exist_ok=False to avoid overwriting
        if self.use_tensorboard:
            self.summary_writer = tb.SummaryWriter(self.log_path)

        #
        # write all experiment parameters in parameters.txt and in tensorboard text field
        self.parameters = [
            x + ': ' + str(y) + '\n' for x, y in locals().items()
        ]

        with open(os.path.join(self.log_path, 'parameters.txt'), 'w') as fd:
            fd.writelines(self.parameters)
        if self.use_tensorboard:
            self.summary_writer.add_text('parameters',
                                         '\n'.join(self.parameters))

        #
        # load model
        self.model = HRNet(c=self.model_c,
                           nof_joints=self.model_nof_joints,
                           bn_momentum=self.model_bn_momentum).cuda()

        #
        # define loss and optimizers
        if self.loss == 'JointsMSELoss':
            self.loss_fn = JointsMSELoss()
        elif self.loss == 'JointsOHKMMSELoss':
            self.loss_fn = JointsOHKMMSELoss()
        else:
            raise NotImplementedError

        if optimizer == 'SGD':
            self.optim = SGD(self.model.parameters(),
                             lr=self.lr,
                             weight_decay=self.weight_decay,
                             momentum=self.momentum,
                             nesterov=self.nesterov)
        elif optimizer == 'Adam':
            self.optim = Adam(self.model.parameters(),
                              lr=self.lr,
                              weight_decay=self.weight_decay)
        else:
            raise NotImplementedError

        # load pre-trained weights (such as those pre-trained on imagenet)
        if self.pretrained_weight_path is not None:
            if self.model_nof_joints == 18:
                pretrained_dict = torch.load(self.pretrained_weight_path)
                pretrained_dict_items = list(pretrained_dict.items())
                pretrained_model = {}
                j = 0
                for k, v in self.model.state_dict().items():
                    v = pretrained_dict_items[j][1]
                    k = pretrained_dict_items[j][0]

                    if k == 'final_layer.weight':
                        x = torch.rand(1, 48, 1, 1).cuda()
                        v = torch.cat([v, x], dim=0)
                    if k == 'final_layer.bias':
                        x = torch.rand(1).cuda()
                        v = torch.cat([v, x], dim=0)
                    pretrained_model[k] = v
                    j += 1
                model_dict = self.model.state_dict()
                model_dict.update(pretrained_model)
                self.model.load_state_dict(model_dict, strict=True)
            else:
                self.model.load_state_dict(
                    torch.load(self.pretrained_weight_path, strict=True))
            print('Pre-trained weights loaded.')

        self.model = nn.DataParallel(self.model.cuda())
        # self.model = nn.DataParallel(self.model.to(self.device))
        #
        # load previous checkpoint
        if self.checkpoint_path is not None:
            print('Loading checkpoint %s...' % self.checkpoint_path)
            if os.path.isdir(self.checkpoint_path):
                path = os.path.join(self.checkpoint_path,
                                    'checkpoint_last.pth')
            else:
                path = self.checkpoint_path
            self.starting_epoch, self.model, self.optim, self.params = load_checkpoint(
                path, self.model, self.optim, self.device)
        else:
            self.starting_epoch = 0

        if lr_decay:
            self.lr_scheduler = MultiStepLR(
                self.optim,
                list(self.lr_decay_steps),
                gamma=self.lr_decay_gamma,
                last_epoch=self.starting_epoch if self.starting_epoch else -1)

        #
        # load train and val datasets
        self.dl_train = DataLoader(self.ds_train,
                                   batch_size=self.batch_size,
                                   shuffle=True,
                                   num_workers=self.num_workers,
                                   drop_last=True)
        self.len_dl_train = len(self.dl_train)

        # dl_val = DataLoader(self.ds_val, batch_size=1, shuffle=False, num_workers=num_workers)
        self.dl_val = DataLoader(self.ds_val,
                                 batch_size=self.batch_size,
                                 shuffle=False,
                                 num_workers=self.num_workers)
        self.len_dl_val = len(self.dl_val)

        #
        # initialize variables
        self.mean_loss_train = 0.
        self.mean_acc_train = 0.
        self.mean_loss_val = 0.
        self.mean_acc_val = 0.
        self.mean_mAP_val = 0.

        self.best_loss = None
        self.best_acc = None
        self.best_mAP = None
Beispiel #8
0
class Train(object):
    """
    Train  class.

    The class provides a basic tool for training HRNet.
    Most of the training options are customizable.

    The only method supposed to be directly called is `run()`.
    """
    def __init__(self,
                 exp_name,
                 ds_train,
                 ds_val,
                 epochs=210,
                 batch_size=16,
                 num_workers=4,
                 loss='JointsMSELoss',
                 lr=0.001,
                 lr_decay=True,
                 lr_decay_steps=(170, 200),
                 lr_decay_gamma=0.1,
                 optimizer='Adam',
                 weight_decay=0.,
                 momentum=0.9,
                 nesterov=False,
                 pretrained_weight_path=None,
                 checkpoint_path=None,
                 log_path='./logs',
                 use_tensorboard=True,
                 model_c=48,
                 model_nof_joints=18,
                 model_bn_momentum=0.1,
                 flip_test_images=True,
                 device=None):
        """
        Initializes a new Train object.

        The log folder is created, the HRNet model is initialized and optional pre-trained weights or saved checkpoints
        are loaded.
        The DataLoaders, the loss function, and the optimizer are defined.

        Args:
            exp_name (str):  experiment name.
            ds_train (HumanPoseEstimationDataset): train dataset.
            ds_val (HumanPoseEstimationDataset): validation dataset.
            epochs (int): number of epochs.
                Default: 210
            batch_size (int): batch size.
                Default: 16
            num_workers (int): number of workers for each DataLoader
                Default: 4
            loss (str): loss function. Valid values are 'JointsMSELoss' and 'JointsOHKMMSELoss'.
                Default: "JointsMSELoss"
            lr (float): learning rate.
                Default: 0.001
            lr_decay (bool): learning rate decay.
                Default: True
            lr_decay_steps (tuple): steps for the learning rate decay scheduler.
                Default: (170, 200)
            lr_decay_gamma (float): scale factor for each learning rate decay step.
                Default: 0.1
            optimizer (str): network optimizer. Valid values are 'Adam' and 'SGD'.
                Default: "Adam"
            weight_decay (float): weight decay.
                Default: 0.
            momentum (float): momentum factor.
                Default: 0.9
            nesterov (bool): Nesterov momentum.
                Default: False
            pretrained_weight_path (str): path to pre-trained weights (such as weights from pre-train on imagenet).
                Default: None
            checkpoint_path (str): path to a previous checkpoint.
                Default: None
            log_path (str): path where tensorboard data and checkpoints will be saved.
                Default: "./logs"
            use_tensorboard (bool): enables tensorboard use.
                Default: True
            model_c (int): hrnet parameters - number of channels.
                Default: 48
            model_nof_joints (int): hrnet parameters - number of joints.
                Default: 17
            model_bn_momentum (float): hrnet parameters - path to the pretrained weights.
                Default: 0.1
            flip_test_images (bool): flip images during validating.
                Default: True
            device (torch.device): device to be used (default: cuda, if available).
                Default: None
        """
        super(Train, self).__init__()

        self.exp_name = exp_name
        self.ds_train = ds_train
        self.ds_val = ds_val
        self.epochs = epochs
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.loss = loss
        self.lr = lr
        self.lr_decay = lr_decay
        self.lr_decay_steps = lr_decay_steps
        self.lr_decay_gamma = lr_decay_gamma
        self.optimizer = optimizer
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.nesterov = nesterov
        self.pretrained_weight_path = pretrained_weight_path
        self.checkpoint_path = checkpoint_path
        self.log_path = os.path.join(log_path, self.exp_name)
        self.use_tensorboard = use_tensorboard
        self.model_c = model_c
        self.model_nof_joints = model_nof_joints
        self.model_bn_momentum = model_bn_momentum
        self.flip_test_images = flip_test_images
        self.epoch = 0

        os.makedirs(self.log_path, 0o755,
                    exist_ok=True)  # exist_ok=False to avoid overwriting
        if self.use_tensorboard:
            self.summary_writer = tb.SummaryWriter(self.log_path)

        #
        # write all experiment parameters in parameters.txt and in tensorboard text field
        self.parameters = [
            x + ': ' + str(y) + '\n' for x, y in locals().items()
        ]

        with open(os.path.join(self.log_path, 'parameters.txt'), 'w') as fd:
            fd.writelines(self.parameters)
        if self.use_tensorboard:
            self.summary_writer.add_text('parameters',
                                         '\n'.join(self.parameters))

        #
        # load model
        self.model = HRNet(c=self.model_c,
                           nof_joints=self.model_nof_joints,
                           bn_momentum=self.model_bn_momentum).cuda()

        #
        # define loss and optimizers
        if self.loss == 'JointsMSELoss':
            self.loss_fn = JointsMSELoss()
        elif self.loss == 'JointsOHKMMSELoss':
            self.loss_fn = JointsOHKMMSELoss()
        else:
            raise NotImplementedError

        if optimizer == 'SGD':
            self.optim = SGD(self.model.parameters(),
                             lr=self.lr,
                             weight_decay=self.weight_decay,
                             momentum=self.momentum,
                             nesterov=self.nesterov)
        elif optimizer == 'Adam':
            self.optim = Adam(self.model.parameters(),
                              lr=self.lr,
                              weight_decay=self.weight_decay)
        else:
            raise NotImplementedError

        # load pre-trained weights (such as those pre-trained on imagenet)
        if self.pretrained_weight_path is not None:
            if self.model_nof_joints == 18:
                pretrained_dict = torch.load(self.pretrained_weight_path)
                pretrained_dict_items = list(pretrained_dict.items())
                pretrained_model = {}
                j = 0
                for k, v in self.model.state_dict().items():
                    v = pretrained_dict_items[j][1]
                    k = pretrained_dict_items[j][0]

                    if k == 'final_layer.weight':
                        x = torch.rand(1, 48, 1, 1).cuda()
                        v = torch.cat([v, x], dim=0)
                    if k == 'final_layer.bias':
                        x = torch.rand(1).cuda()
                        v = torch.cat([v, x], dim=0)
                    pretrained_model[k] = v
                    j += 1
                model_dict = self.model.state_dict()
                model_dict.update(pretrained_model)
                self.model.load_state_dict(model_dict, strict=True)
            else:
                self.model.load_state_dict(
                    torch.load(self.pretrained_weight_path, strict=True))
            print('Pre-trained weights loaded.')

        self.model = nn.DataParallel(self.model.cuda())
        # self.model = nn.DataParallel(self.model.to(self.device))
        #
        # load previous checkpoint
        if self.checkpoint_path is not None:
            print('Loading checkpoint %s...' % self.checkpoint_path)
            if os.path.isdir(self.checkpoint_path):
                path = os.path.join(self.checkpoint_path,
                                    'checkpoint_last.pth')
            else:
                path = self.checkpoint_path
            self.starting_epoch, self.model, self.optim, self.params = load_checkpoint(
                path, self.model, self.optim, self.device)
        else:
            self.starting_epoch = 0

        if lr_decay:
            self.lr_scheduler = MultiStepLR(
                self.optim,
                list(self.lr_decay_steps),
                gamma=self.lr_decay_gamma,
                last_epoch=self.starting_epoch if self.starting_epoch else -1)

        #
        # load train and val datasets
        self.dl_train = DataLoader(self.ds_train,
                                   batch_size=self.batch_size,
                                   shuffle=True,
                                   num_workers=self.num_workers,
                                   drop_last=True)
        self.len_dl_train = len(self.dl_train)

        # dl_val = DataLoader(self.ds_val, batch_size=1, shuffle=False, num_workers=num_workers)
        self.dl_val = DataLoader(self.ds_val,
                                 batch_size=self.batch_size,
                                 shuffle=False,
                                 num_workers=self.num_workers)
        self.len_dl_val = len(self.dl_val)

        #
        # initialize variables
        self.mean_loss_train = 0.
        self.mean_acc_train = 0.
        self.mean_loss_val = 0.
        self.mean_acc_val = 0.
        self.mean_mAP_val = 0.

        self.best_loss = None
        self.best_acc = None
        self.best_mAP = None

    def _train(self):
        self.model.train()

        for step, (image, target, target_weight, joints_data) in enumerate(
                tqdm(self.dl_train, desc='Training')):
            image = image.cuda()
            target = target.cuda()
            target_weight = target_weight.cuda()

            self.optim.zero_grad()

            output = self.model(image)

            loss = self.loss_fn(output, target, target_weight)

            loss.backward()

            self.optim.step()

            # Evaluate accuracy
            # Get predictions on the input
            accs, avg_acc, cnt, joints_preds, joints_target = self.ds_train.evaluate_accuracy(
                output, target)

            self.mean_loss_train += loss.item()
            self.mean_acc_train += avg_acc.item()
            if self.use_tensorboard:
                self.summary_writer.add_scalar('train_loss',
                                               loss.item(),
                                               global_step=step +
                                               self.epoch * self.len_dl_train)
                self.summary_writer.add_scalar('train_acc',
                                               avg_acc.item(),
                                               global_step=step +
                                               self.epoch * self.len_dl_train)
                if step == 0:
                    save_images(image,
                                target,
                                joints_target,
                                output,
                                joints_preds,
                                joints_data['joints_visibility'],
                                self.summary_writer,
                                step=step + self.epoch * self.len_dl_train,
                                prefix='train_')

        self.mean_loss_train /= len(self.dl_train)
        self.mean_acc_train /= len(self.dl_train)

        print('\nTrain: Loss %f - Accuracy %f' %
              (self.mean_loss_train, self.mean_acc_train))

    def _val(self):
        self.model.eval()

        with torch.no_grad():
            for step, (image, target, target_weight, joints_data) in enumerate(
                    tqdm(self.dl_val, desc='Validating')):
                image = image.cuda()
                target = target.cuda()
                target_weight = target_weight.cuda()

                output = self.model(image)

                if self.flip_test_images:
                    image_flipped = flip_tensor(image, dim=-1)
                    output_flipped = self.model(image_flipped)

                    output_flipped = flip_back(output_flipped,
                                               self.ds_val.flip_pairs)

                    output = (output + output_flipped) * 0.5

                loss = self.loss_fn(output, target, target_weight)

                # Evaluate accuracy
                # Get predictions on the input
                accs, avg_acc, cnt, joints_preds, joints_target = \
                    self.ds_train.evaluate_accuracy(output, target)

                self.mean_loss_train += loss.item()
                self.mean_acc_train += avg_acc.item()
                if self.use_tensorboard:
                    self.summary_writer.add_scalar(
                        'val_loss',
                        loss.item(),
                        global_step=step + self.epoch * self.len_dl_train)
                    self.summary_writer.add_scalar(
                        'val_acc',
                        avg_acc.item(),
                        global_step=step + self.epoch * self.len_dl_train)
                    if step == 0:
                        save_images(image,
                                    target,
                                    joints_target,
                                    output,
                                    joints_preds,
                                    joints_data['joints_visibility'],
                                    self.summary_writer,
                                    step=step + self.epoch * self.len_dl_train,
                                    prefix='val_')

        self.mean_loss_val /= len(self.dl_val)
        self.mean_acc_val /= len(self.dl_val)

        print('\nValidation: Loss %f - Accuracy %f' %
              (self.mean_loss_val, self.mean_acc_val))

    def _checkpoint(self):

        save_checkpoint(path=os.path.join(self.log_path,
                                          'checkpoint_last.pth'),
                        epoch=self.epoch + 1,
                        model=self.model,
                        optimizer=self.optim,
                        params=self.parameters)

        if self.best_loss is None or self.best_loss > self.mean_loss_val:
            self.best_loss = self.mean_loss_val
            # print('best_loss %f at epoch %d' % (self.best_loss, self.epoch + 1))
            save_checkpoint(path=os.path.join(self.log_path,
                                              'checkpoint_best_loss.pth'),
                            epoch=self.epoch + 1,
                            model=self.model,
                            optimizer=self.optim,
                            params=self.parameters)
        if self.best_acc is None or self.best_acc < self.mean_acc_val:
            self.best_acc = self.mean_acc_val
            # print('best_acc %f at epoch %d' % (self.best_acc, self.epoch + 1))
            save_checkpoint(path=os.path.join(self.log_path,
                                              'checkpoint_best_acc.pth'),
                            epoch=self.epoch + 1,
                            model=self.model,
                            optimizer=self.optim,
                            params=self.parameters)
        if self.best_mAP is None or self.best_mAP < self.mean_mAP_val:
            self.best_mAP = self.mean_mAP_val
            # print('best_mAP %f at epoch %d' % (self.best_mAP, self.epoch + 1))
            save_checkpoint(path=os.path.join(self.log_path,
                                              'checkpoint_best_mAP.pth'),
                            epoch=self.epoch + 1,
                            model=self.model,
                            optimizer=self.optim,
                            params=self.parameters)

    def run(self):
        """
        Runs the training.
        """

        print('\nTraining started @ %s' %
              datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

        # start training
        for self.epoch in range(self.starting_epoch, self.epochs):
            # print('\nEpoch %d of %d @ %s' % (self.epoch + 1, self.epochs, datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

            self.mean_loss_train = 0.
            self.mean_loss_val = 0.
            self.mean_acc_train = 0.
            self.mean_acc_val = 0.
            self.mean_mAP_val = 0.

            #
            # Train

            self._train()

            #
            # Val

            self._val()

            #
            # LR Update

            if self.lr_decay:
                self.lr_scheduler.step()

            #
            # Checkpoint

            self._checkpoint()

        print('\nTraining ended @ %s' %
              datetime.now().strftime("%Y-%m-%d %H:%M:%S"))