def visualize_stn(original, transformed, before_stn, after_stn, filenames,
                  target):
    target = os.path.join(target, "stnviz")
    mkdir_if_not_exists(target)

    transformed = transformed.detach().cpu().numpy().transpose(0, 2, 3, 1)
    before_stn = before_stn.detach().cpu().numpy().transpose(0, 2, 3, 1)
    after_stn = after_stn.detach().cpu().numpy().transpose(0, 2, 3, 1)

    for orig, trans, before, after, filename in zip(original, transformed,
                                                    before_stn, after_stn,
                                                    filenames):
        origimg = cv2.cvtColor(orig, cv2.COLOR_RGB2BGR)

        before = cv2.cvtColor(
            FaceLandmarksTrainingData.undo_normalization(before),
            cv2.COLOR_RGB2BGR)
        after = cv2.cvtColor(
            FaceLandmarksTrainingData.undo_normalization(after),
            cv2.COLOR_RGB2BGR)
        trans = cv2.cvtColor(
            FaceLandmarksTrainingData.undo_normalization(trans),
            cv2.COLOR_RGB2BGR)

        comb = np.hstack((origimg, trans, before, after))
        fn = os.path.join(target, "%s.png" % filename)
        cv2.imwrite(fn, comb)
Пример #2
0
    def run(self):
        torch.autograd.set_detect_anomaly(True)  # This makes debugging much easier

        self.config["model_dir"] = self.model_dir

        make_deterministic(self.config['random_seed'])

        location = 'cpu' if self.gpu_id is None else "cuda:%d" % self.gpu_id
        if location is not 'cpu':
            # This fixes the problem that pytorch is always allocating memory on GPU 0 even if this is not included
            # in the list of GPUs to use
            torch.cuda.set_device(torch.device(location))

            # cudnn.benchmark improves training speed when input sizes do not change
            # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
            # It selects the best algorithms as the training iterates over the dataset
            #cudnn.benchmark = True # but it can cause determinism problems, so disable

        hg, hg_config = self.load_hg(self.config["initial_hg"], location)
        pdm, pdm_config = self.load_pdm(self.config["initial_pdm"], location)

        pdm.verbose = not self.is_gridsearch
        pdm.print_losses = False
        pdm.listener = self.receive_pdm_output

        normMean, normStd = FaceLandmarksTrainingData.TRAIN_MEAN, FaceLandmarksTrainingData.TRAIN_STD
        normTransform = transforms.Normalize(normMean, normStd)

        jitterTransform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)

        transform = transforms.Compose([
            ImageTransform(transforms.ToPILImage()),
            ImageTransform(jitterTransform),
            ImageAndLabelTransform(RandomHorizontalFlip()),
            ImageTransform(transforms.ToTensor()),
            ImageTransform(normTransform)
        ])

        bs = self.config["bs"]
        pin_memory = location != 'cpu'
        num_workers = 8

        with h5py.File(self.config["data"], 'r') as f:
            train_d = FaceLandmarksTrainingData(f, transform=transform)
            train_loader = DataLoader(dataset=train_d, shuffle=self.config["shuffle"], num_workers=num_workers, pin_memory=pin_memory, batch_size=bs)

        results_before = run_e2e(hg, pdm, self.config["data"], location, self.config["bs"], verbose=True)
        if not self.is_gridsearch:
            print("Before training")
            for model, res in results_before.items():
                print(model, res)

        zs, nr, losses = pdm.end2end_training(hg=hg,
                                              data_loader=train_loader,
                                              hg_opt_config=self.config["hg_optimizer"],
                                              pdm_weight_opt_config=self.config["pdm_weight_optimizer"],
                                              pdm_shape_opt_config=self.config["pdm_shape_optimizer"],
                                              training_schedule=self.config["training_schedule"],
                                              detach_confidence=self.config["detach_confidence"])

        plot_path = os.path.join(self.plot_dir, "losses_%d.png" % self.config["config_id"])
        if not self.is_gridsearch: print("save plot to %s" % plot_path)
        fig, ax = plt.subplots()
        ax.plot(losses)
        ax.set(xlabel='epoch', ylabel='loss', title='loss per epoch')
        ax.grid()
        fig.savefig(plot_path)

        if not self.is_gridsearch: print("save HG")
        torch.save({
            'model': 'pe_hourglass',
            'state_dict': hg.state_dict(),
            'config': hg_config
        }, os.path.join(self.model_dir, "%d_hg_e2e.torch" % self.config["config_id"]))

        if not self.is_gridsearch: print("save PDM")
        pdm.save_pdm(pdm.train_epochs, os.path.join(self.model_dir, "%d_pdm_e2e.torch" % self.config["config_id"]))

        results_after = run_e2e(hg, pdm, self.config["data"], location, self.config["bs"], verbose=False)

        if not self.is_gridsearch:
            print("Before training")
            for model, res in results_before.items():
                print(model, res)

            print("After training")
            for model, res in results_after.items():
                print(model, res)

        if self.is_gridsearch:
            logpath = os.path.join(self.result_dir, "%d_log.json" % self.config["config_id"])
            json.dump({
                "gt": self.gts,
                "l2d": self.l2d_log,
                "hg": self.hg_coords_log,
                "losses": self.loss_log
            }, open(logpath, "w"))

            return {
                **self.config,
                "min_loss": min(self.loss_log),
                "last_loss" : self.loss_log[-1],
                "hg_before_easy_with" : results_before["hg"]["easy_woutline"],
                "hg_before_easy_without": results_before["hg"]["easy_noutline"],
                "hg_before_hard_with": results_before["hg"]["hard_woutline"],
                "hg_before_hard_without": results_before["hg"]["hard_noutline"],
                "pdm_before_easy_with": results_before["pdm"]["easy_woutline"],
                "pdm_before_easy_without": results_before["pdm"]["easy_noutline"],
                "pdm_before_hard_with": results_before["pdm"]["hard_woutline"],
                "pdm_before_hard_without": results_before["pdm"]["hard_noutline"],
                "hg_after_easy_with": results_after["hg"]["easy_woutline"],
                "hg_after_easy_without": results_after["hg"]["easy_noutline"],
                "hg_after_hard_with": results_after["hg"]["hard_woutline"],
                "hg_after_hard_without": results_after["hg"]["hard_noutline"],
                "pdm_after_easy_with": results_after["pdm"]["easy_woutline"],
                "pdm_after_easy_without": results_after["pdm"]["easy_noutline"],
                "pdm_after_hard_with": results_after["pdm"]["hard_woutline"],
                "pdm_after_hard_without": results_after["pdm"]["hard_noutline"],
            }
        ImageTransform(transforms.ToPILImage()),
        #ImageAndLabelTransform(RandomHorizontalFlip()),
        #ImageAndLabelTransform(NormalizeRotation()),
        ImageAndLabelTransform(
            RandomRotation(min_angle=-30,
                           max_angle=30,
                           retain_scale=False,
                           rotate_landmarks="neutral")),
        ImageTransform(transforms.ToTensor()),
        #ImageTransform(normTransform)
    ])

    with h5py.File(args.dataset, 'r') as f:
        easy_d = FaceLandmarksEasyTestData(f, transform=transform)
        hard_d = FaceLandmarksHardTestData(f, transform=transform)
        train = FaceLandmarksTrainingData(f, transform=transform)

        imgs = []
        for x in easy_d:
            imgs.append(
                (x["angle"], x["original_image"],
                 (255 * x["image"]).type(torch.uint8).permute(1, 2, 0).numpy(),
                 x["original_landmarks"], x["landmarks"]))

        cv2.namedWindow("test_rotations")  # Create a named window
        cv2.moveWindow("test_rotations", 200, 200)

        imgs = sorted(imgs, key=lambda x: abs(x[0]))
        for angle, original_image, corrected_img, origlms, lms in imgs[::
                                                                       -1][:3]:
            withlms = draw_landmarks(corrected_img[:, :, ::-1], lms)
