def main():
    # Parse arguments
    _DETECT, _FORECAST = 'detect', 'forecast'
    parser = argparse.ArgumentParser(description='Train ball detector or ball position forecasting model(s).')
    parser.add_argument('--model', nargs=1, type=str, required=True, choices=[_DETECT, _FORECAST],
                        help=f'Determines model to train ("{_DETECT}" or "{_FORECAST}").')
    train_model = parser.parse_args().model[0]
    TRIALS_FILEPATH = tu.source_dir(__file__) / f'../hp_trials_{train_model}.pkl'

    # Ball detector Conv2d backbone layers
    conv_backbone = (
        ('conv2d', {'out_channels': 4, 'kernel_size': (3, 3), 'padding': 0}),
        ('conv2d', {'out_channels': 4, 'kernel_size': (3, 3), 'padding': 0}),
        ('conv2d', {'out_channels': 4, 'kernel_size': (3, 3), 'padding': 0}),
        ('avg_pooling', {'kernel_size': (2, 2), 'stride': (2, 2)}),
        ('conv2d', {'out_channels': 16, 'kernel_size': (5, 5), 'padding': 0}),
        ('conv2d', {'out_channels': 16, 'kernel_size': (5, 5), 'padding': 0}),
        ('avg_pooling', {'kernel_size': (2, 2), 'stride': (2, 2)}),
        ('conv2d', {'out_channels': 32, 'kernel_size': (5, 5), 'padding': 2}),
        ('conv2d', {'out_channels': 32, 'kernel_size': (7, 7), 'padding': 3}),
        ('avg_pooling', {'kernel_size': (2, 2), 'stride': (2, 2)}),
        ('conv2d', {'out_channels': 64, 'kernel_size': (5, 5), 'padding': 2}),
        ('flatten', {}))

    # Define hyperparameter search space (second hp search space iteration) for ball detector (task 1)
    detect_hp_space = {
        'optimizer_params': {'lr': hp.uniform('lr', 1e-6, 1e-3), 'betas': (0.9, 0.999), 'eps': 1e-8,
                             'weight_decay': hp.loguniform('weight_decay', math.log(1e-7), math.log(3e-3)), 'amsgrad': False},
        'scheduler_params': {'step_size': 40, 'gamma': .3},
        # 'scheduler_params': {'max_lr': 1e-2, 'pct_start': 0.3, 'anneal_strategy': 'cos'},
        'batch_size': hp.choice('batch_size', [16, 32, 64]),
        'bce_loss_scale': 0.1,
        'early_stopping': 12,
        'epochs': 90,
        'architecture': {
            'act_fn': nn.ReLU,
            'batch_norm': {'eps': 1e-05, 'momentum': hp.uniform('momentum', 0.05, 0.15), 'affine': True},
            'dropout_prob': hp.choice('dropout_prob', [0., hp.uniform('nonzero_dropout_prob', 0.1, 0.45)]),
            'layers_param': hp.choice('layers_param', [(*conv_backbone, ('fully_connected', {'out_features': 64}),
                                                        ('fully_connected', {})),
                                                       (*conv_backbone, ('fully_connected', {'out_features': 64}),
                                                        ('fully_connected', {'out_features': 128}),
                                                        ('fully_connected', {})),
                                                       (*conv_backbone, ('fully_connected', {'out_features': 128}),
                                                        ('fully_connected', {'out_features': 128}),
                                                        ('fully_connected', {})),
                                                       (*conv_backbone, ('fully_connected', {}))])

    # Define hyperparameter search space for ball position forecasting (task 2)
    forecast_hp_space = {
        'optimizer_params': {'lr': hp.uniform('lr', 5e-6, 1e-4), 'betas': (0.9, 0.999), 'eps': 1e-8, 'weight_decay': hp.loguniform('weight_decay', math.log(1e-7), math.log(1e-2)), 'amsgrad': False},
        'scheduler_params': {'step_size': 30, 'gamma': .3},
        # 'scheduler_params': {'max_lr': 1e-2, 'pct_start': 0.3, 'anneal_strategy': 'cos'},
        'batch_size': hp.choice('batch_size', [16, 32, 64]),
        'early_stopping': 12,
        'epochs': 90,
        'architecture': {
            'act_fn': nn.Tanh,
            'dropout_prob': hp.choice('dropout_prob', [0., hp.uniform('nonzero_dropout_prob', 0.1, 0.45)]),
            # Fully connected network hyperparameters (a final FC inference layer with no dropout nor batchnorm will be added when ball position predictor model is instantiated)
            'fc_params': hp.choice('fc_params', [[{'out_features': 512}, {'out_features': 256}] + [{'out_features': 128}] * 2,
                                                 [{'out_features': 128}] + [{'out_features': 256}] * 2 + [{'out_features': 512}],
                                                 [{'out_features': 128}] + [{'out_features': 256}] * 3,
                                                 [{'out_features': 128}] * 2 + [{'out_features': 256}] * 3,
                                                 [{'out_features': 128}] * 2 + [{'out_features': 256}] * 4,
                                                 [{'out_features': 128}] * 3 + [{'out_features': 256}] * 4])}

    if train_model == _DETECT:
        hp_space = detect_hp_space
        model_module = ball_detector
    elif train_model == _FORECAST:
        hp_space = forecast_hp_space
        model_module = seq_prediction
        print('ERROR: bad model_name provided')  # TODO: logging.error

    # Define hp search objective (runs one hyperparameter trial)
    def _objective(params: dict) -> float:
        print('\n' + '#' * 20 + f' {train_model.upper()} HYPERPARAMETERS TRIAL  ' + '#' * 20 + f'\n{params}')
        # Set seeds for better repducibility
        # Train ball detector model
        _, valid_loss, _ = model_module.train(**params, pbar=False)
        return valid_loss

    print(f'Running hyperparameter search for "{train_model}" model (mini_balls_seq dataset)...')
    trials = Trials()
    best_parameters = fmin(_objective,

    print('\n\n' + '#' * 20 + f'  BEST HYPERPARAMETERS ({train_model.upper()})  ' + '#' * 20)
    print(space_eval(hp_space, best_parameters))

    print('\n\n' + '#' * 20 + f'  TRIALS  ({train_model.upper()})  ' + '#' * 20)

    print('Saving trials with pickle...')
    with open(TRIALS_FILEPATH, 'wb') as f:
        pickle.dump(trials, f)
def train(batch_size: int, architecture: dict, optimizer_params: dict, scheduler_params: dict, bce_loss_scale: float, epochs: int, early_stopping: Optional[int] = None, pbar: bool = True) -> Tuple[float, float, int]:
    """ Initializes dataset, dataloaders, model, optimizer and lr_scheduler for future training """
    # TODO: refactor this to avoid some duplicated code with seq_prediction.init_training()
    # TODO: add path parameter for dataset dir
    # Create balls dataset
    dataset = datasets.BallsCFDetection(tu.source_dir() / r'../../datasets/mini_balls')

    # Create ball detector model and dataloaders
    trainset, validset = datasets.create_dataloaders(dataset, batch_size)
    dummy_img, p, bb = dataset[0]  # Nescessary to retreive input image resolution (assumes all dataset images are of the same size)
    model = BallDetector(dummy_img.shape, (,, **architecture)
    if batch_size > 64:
        model = tu.parrallelize(model)
    print(f'> MODEL ARCHITECTURE:\n{model.__repr__()}')
    print(f'> MODEL CONVOLUTION FEATURE SIZES: {model._conv_features_shapes}')

    # Define optimizer, loss and LR scheduler
    optimizer = torch.optim.Adam(model.parameters(), **optimizer_params)
    bb_metric, pos_metric = torch.nn.MSELoss(), torch.nn.BCEWithLogitsLoss()
    scheduler_params['step_size'] *= len(trainset)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_params)
    # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, steps_per_epoch = len(trainset), epochs = hp['epochs'], **scheduler_params)

    # Create directory for results visualization
    if VIS_DIR is not None:
        shutil.rmtree(VIS_DIR, ignore_errors=True)

    best_valid_loss, best_train_loss = float("inf"), float("inf")
    best_run_epoch = -1
    epochs_since_best_loss = 0

    # Main training loop
    for epoch in range(1, epochs + 1):
        print("\nEpoch %03d/%03d\n" % (epoch, epochs) + '-' * 15)
        train_loss = 0

        trange, update_bar = tu.progess_bar(trainset, '> Training on trainset', min(
            len(trainset.dataset), trainset.batch_size), custom_vars=True, disable=not pbar)
        for (batch_x, colors, bbs) in trange:
            batch_x, colors, bbs =, tu.flatten_batch(, tu.flatten_batch(

            def closure():
                output_colors, output_bbs = model(batch_x)
                loss = bce_loss_scale * pos_metric(output_colors, colors) + bb_metric(output_bbs, bbs)
                return loss
            loss = float(optimizer.step(closure).clone().detach())
            train_loss += loss / len(trainset)
            update_bar(trainLoss=f'{len(trainset) * train_loss / (trange.n + 1):.7f}', lr=f'{float(scheduler.get_lr()[0]):.3E}')

        print(f'>\tDone: TRAIN_LOSS = {train_loss:.7f}')
        valid_loss = evaluate(epoch, model, validset, bce_loss_scale, best_valid_loss, pbar=pbar)
        print(f'>\tDone: VALID_LOSS = {valid_loss:.7f}')
        if best_valid_loss > valid_loss:
            print('>\tBest valid_loss found so far, saving model...')  # TODO: save model
            best_valid_loss, best_train_loss = valid_loss, train_loss
            best_run_epoch = epoch
            epochs_since_best_loss = 0
            epochs_since_best_loss += 1
            if early_stopping is not None and early_stopping > 0 and epochs_since_best_loss >= early_stopping:
                print(f'>\tModel not improving: Ran {epochs_since_best_loss} training epochs without improvement. Early stopping training loop...')

    print(f'>\tBest training results obtained at {best_run_epoch}nth epoch (best_valid_loss = {best_valid_loss:.7f}, best_train_loss = {best_train_loss:.7f}).')
    return best_train_loss, best_valid_loss, best_run_epoch
def train(batch_size: int,
          architecture: dict,
          optimizer_params: dict,
          scheduler_params: dict,
          epochs: int,
          early_stopping: Optional[int] = None,
          pbar: bool = True) -> Tuple[float, float, int]:
    """ Initializes and train seq forecasting model """
    # TODO: refactor this to avoid some duplicated code with ball_detector.init_training()
    # Create balls dataset
    dataset = datasets.BallsCFSeq(tu.source_dir() /

    # Create ball detector model and dataloaders
    trainset, validset = datasets.create_dataloaders(dataset, batch_size)
    input_bb_sequence, colors, target_bb = dataset[
        0]  # Nescessary to retreive input image resolution (assumes all dataset images are of the same size)
    model = SeqPredictor( +,, **architecture)
    if batch_size > 64:
        model = tu.parrallelize(model)

    # Define optimizer, loss and LR scheduler
    optimizer = torch.optim.Adam(model.parameters(), **optimizer_params)
    mse = torch.nn.MSELoss()
    scheduler_params['step_size'] *= len(trainset)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **scheduler_params)

    # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, steps_per_epoch=len(trainset), epochs=hp['epochs'], **scheduler_params)

    # Weight xavier initialization
    def _initialize_weights(module):
        if isinstance(module, nn.Linear):
            # TODO: adapt this line according to act fn in hyperprameteres (like in BallDetector model)


    best_valid_mse, best_train_mse = float("inf"), float("inf")
    best_run_epoch = -1
    epochs_since_best_loss = 0

    # Main training loop
    for epoch in range(1, epochs + 1):
        print("\nEpoch %03d/%03d\n" % (epoch, epochs) + '-' * 15)
        train_mse = 0

        trange, update_bar = tu.progess_bar(trainset,
                                            '> Training on trainset',
                                            disable=not pbar)
        for i, (input_bb_sequence, colors, target_bb) in enumerate(trange):
            batch_x =
            target_bb = tu.flatten_batch(

            def closure():
                output = model(batch_x)
                loss = mse(output, target_bb)
                return loss

            loss = float(optimizer.step(closure).clone().detach())
            train_mse += loss / len(trainset)
                trainMSE=f'{len(trainset) * train_mse / (trange.n + 1):.7f}',

        print(f'\tDone: TRAIN_MSE = {train_mse:.7f}')
        valid_loss = evaluate(model, validset, pbar=pbar)
        print(f'\tDone: TEST_MSE = {valid_loss:.7f}')

        if best_valid_mse > valid_loss:
            print('>\tBest valid_loss found so far, saving model...'
                  )  # TODO: save model
            best_valid_mse, best_train_mse = valid_loss, train_mse
            best_run_epoch = epoch
            epochs_since_best_loss = 0
            epochs_since_best_loss += 1
            if early_stopping is not None and early_stopping > 0 and epochs_since_best_loss >= early_stopping:
                    f'>\tModel not improving: Ran {epochs_since_best_loss} training epochs without improvement. Early stopping training loop...'

        f'>\tBest training results obtained at {best_run_epoch}nth epoch (best_valid_mse={best_valid_mse:.7f}, best_train_mse={best_train_mse:.7f}).'
    return best_train_mse, best_valid_mse, best_run_epoch
import torch.optim as optim
import torch.nn.functional as F
from import DataLoader
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

import balldetect.datasets as datasets
import balldetect.torch_utils as tu
import balldetect.vis as vis
pickle = tu.import_pickle()

__all__ = ['BallDetector', 'init_training', 'train']
__author__ = 'Paul-Emmanuel SOTIR <*****@*****.**>'

DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
VIS_DIR = tu.source_dir() / f'../../visualization_imgs/detector2'

class BallDetector(nn.Module):
    """ Ball detector pytorch module.
    .. class:: BallDetector

    __constants__ = ['_input_shape', '_p_output_size', '_bb_output_size', '_xavier_gain', '_conv_features_shapes', '_conv_out_features']

    def __init__(self, input_shape: torch.Size, output_sizes: tuple, layers_param: list, act_fn: type = nn.ReLU, dropout_prob: float = 0., batch_norm: Optional[dict] = None):
        super(BallDetector, self).__init__()
        self._input_shape = input_shape
        self._p_output_size, self._bb_output_size = output_sizes
        self._xavier_gain = nn.init.calculate_gain(tu.get_gain_name(act_fn))
        self._conv_features_shapes, self._conv_out_features = [], None
def show_bboxes(rgb_array, np_bbox, list_colors, out_fn='./bboxes_on_rgb.png'):
    """ Show the bounding box on a RGB image
    rgb_array: a np.array of shape (H,W,3) - it represents the rgb frame in uint8 type
    np_bbox: np.array of shape (9,4) and a bbox is of type [x1,y1,x2,y2]
    list_colors: list of string of length 9
    assert np_bbox.shape[0] == len(list_colors)

    img_rgb = Image.fromarray(rgb_array, 'RGB')
    draw = ImageDraw.Draw(img_rgb)

    for i in range(len(list_colors)):
        color = COLORS[i]
        x_1, y_1, x_2, y_2 = np_bbox[i]
        draw.rectangle(((x_1, y_1), (x_2, y_2)), outline=color, fill=None)

    #  # TODO: make sure there is a runing graphical server before calling this?

if __name__ == "__main__":
    dataset = BallsCFDetection(tu.source_dir() / r'../datasets/mini_balls/')

    # Get a single image from the dataset and display it
    img, pose, p = dataset.__getitem__(2)


    show_bboxes(img, pose, COLORS, out_fn='_x.png')