def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Load model.
        self.model = affnet.ResNetAffNet(pretrained=config.IS_PRETRAINED,
                                         num_classes=config.NUM_CLASSES)
        self.model.to(config.DEVICE)

        # Load the dataset.
        train_loader, val_loader, test_loader = umd_dataset_loaders.load_umd_train_datasets(
        )
        # create dataloader.
        self.data_loader = test_loader
def main():

    # Init random seeds.
    random.seed(config.RANDOM_SEED)
    np.random.seed(config.RANDOM_SEED)
    torch.manual_seed(config.RANDOM_SEED)
    torch.cuda.manual_seed(config.RANDOM_SEED)

    # Setup Tensorboard.
    print('\nsaving run in .. {}'.format(config.TRAINED_MODELS_DIR))
    if not os.path.exists(config.TRAINED_MODELS_DIR):
        os.makedirs(config.TRAINED_MODELS_DIR)
    writer = SummaryWriter(f'{config.TRAINED_MODELS_DIR}')

    # Load the Model.
    print()
    model = affnet.ResNetAffNet(pretrained=config.IS_PRETRAINED,
                                num_classes=config.NUM_CLASSES)
    model.to(config.DEVICE)

    # Load the dataset.
    train_loader, val_loader, test_loader = umd_dataset_loaders.load_umd_train_datasets(
    )

    # construct optimizer.
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params,
                                lr=config.LEARNING_RATE,
                                weight_decay=config.WEIGHT_DECAY,
                                momentum=config.MOMENTUM)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=config.MILESTONES, gamma=config.GAMMA)

    # Main training loop.
    num_epochs = config.NUM_EPOCHS
    best_Fwb = -np.inf

    for epoch in range(0, num_epochs):
        print()

        if epoch < config.EPOCH_TO_TRAIN_FULL_DATASET:
            is_subsample = True
        else:
            is_subsample = False

        # train & val for one epoch
        model, optimizer = train_utils.train_one_epoch(
            model,
            optimizer,
            train_loader,
            config.DEVICE,
            epoch,
            writer,
            is_subsample=is_subsample)
        model, optimizer = train_utils.val_one_epoch(model,
                                                     optimizer,
                                                     val_loader,
                                                     config.DEVICE,
                                                     epoch,
                                                     writer,
                                                     is_subsample=is_subsample)
        # update learning rate.
        lr_scheduler.step()

        # # eval Fwb
        # model, Fwb = eval_utils.affnet_eval_umd(model, test_loader)
        # writer.add_scalar('eval/Fwb', Fwb, int(epoch))
        # # save best model.
        # if Fwb > best_Fwb:
        #     best_Fwb = Fwb
        #     writer.add_scalar('eval/Best_Fwb', best_Fwb, int(epoch))
        #     checkpoint_path = config.BEST_MODEL_SAVE_PATH
        #     train_utils.save_checkpoint(model, optimizer, epoch, checkpoint_path)
        #     print("Saving best model .. best Fwb={:.5f} ..".format(best_Fwb))
        #
        # checkpoint_path
        checkpoint_path = config.MODEL_SAVE_PATH + 'affnet_epoch_' + np.str(
            epoch) + '.pth'
        train_utils.save_checkpoint(model, optimizer, epoch, checkpoint_path)