def visualize(model,
              dataset,
              target,
              gpu=None,
              splits=["easy", "hard"],
              landmarks_in_heatmaps=True):
    location = 'cpu' if gpu is None else "cuda:%d" % gpu
    if location is not 'cpu':
        # This fixes the problem that pytorch is always allocating memory on GPU 0 even if this is not included
        # in the list of GPUs to use
        torch.cuda.set_device(torch.device(location))

        # cudnn.benchmark improves training speed when input sizes do not change
        # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        # It selects the best algorithms as the training iterates over the dataset
        cudnn.benchmark = True

    print("Location: ", location)

    data = torch.load(model, map_location=location)
    state_dict = data['state_dict']
    config = data['config']

    num_workers = multiprocessing.cpu_count()
    batch_size = config['batch_size'] if gpu is not None else num_workers
    pin_memory = gpu is not None

    print("Workers: ", num_workers)
    print("Batchsize: ", batch_size)

    net = ModelTrainer.create_net(config, verbose=False)
    net.load_state_dict(state_dict)
    net.eval()

    net = net.to(location)

    mkdir_if_not_exists(target)

    normMean, normStd = FaceLandmarksTrainingData.TRAIN_MEAN, FaceLandmarksTrainingData.TRAIN_STD
    normTransform = transforms.Normalize(normMean, normStd)

    transform = transforms.Compose([
        ImageTransform(transforms.ToPILImage()),
        #ImageAndLabelTransform(RandomHorizontalFlip()),
        #ImageAndLabelTransform(RandomRotation(min_angle=-0, max_angle=0, retain_scale=False)),
        ImageTransform(transforms.ToTensor()),
        ImageTransform(normTransform)
    ])

    with h5py.File(dataset, 'r') as f:
        if "easy" in splits:
            print("Run on easy")
            easy_d = FaceLandmarksEasyTestData(f, transform=transform)
            #print(len(easy_d))
            easy_loader = DataLoader(dataset=easy_d,
                                     shuffle=False,
                                     num_workers=num_workers,
                                     pin_memory=pin_memory,
                                     batch_size=batch_size)
            visualize_split(net,
                            easy_loader,
                            os.path.join(target, "easy"),
                            location,
                            landmarks_in_heatmaps=landmarks_in_heatmaps)

        if "hard" in splits:
            print("Run on hard")
            hard_d = FaceLandmarksHardTestData(f, transform=transform)
            #print(len(hard_d))
            hard_loader = DataLoader(dataset=hard_d,
                                     shuffle=False,
                                     num_workers=num_workers,
                                     pin_memory=pin_memory,
                                     batch_size=batch_size)
            visualize_split(net,
                            hard_loader,
                            os.path.join(target, "hard"),
                            location,
                            landmarks_in_heatmaps=landmarks_in_heatmaps)

        if "train" in splits:
            print("Run on train")
            train = FaceLandmarksTrainingData(f, transform=transform)
            #print(len(train))
            train_loader = DataLoader(dataset=train,
                                      shuffle=False,
                                      num_workers=num_workers,
                                      pin_memory=pin_memory,
                                      batch_size=batch_size)
            visualize_split(net,
                            train_loader,
                            os.path.join(target, "train"),
                            location,
                            landmarks_in_heatmaps=landmarks_in_heatmaps)
                        action='store_false',
                        help='Store original images instead of landmarks')

    parser.add_argument('--is_menpo',
                        dest="is_menpo",
                        default=False,
                        action="store_true",
                        help="Specify this when using menpo instead of 300-W")

    args = parser.parse_args()

    with h5py.File(args.source, 'r') as f:
        if args.is_menpo:
            splits = [Menpo(f)]
        else:
            splits = [FaceLandmarksTrainingData(f), FaceLandmarksEasyTestData(f), FaceLandmarksHardTestData(f)]

        stats = defaultdict(int)

        for data in splits:
            split = data.split
            print('Split: %s' % split)

            directory = os.path.join(args.target, split)
            mkdir_if_not_exists(directory)

            for sample in tqdm(list(data)):
                image = sample['original_image'][:,:,::-1].copy()  # RGB o BGR

                if args.draw_landmarks:
                    image = draw_landmarks(image, sample['landmarks'])
def run(model,
        src_300w,
        src_menpo,
        target,
        gpu=None,
        override_norm_params=False,
        bs_factor=1):
    location = 'cpu' if gpu is None else "cuda:%d" % gpu
    if location is not 'cpu':
        # This fixes the problem that pytorch is always allocating memory on GPU 0 even if this is not included
        # in the list of GPUs to use
        torch.cuda.set_device(torch.device(location))

        # cudnn.benchmark improves training speed when input sizes do not change
        # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
        # It selects the best algorithms as the training iterates over the dataset
        #cudnn.benchmark = True # disable for deterministic behavior

    print("Location: ", location)

    data = torch.load(model, map_location=location)
    state_dict = data['state_dict']
    config = data['config']
    n_lm = config["n_lm"]

    if n_lm == 49:
        print("WARNING! THIS IS A 49 LM model!!!!", n_lm)

    num_workers = multiprocessing.cpu_count()
    batch_size = config[
        'batch_size'] * bs_factor if gpu is not None else num_workers
    pin_memory = gpu is not None

    print("Workers: ", num_workers)
    print("Batchsize: ", batch_size)

    net = ModelTrainer.create_net(config, verbose=False)
    net.load_state_dict(state_dict)
    net.eval()

    net.to(location)

    mkdir_if_not_exists(os.path.dirname(target))

    normMean, normStd = FaceLandmarksTrainingData.TRAIN_MEAN, FaceLandmarksTrainingData.TRAIN_STD

    if override_norm_params:
        normMean = tuple(
            np.array([133.0255852472676, 101.61684197664563, 87.4134193236219])
            / 255.0)
        normStd = tuple(
            np.array([71.91047346327116, 62.94368776888253, 61.56865329427311])
            / 255.0)

    normTransform = transforms.Normalize(normMean, normStd)

    transform = transforms.Compose([
        ImageTransform(transforms.ToPILImage()),
        ImageTransform(transforms.ToTensor()),
        ImageTransform(normTransform)
    ])

    with h5py.File(src_300w, 'r') as f:
        print("Run on easy")
        easy_d = FaceLandmarksEasyTestData(f, transform=transform, n_lm=n_lm)
        easy_loader = DataLoader(dataset=easy_d,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 pin_memory=pin_memory,
                                 batch_size=batch_size)
        easy_results = evaluate_split(net,
                                      easy_loader,
                                      location=location,
                                      n_lm=n_lm)

        print("Run on hard")
        hard_d = FaceLandmarksHardTestData(f, transform=transform, n_lm=n_lm)
        hard_loader = DataLoader(dataset=hard_d,
                                 shuffle=False,
                                 num_workers=num_workers,
                                 pin_memory=pin_memory,
                                 batch_size=batch_size)
        hard_results = evaluate_split(net,
                                      hard_loader,
                                      location=location,
                                      n_lm=n_lm)

        print("Run on train")
        train = FaceLandmarksTrainingData(f, transform=transform, n_lm=n_lm)
        train_loader = DataLoader(dataset=train,
                                  shuffle=False,
                                  num_workers=num_workers,
                                  pin_memory=pin_memory,
                                  batch_size=batch_size)
        train_results = evaluate_split(net,
                                       train_loader,
                                       location=location,
                                       n_lm=n_lm)

    with h5py.File(src_menpo, "r") as f:
        print("Run on menpo")
        menpo = Menpo(f, transform=transform, n_lm=n_lm)
        menpo_loader = DataLoader(dataset=menpo,
                                  shuffle=False,
                                  num_workers=num_workers,
                                  pin_memory=pin_memory,
                                  batch_size=batch_size)
        menpo_results = evaluate_split(net,
                                       menpo_loader,
                                       location=location,
                                       n_lm=n_lm)

    res = {
        "easy": easy_results,
        "hard": hard_results,
        "train": train_results,
        "menpo": menpo_results,
        "model_src": model,
        "config": config
    }

    if target is not None:
        json.dump(res, open(target, "w"))
    else:
        return res