def main():

    if SAVE_AND_EVAL_PRED:
        # Init folders
        print('\neval in .. {}'.format(config.UMD_AFF_EVAL_SAVE_FOLDER))

        if not os.path.exists(config.UMD_AFF_EVAL_SAVE_FOLDER):
            os.makedirs(config.UMD_AFF_EVAL_SAVE_FOLDER)

        gt_pred_images = glob.glob(config.UMD_AFF_EVAL_SAVE_FOLDER + '*')
        for images in gt_pred_images:
            os.remove(images)

    # Load the Model.
    print()
    model = affnet.ResNetAffNet(pretrained=config.IS_PRETRAINED,
                                num_classes=config.NUM_CLASSES)
    model.to(config.DEVICE)

    # Load saved weights.
    print(
        f"\nrestoring pre-trained AffNet weights for UMD: {config.RESTORE_UMD_AFFNET_WEIGHTS} .. "
    )
    checkpoint = torch.load(config.RESTORE_UMD_AFFNET_WEIGHTS,
                            map_location=config.DEVICE)
    model.load_state_dict(checkpoint["model"])
    model.eval()

    # Load the dataset.
    test_loader = umd_dataset_loaders.load_umd_eval_datasets(
        random_images=RANDOM_IMAGES,
        num_random=NUM_RANDOM,
        shuffle_images=SHUFFLE_IMAGES)

    # run the predictions.
    for image_idx, (images, targets) in enumerate(test_loader):
        print(f'\nImage:{image_idx+1}/{len(test_loader)}')

        image, target = copy.deepcopy(images), copy.deepcopy(targets)
        images = list(image.to(config.DEVICE) for image in images)

        with torch.no_grad():
            outputs = model(images)
            outputs = [{k: v.to(config.CPU_DEVICE)
                        for k, v in t.items()} for t in outputs]

        # Formatting input.
        image = image[0]
        image = image.to(config.CPU_DEVICE)
        image = np.squeeze(np.array(image)).transpose(1, 2, 0)
        image = np.array(image * (2**8 - 1), dtype=np.uint8)
        H, W, C = image.shape

        # Formatting targets.
        target = target[0]
        target = {k: v.to(config.CPU_DEVICE) for k, v in target.items()}
        target = umd_dataset_utils.format_target_data(image.copy(),
                                                      target.copy())

        # Formatting Output.
        outputs = outputs.pop()
        outputs = eval_utils.affnet_umd_format_outputs(image.copy(),
                                                       outputs.copy())
        outputs = eval_utils.affnet_umd_threshold_outputs(
            image.copy(), outputs.copy())
        outputs, obj_mask_probabilities = eval_utils.affnet_umd_threshold_binary_masks(
            image.copy(), outputs.copy(), False)
        scores = np.array(outputs['scores'], dtype=np.float32).flatten()
        obj_ids = np.array(outputs['obj_ids'], dtype=np.int32).flatten()
        obj_boxes = np.array(outputs['obj_boxes'],
                             dtype=np.int32).reshape(-1, 4)
        aff_ids = np.array(outputs['aff_ids'], dtype=np.int32).flatten()
        aff_binary_masks = np.array(outputs['aff_binary_masks'],
                                    dtype=np.uint8).reshape(-1, H, W)

        for idx in range(len(obj_ids)):
            OBJ_MASK_PROBABILITIES[image_idx,
                                   obj_ids[idx]] = obj_mask_probabilities
            obj_name = "{:<13}".format(
                umd_dataset_utils.map_obj_id_to_name(obj_ids[idx]))
            print(
                f'Object:{obj_name}'
                f'Obj id: {obj_ids[idx]}, '
                f'Score:{scores[idx]:.3f}, ', )

        # Pred bbox.
        print(f'1')
        pred_bbox_img = umd_dataset_utils.draw_bbox_on_img(image=image,
                                                           scores=scores,
                                                           obj_ids=obj_ids,
                                                           boxes=obj_boxes)

        # Pred affordance mask.
        print(f'2')
        pred_aff_mask = umd_dataset_utils.get_segmentation_masks(
            image=image,
            obj_ids=aff_ids,
            binary_masks=aff_binary_masks,
        )
        color_aff_mask = umd_dataset_utils.colorize_aff_mask(pred_aff_mask)
        color_aff_mask = cv2.addWeighted(pred_bbox_img, 0.5, color_aff_mask,
                                         0.5, 0)

        # gt affordance masks.
        print(f'3')
        binary_mask = umd_dataset_utils.get_segmentation_masks(
            image=image,
            obj_ids=target['aff_ids'],
            binary_masks=target['aff_binary_masks'],
        )
        color_binary_mask = umd_dataset_utils.colorize_aff_mask(binary_mask)
        color_binary_mask = cv2.addWeighted(image, 0.5, color_binary_mask, 0.5,
                                            0)

        if SAVE_AND_EVAL_PRED:
            # saving predictions.
            _image_idx = target["image_id"].detach().numpy()[0]
            _image_idx = str(1000000 + _image_idx)[1:]

            gt_name = config.UMD_AFF_EVAL_SAVE_FOLDER + _image_idx + config.TEST_GT_EXT
            pred_name = config.UMD_AFF_EVAL_SAVE_FOLDER + _image_idx + config.TEST_PRED_EXT

            cv2.imwrite(gt_name, target['aff_mask'])
            cv2.imwrite(pred_name, pred_aff_mask)

        # show plot.
        if SHOW_IMAGES:
            cv2.imshow('pred_bbox',
                       cv2.cvtColor(pred_bbox_img, cv2.COLOR_BGR2RGB))
            cv2.imshow('pred_aff_mask',
                       cv2.cvtColor(color_aff_mask, cv2.COLOR_BGR2RGB))
            cv2.imshow('gt_aff_mask',
                       cv2.cvtColor(color_binary_mask, cv2.COLOR_BGR2RGB))
            cv2.waitKey(0)

    print()
    for obj_id in range(1, config.UMD_NUM_OBJECT_CLASSES):
        obj_mask_probabilities = OBJ_MASK_PROBABILITIES[:, obj_id]

        mean = obj_mask_probabilities[np.nonzero(
            obj_mask_probabilities.copy())].mean()
        std = obj_mask_probabilities[np.nonzero(
            obj_mask_probabilities.copy())].std()

        obj_name = "{:<13}".format(
            umd_dataset_utils.map_obj_id_to_name(obj_id))
        print(f'Object:{obj_name}'
              f'Obj id: {obj_id}, '
              f'Mean: {mean: .5f}, '
              # f'Std: {std: .5f}'
              )

    if SAVE_AND_EVAL_PRED:
        print()
        # getting FwB.
        os.chdir(config.MATLAB_SCRIPTS_DIR)
        import matlab.engine
        eng = matlab.engine.start_matlab()
        Fwb = eng.evaluate_umd_affnet(config.UMD_AFF_EVAL_SAVE_FOLDER,
                                      nargout=1)
        os.chdir(config.ROOT_DIR_PATH)