Пример #7
0
    def run(self):
        torch.cuda.empty_cache()

        starttime = time.time()

        if self.gpu_id is not None:
            # cudnn.benchmark improves training speed when input sizes do not change
            # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
            # It selects the best algorithms as the training iterates over the dataset
            # I found no big difference between True and False, but it also doesn't hurt, so enable it
            #cudnn.benchmark = True # disable for deterministic behavior
            pass

        config = self.config
        config_id = config["config_id"]
        n_lm = config["n_lm"]

        make_deterministic(config['random_seed'])
        torch.autograd.set_detect_anomaly(
            True)  # This makes debugging much easier

        jitterTransform = transforms.ColorJitter(brightness=0.4,
                                                 contrast=0.4,
                                                 saturation=0.4,
                                                 hue=0.1)

        # TODO store these values in h5 files
        normMean, normStd = FaceLandmarksTrainingData.TRAIN_MEAN, FaceLandmarksTrainingData.TRAIN_STD
        normTransform = transforms.Normalize(normMean, normStd)

        rot_angle = float(config['augment_rotation'])
        rotation_augmentation = RandomRotation(min_angle=-1 * rot_angle,
                                               max_angle=rot_angle,
                                               retain_scale=False,
                                               rotate_landmarks="same")

        trainTransform = transforms.Compose([
            ImageTransform(transforms.ToPILImage()),
            ImageTransform(jitterTransform),
            ImageAndLabelTransform(RandomHorizontalFlip()),
            ImageAndLabelTransform(rotation_augmentation),
            ImageTransform(transforms.ToTensor()),
            ImageTransform(normTransform)
        ])

        testTransform = transforms.Compose([
            ImageTransform(transforms.ToPILImage()),
            ImageTransform(transforms.ToTensor()),
            ImageTransform(normTransform)
        ])

        # Note: Reading takes only ~0.2s, so it is okay to do this again whenever main.py is called
        # No need to read in trainer.py and pass results here
        with h5py.File(self.data, 'r') as f:
            train_dataset = FaceLandmarksTrainingData(f,
                                                      transform=trainTransform,
                                                      n_lm=n_lm)
            val_dataset = FaceLandmarksAllTestData(f,
                                                   transform=testTransform,
                                                   n_lm=n_lm)
            easy_d = FaceLandmarksEasyTestData(f,
                                               transform=testTransform,
                                               n_lm=n_lm)
            hard_d = FaceLandmarksHardTestData(f,
                                               transform=testTransform,
                                               n_lm=n_lm)

        print("GPU %d.%d" % (self.gpu_id, self.sub_gpu_id),
              "Data: %s" % self.data,
              "Train %d Test %d" % (len(train_dataset), len(val_dataset)))

        dataloader_params = {
            'batch_size': config['batch_size'],
            'pin_memory': self.gpu_id is not None,
            'num_workers': 8
        }

        train_loader = DataLoader(train_dataset,
                                  shuffle=True,
                                  **dataloader_params)
        val_loader = DataLoader(val_dataset,
                                shuffle=False,
                                **dataloader_params)
        easy = DataLoader(easy_d, shuffle=False, **dataloader_params)
        hard = DataLoader(hard_d, shuffle=False, **dataloader_params)

        net = self.create_net(config)
        _, trainable_parameters, _ = count_parameters(net)
        self.to_gpu(net)
        net.train()  # Put net into train mode

        params = [
            {
                "params": net.hourglass.parameters()
            },
            {
                "params": net.regressor.parameters()
            },
        ]

        if config["predict_distances_weight"] > 0:
            # generate ground truth distances
            y = torch.stack([x["landmarks"] for x in train_dataset])
            bs = y.shape[0]
            n_lm = y.shape[1]
            dist_gt = torch.zeros(bs, n_lm, n_lm, 2)
            dist_gt[:, :, :, 0] = y[:, :, 0].view(bs, 1, -1) - y[:, :, 0].view(
                bs, -1, 1)
            dist_gt[:, :, :, 1] = y[:, :, 1].view(bs, 1, -1) - y[:, :, 1].view(
                bs, -1, 1)

        optimizer = optim.Adam(params, lr=config['lr'])

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            'min',
            patience=config['lr_scheduler_patience'],
            verbose=True,
            factor=config['lr_decay_factor'])

        early_stopping_patience = config['lr_scheduler_patience'] * 2 + 1
        early_stopping_max_ratio = 0.975
        should_stop = EarlyStopping(patience=early_stopping_patience,
                                    max_ratio=early_stopping_max_ratio,
                                    verbose=False)

        loss_function = self.get_loss_function(config['regression'],
                                               config['loss_function'])

        category_calculator = {
            "e49":
            lambda metrics: metrics["e49"],
            "h49":
            lambda metrics: metrics["h49"],
            "e68":
            lambda metrics: metrics["e68"],
            "h68":
            lambda metrics: metrics["h68"],
            "49":
            lambda metrics: (metrics["e49"] + metrics["h49"]) / 2,
            "68":
            lambda metrics: (metrics["e68"] + metrics["h68"]) / 2,
            "e":
            lambda metrics: (metrics["e49"] + metrics["e68"]) / 2,
            "h":
            lambda metrics: (metrics["h49"] + metrics["h68"]) / 2,
            "all":
            lambda metrics: (metrics["e49"] + metrics["h49"] + metrics["e68"] +
                             metrics["h68"]) / 4
        }
        categories = category_calculator.keys()
        best_epoch = {k: 0 for k in categories}
        lowest_error = {k: np.Inf for k in categories}
        epoch_train_losses = []
        epoch_val_losses = []

        # Only store models that are better than these values to save storage
        storage_thresholds = {"e49": 2.1, "h49": 3.4, "e68": 2.7, "h68": 4.5}
        storage_thresholds["49"] = category_calculator["49"](
            storage_thresholds)
        storage_thresholds["68"] = category_calculator["68"](
            storage_thresholds)
        storage_thresholds["e"] = category_calculator["e"](storage_thresholds)
        storage_thresholds["h"] = category_calculator["h"](storage_thresholds)
        storage_thresholds["all"] = category_calculator["all"](
            storage_thresholds)

        loss_history = {}
        metric_history = []

        dist_loss_fct = nn.L1Loss()

        epochs = config['n_epoch']
        for epoch in range(epochs):
            epoch_start_time = time.time()

            net.train()
            epoch_train_loss = 0
            epoch_sample_count = 0

            for sample in train_loader:
                x = self.to_gpu(sample['image'].float())
                y = self.to_gpu(sample['landmarks'].float())
                if config["predict_distances_weight"] > 0:
                    indices = self.to_gpu(sample['index'])
                    dist_y = self.to_gpu(dist_gt[indices])
                epoch_sample_count += x.shape[0]

                optimizer.zero_grad()

                coords, heatmaps, var, unnormalized_heatmaps = net(x)

                loss = loss_function(coords, heatmaps, y)
                epoch_train_loss += loss.float().data.item()
                if config["normalize_loss"]:
                    if loss.detach().data.item() > 0:
                        loss = loss / loss.detach()

                if config["predict_distances_weight"] > 0:
                    bs = x.shape[0]
                    distance_pred = torch.zeros(bs, n_lm, n_lm, 2)
                    distance_pred[:, :, :, 0] = coords[:, :, 0].view(
                        bs, 1, -1) - coords[:, :, 0].view(bs, -1, 1)
                    distance_pred[:, :, :, 1] = coords[:, :, 1].view(
                        bs, 1, -1) - coords[:, :, 1].view(bs, -1, 1)
                    distance_pred = self.to_gpu(distance_pred)
                    dist_loss = dist_loss_fct(distance_pred, dist_y)
                    loss = loss + config[
                        "predict_distances_weight"] * dist_loss / dist_loss.detach(
                        )
                else:
                    dist_loss = 0

                if torch.isnan(loss):
                    print_info(
                        "ERROR! Invalid loss (nan). Aborting training for config %d in epoch %d"
                        % (config_id, epoch))
                    raise LossException("loss was nan in config %d, epoch %d" %
                                        (config_id, epoch))
                if torch.isinf(loss):
                    print_info(
                        "ERROR! Invalid loss (inf). Aborting training for config %d in epoch %d"
                        % (config_id, epoch))
                    raise LossException("loss was inf in config %d, epoch %d" %
                                        (config_id, epoch))

                loss.backward()
                optimizer.step()

                #### end batch

            epoch_train_loss /= epoch_sample_count  # normalize loss by images that were processed

            val_loss = self.evaluate_model(val_loader, net, loss_function)
            scheduler.step(val_loss)

            epoch_train_losses.append(epoch_train_loss)
            epoch_val_losses.append(val_loss)
            loss_history[epoch] = {
                'train': epoch_train_losses[-1],
                'val': epoch_val_losses[-1]
            }

            epoch_end_time = time.time()
            epoch_duration = epoch_end_time - epoch_start_time

            metrics = benchmark(net, easy, hard, self.gpu_id)
            all_metrics = {}
            for category, calculator in category_calculator.items():
                error = calculator(metrics)
                all_metrics[category] = error

                if error < lowest_error[
                        category] and error < 1000:  # 100000 is the error for with outline when HG only has 49LM
                    lowest_error[category] = error
                    best_epoch[category] = epoch

                    if error < storage_thresholds[category]:
                        torch.save(
                            {
                                'model': 'pe_hourglass',
                                'epoch': epoch + 1,
                                'state_dict': net.state_dict(),
                                'val_loss': val_loss,
                                'config': config,
                                'category': category,
                                'metrics': all_metrics
                            },
                            os.path.join(
                                self.model_dir,
                                "%d_best_%s.torch" % (config_id, category)))
            metric_history.append(all_metrics)

            print(
                "GPU %d.%d" % (self.gpu_id, self.sub_gpu_id),
                "| conf",
                config_id,
                '| %03d/%03d' % (epoch + 1, epochs),
                '| %ds' % (int(epoch_duration)),
                '| train %0.6f' % epoch_train_losses[-1],
                '| val %0.6f' % epoch_val_losses[-1],
                '| dist %0.6f' % float(dist_loss),
                '| e68 %0.2f [B %0.2f]' %
                (metrics["e68"], lowest_error['e68']),
                '| h68 %0.2f [B %0.2f]' %
                (metrics["h68"], lowest_error['h68']),
                '| e49 %0.2f [B %0.2f]' %
                (metrics["e49"], lowest_error['e49']),
                '| h49 %0.2f [B %0.2f]' %
                (metrics["h49"], lowest_error['h49']),
            )

            if should_stop(val_loss):
                epochs = epoch + 1
                print_info(
                    "EarlyStopping (patience = %d, max_ratio=%f) criterion returned true in epoch %d. Stop training"
                    % (should_stop.patience, should_stop.max_ratio, epochs))
                break

        endtime = time.time()

        # Write a loss plot to CONFIG_ID_loss_plot.txt in the output directory
        # TODO tensorboardX in addition to matplotlib?
        x = np.array(range(epochs))
        plt.plot(x, np.array(epoch_train_losses), 'r', label='Train Loss')
        plt.plot(x, np.array(epoch_val_losses), 'b', label='Val Loss')
        plt.xlabel("Epochs")
        plt.ylabel("Avg. Train and Val Loss")
        plt.title("Variation of train and Val loss with epochs")
        plt.legend(loc='best')
        plt.savefig(os.path.join(self.plot_dir,
                                 "%d_loss_plot.png" % config_id))
        plt.close()

        training_duration = int(endtime - starttime)

        best_epochs = {"best_%s_epoch" % k: v for k, v in best_epoch.items()}
        best_errors = {"best_%s" % k: v for k, v in lowest_error.items()}

        results = {
            "config_id": config_id,
            'dataset': self.data,
            "gpu_id": self.gpu_id,
            "duration_seconds": training_duration,
            "last_epoch":
            epochs,  # is different from n_epoch in case of early stopping
            "trainable_parameters": trainable_parameters,
            **self.config,
            "optimizer_name": optimizer.__class__.__name__,
            **best_epochs,
            "training_loss_last_epoch": epoch_train_losses[-1],
            **best_errors
        }

        # Write results to CONFIG_ID_result.json in the output directory
        with open(os.path.join(self.result_dir, "%d_result.json" % config_id),
                  "w") as f:
            to_write = {
                **results, 'loss_history': loss_history,
                'metric_history': metric_history
            }
            json.dump(to_write, f, indent=4)

        torch.cuda.empty_cache()

        return results