def main():

    # if SAVE_AND_EVAL_PRED:
    #     # Init folders
    #     print('\neval in .. {}'.format(config.ARL_AFF_EVAL_SAVE_FOLDER))
    #
    #     if not os.path.exists(config.ARL_AFF_EVAL_SAVE_FOLDER):
    #         os.makedirs(config.ARL_AFF_EVAL_SAVE_FOLDER)
    #
    #     gt_pred_images = glob.glob(config.ARL_AFF_EVAL_SAVE_FOLDER + '*')
    #     for images in gt_pred_images:
    #         os.remove(images)

    # Load the Model.
    print()
    model = affnet.ResNetAffNet(pretrained=config.IS_PRETRAINED,
                                num_classes=config.NUM_CLASSES)
    model.to(config.DEVICE)

    # Load saved weights.
    print(
        f"\nrestoring pre-trained MaskRCNN weights: {config.RESTORE_ARL_AFFNET_WEIGHTS} .. "
    )
    checkpoint = torch.load(config.RESTORE_ARL_AFFNET_WEIGHTS,
                            map_location=config.DEVICE)
    model.load_state_dict(checkpoint["model"])
    model.eval()

    # Load the dataset.
    test_loader = arl_affpose_dataset_loaders.load_arl_affpose_eval_datasets(
        random_images=RANDOM_IMAGES,
        num_random=NUM_RANDOM,
        shuffle_images=SHUFFLE_IMAGES)

    # run the predictions.
    APs = []
    gt_obj_ids_list, pred_obj_ids_list = [], []
    for image_idx, (images, targets) in enumerate(test_loader):
        print(f'\nImage:{image_idx+1}/{len(test_loader)}')

        image, target = copy.deepcopy(images), copy.deepcopy(targets)
        images = list(image.to(config.DEVICE) for image in images)

        with torch.no_grad():
            outputs = model(images)
            outputs = [{k: v.to(config.CPU_DEVICE)
                        for k, v in t.items()} for t in outputs]

        # Formatting input.
        image = image[0]
        image = image.to(config.CPU_DEVICE)
        image = np.squeeze(np.array(image)).transpose(1, 2, 0)
        image = np.array(image * (2**8 - 1), dtype=np.uint8)
        H, W, C = image.shape

        # Formatting targets.
        target = target[0]
        target = {k: v.to(config.CPU_DEVICE) for k, v in target.items()}
        target = arl_affpose_dataset_utils.format_target_data(
            image.copy(), target.copy())
        gt_obj_ids = np.array(target['obj_ids'], dtype=np.int32).flatten()
        gt_obj_boxes = np.array(target['obj_boxes'],
                                dtype=np.int32).reshape(-1, 4)
        gt_obj_binary_masks = np.array(target['obj_binary_masks'],
                                       dtype=np.uint8).reshape(-1, H, W)

        # Formatting Output.
        outputs = outputs.pop()
        outputs = eval_utils.affnet_format_outputs(image.copy(),
                                                   outputs.copy())
        outputs = eval_utils.affnet_threshold_outputs(image.copy(),
                                                      outputs.copy())
        outputs = eval_utils.maskrcnn_match_pred_to_gt(image.copy(),
                                                       target.copy(),
                                                       outputs.copy())
        scores = np.array(outputs['scores'], dtype=np.float32).flatten()
        obj_ids = np.array(outputs['obj_ids'], dtype=np.int32).flatten()
        obj_boxes = np.array(outputs['obj_boxes'],
                             dtype=np.int32).reshape(-1, 4)
        obj_binary_masks = np.array(outputs['obj_binary_masks'],
                                    dtype=np.uint8).reshape(-1, H, W)
        aff_scores = np.array(outputs['aff_scores'],
                              dtype=np.float32).flatten()
        obj_part_ids = np.array(outputs['obj_part_ids'],
                                dtype=np.int32).flatten()
        aff_ids = np.array(outputs['aff_ids'], dtype=np.int32).flatten()
        aff_binary_masks = np.array(outputs['aff_binary_masks'],
                                    dtype=np.uint8).reshape(-1, H, W)

        # confusion matrix.
        gt_obj_ids_list.extend(gt_obj_ids.tolist())
        pred_obj_ids_list.extend(obj_ids.tolist())

        # for obj_binary_mask, gt_obj_binary_mask in zip(obj_binary_masks, gt_obj_binary_masks):
        #     cv2.imshow('obj_binary_mask', obj_binary_mask*20)
        #     cv2.imshow('gt_obj_binary_mask', gt_obj_binary_mask*20)
        #     cv2.waitKey(0)

        # get average precision.
        AP = eval_utils.compute_ap_range(
            gt_class_id=gt_obj_ids,
            gt_box=gt_obj_boxes,
            gt_mask=gt_obj_binary_masks.reshape(H, W, -1),
            pred_score=scores,
            pred_class_id=obj_ids,
            pred_box=obj_boxes,
            pred_mask=obj_binary_masks.reshape(H, W, -1),
            verbose=False,
        )
        APs.append(AP)

        # print outputs.
        for gt_idx, pred_idx in zip(range(len(gt_obj_ids)),
                                    range(len(obj_ids))):
            gt_obj_name = "{:<15}".format(
                arl_affpose_dataset_utils.map_obj_id_to_name(
                    gt_obj_ids[gt_idx]))
            pred_obj_name = "{:<15}".format(
                arl_affpose_dataset_utils.map_obj_id_to_name(
                    obj_ids[pred_idx]))
            score = scores[pred_idx]
            bbox_iou = eval_utils.get_iou(obj_boxes[pred_idx],
                                          gt_obj_boxes[gt_idx])

            print(
                f'GT: {gt_obj_name}',
                f'Pred: {pred_obj_name}'
                f'Score: {score:.3f},\t\t',
                f'IoU: {bbox_iou:.3f},',
            )
        print("AP @0.5-0.95: {:.5f}".format(AP))

        # visualize bbox.
        pred_bbox_img = arl_affpose_dataset_utils.draw_bbox_on_img(
            image=image, scores=scores, obj_ids=obj_ids, boxes=obj_boxes)

        # visualize affordance masks.
        pred_aff_mask = arl_affpose_dataset_utils.get_segmentation_masks(
            image=image,
            obj_ids=aff_ids,
            binary_masks=aff_binary_masks,
        )
        color_aff_mask = arl_affpose_dataset_utils.colorize_aff_mask(
            pred_aff_mask)
        color_aff_mask = cv2.addWeighted(pred_bbox_img, 0.5, color_aff_mask,
                                         0.5, 0)

        # get obj part mask.
        pred_obj_part_mask = arl_affpose_dataset_utils.get_obj_part_mask(
            image=image,
            obj_part_ids=obj_part_ids,
            aff_binary_masks=aff_binary_masks,
        )
        # visualize object masks.
        pred_obj_mask = arl_affpose_dataset_utils.convert_obj_part_mask_to_obj_mask(
            pred_obj_part_mask)
        color_obj_mask = arl_affpose_dataset_utils.colorize_obj_mask(
            pred_obj_mask)
        color_obj_mask = cv2.addWeighted(pred_bbox_img, 0.5, color_obj_mask,
                                         0.5, 0)

        if SAVE_AND_EVAL_PRED:
            # saving predictions.
            _image_idx = target["image_id"].detach().numpy()[0]
            _image_idx = str(1000000 + _image_idx)[1:]

            gt_name = config.ARL_AFF_EVAL_SAVE_FOLDER + _image_idx + config.TEST_GT_EXT
            pred_name = config.ARL_AFF_EVAL_SAVE_FOLDER + _image_idx + config.TEST_PRED_EXT
            obj_part_name = config.ARL_AFF_EVAL_SAVE_FOLDER + _image_idx + config.TEST_OBJ_PART_EXT

            cv2.imwrite(gt_name, target['aff_mask'])
            cv2.imwrite(pred_name, pred_aff_mask)
            cv2.imwrite(obj_part_name, pred_obj_part_mask)

        # show plot.
        if SHOW_IMAGES:
            cv2.imshow('pred_bbox',
                       cv2.cvtColor(pred_bbox_img, cv2.COLOR_BGR2RGB))
            cv2.imshow('pred_aff_mask',
                       cv2.cvtColor(color_aff_mask, cv2.COLOR_BGR2RGB))
            cv2.imshow('pred_obj_part_mask',
                       cv2.cvtColor(color_obj_mask, cv2.COLOR_BGR2RGB))
            cv2.waitKey(0)

    # Confusion Matrix.
    cm = sklearn_confusion_matrix(y_true=gt_obj_ids_list,
                                  y_pred=pred_obj_ids_list)
    print(f'\n{cm}')

    # Plot Confusion Matrix.
    # eval_utils.plot_confusion_matrix(cm, arl_affpose_dataset_utils.OBJ_NAMES)

    # mAP
    print("\nmAP @0.5-0.95: over {} test images is {:.3f}".format(
        len(APs), np.mean(APs)))

    if SAVE_AND_EVAL_PRED:
        print()
        # getting FwB.
        os.chdir(config.MATLAB_SCRIPTS_DIR)
        import matlab.engine
        eng = matlab.engine.start_matlab()
        Fwb = eng.evaluate_arl_affpose_affnet(config.ARL_AFF_EVAL_SAVE_FOLDER,
                                              nargout=1)
        os.chdir(config.ROOT_DIR_PATH)
def main():

    # Init random seeds.
    random.seed(config.RANDOM_SEED)
    np.random.seed(config.RANDOM_SEED)
    torch.manual_seed(config.RANDOM_SEED)
    torch.cuda.manual_seed(config.RANDOM_SEED)

    # Setup Tensorboard.
    print('\nsaving run in .. {}'.format(config.TRAINED_MODELS_DIR))
    if not os.path.exists(config.TRAINED_MODELS_DIR):
        os.makedirs(config.TRAINED_MODELS_DIR)
    writer = SummaryWriter(f'{config.TRAINED_MODELS_DIR}')

    # Load the Model.
    print()
    model = affnet.ResNetAffNet(pretrained=config.IS_PRETRAINED,
                                num_classes=config.NUM_CLASSES)
    model.to(config.DEVICE)
    torch.cuda.empty_cache()

    # # TODO: Freeze the backbone.
    # model = model_utils.freeze_backbone(model, verbose=True)

    # TODO: Load saved weights.
    print(
        f"\nrestoring pre-trained AffNet weights: {config.RESTORE_SYN_ARL_AFFNET_WEIGHTS} .. "
    )
    checkpoint = torch.load(config.RESTORE_SYN_ARL_AFFNET_WEIGHTS,
                            map_location=config.DEVICE)
    model.load_state_dict(checkpoint["model"])
    model.to(config.DEVICE)

    # Load the dataset.
    train_loader, val_loader, test_loader = arl_affpose_dataset_loaders.load_arl_affpose_train_datasets(
    )

    # Construct an optimizer.
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params,
                                lr=config.LEARNING_RATE,
                                weight_decay=config.WEIGHT_DECAY,
                                momentum=config.MOMENTUM)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=config.MILESTONES, gamma=config.GAMMA)
    # # TODO: Load saved weights.
    # optimizer.load_state_dict(checkpoint["optimizer"])

    # Main training loop.
    num_epochs = config.NUM_EPOCHS
    best_Fwb, best_mAP = -np.inf, -np.inf

    for epoch in range(0, num_epochs):
        print()

        if epoch < config.EPOCH_TO_TRAIN_FULL_DATASET:
            is_subsample = True
        else:
            is_subsample = False

        # train & val for one epoch
        model, optimizer = train_utils.train_one_epoch(
            model,
            optimizer,
            train_loader,
            config.DEVICE,
            epoch,
            writer,
            is_subsample=is_subsample)
        model, optimizer = train_utils.val_one_epoch(model,
                                                     optimizer,
                                                     val_loader,
                                                     config.DEVICE,
                                                     epoch,
                                                     writer,
                                                     is_subsample=is_subsample)
        # update learning rate.
        lr_scheduler.step()

        model, mAP, Fwb = eval_utils.affnet_eval_arl_affpose(
            model, test_loader)
        # eval FwB
        writer.add_scalar('eval/Fwb', Fwb, int(epoch))
        if Fwb > best_Fwb:
            best_Fwb = Fwb
            writer.add_scalar('eval/Best_Fwb', best_Fwb, int(epoch))
            checkpoint_path = config.BEST_MODEL_SAVE_PATH
            train_utils.save_checkpoint(model, optimizer, epoch,
                                        checkpoint_path)
            print("Saving best model .. best Fwb={:.5f} ..".format(best_Fwb))
        # eval mAP
        writer.add_scalar('eval/mAP', mAP, int(epoch))
        if mAP > best_mAP:
            best_mAP = mAP
            writer.add_scalar('eval/Best_mAP', best_mAP, int(epoch))

        # checkpoint_path
        checkpoint_path = config.MODEL_SAVE_PATH + 'affnet_epoch_' + np.str(
            epoch) + '.pth'
        train_utils.save_checkpoint(model, optimizer, epoch, checkpoint_path)