Пример #1
0
def get_index(index_path):
    """
    Load a label index
    :param index_path:
    :return:the index
    """
    global previous_index_path
    global previous_indexed_labels
    if 'previous_index_path' in globals(
    ) and index_path == previous_index_path:
        print_debug('Labels index in cache')
        return previous_indexed_labels

    # check if labels have been indexed
    if os.path.isfile(index_path):
        # if model is validation
        print_debug('Loading labels index ' + index_path)
        with open(index_path) as f:
            indexed_labels = json.load(f)
        indexed_labels = {int(k): int(v) for k, v in indexed_labels.items()}
        previous_index_path = index_path
        previous_indexed_labels = indexed_labels

    else:
        print_errors('index ' + index_path + ' does not exist...')
        indexed_labels = None
    return indexed_labels
Пример #2
0
def _to_lgb_dataset(dataset):
    if not hasattr(dataset, 'numpy'):
        print_errors(str(type(dataset)) +
                     ' must implement the numpy method...',
                     do_exit=True)
    data, label = dataset.numpy()
    return lgb.Dataset(data, label=label)
Пример #3
0
def fit(model,
        train,
        test,
        num_boost_round=360,
        verbose_eval=1,
        export=False,
        training_params=None,
        export_params=None,
        **kwargs):
    if not use_gpu():
        print_errors('XGBoost can only be executed on a GPU for the moment',
                     do_exit=True)

    training_params = {} if training_params is None else training_params
    export_params = {} if export_params is None else export_params

    d_test = xgb.DMatrix(np.asarray(test.get_vectors()),
                         label=np.asarray(test.labels))

    if not validation_only:
        print_h1('Training: ' + special_parameters.setup_name)
        print_info("get vectors...")

        X = np.asarray(train.get_vectors())
        y = np.asarray(train.labels)

        d_train = xgb.DMatrix(X, label=y)

        gpu_id = first_device().index

        kwargs['verbosity'] = verbose_level()
        kwargs['gpu_id'] = gpu_id

        eval_list = [(d_test, 'eval'), (d_train, 'train')]

        print_info("fit model...")

        bst = xgb.train(kwargs,
                        d_train,
                        num_boost_round=num_boost_round,
                        verbose_eval=verbose_eval,
                        evals=eval_list,
                        xgb_model=model)

        print_info("Save model...")
        save_model(bst)

    else:
        bst = load_model()

    print_h1('Validation/Export: ' + special_parameters.setup_name)
    predictions = bst.predict(d_test, ntree_limit=bst.best_ntree_limit)
    res = validate(predictions,
                   np.array(test.labels),
                   training_params['metrics']
                   if 'metrics' in training_params else tuple(),
                   final=True)
    print_notification(res, end='')
    if export:
        export_results(test, predictions, **export_params)
Пример #4
0
    def __init__(self, source, nb_try_max=1000, islands_sup=0, close_target=False, auto_restart=True):
        """
        :param root_dir: the root dir of the grib files
        :param polar: the polar path to the file
        :param nb_try_max: the number of allowed tries
        """
        super().__init__()

        self.autorestart = auto_restart

        r = check_source(source)
        if 'path' not in r:
            print_errors('The source ' + source + ' does not contain path', do_exit=True)
        if 'polar' not in r:
            print_errors('The source ' + source + ' does not contain polar', do_exit=True)

        self.root_dir = r['path']

        self.game = None

        self.numpy_grib = None
        self.polar = Polar(path_polar_file=r['polar'])

        self.target = None
        self.position = None
        self.start_position = None

        self.grib_list = [file for file in os.listdir(self.root_dir) if file.endswith('.npz')]

        self.start_timestamp = None
        self.timedelta = None
        self.track = None

        self.score = 0
        self.score_ = 0

        self.nb_try = 0
        self.nb_try_max = nb_try_max

        self.dist = None
        self.old_dist = None
        self.dir = None
        self.sog = None
        self.cog = None

        self.twa = None
        self.tws = None

        self.twd = None

        self.close_target = close_target

        self.islands_sup = islands_sup

        self.bins = np.array([i * 45 for i in range(8)])
        self.start()
Пример #5
0
def _load_checkpoint(model_name, path=None):
    if path is None:
        path = output_path(_checkpoint_path.format(model_name),
                           have_validation=True)

    global _checkpoint
    if not os.path.isfile(path):
        print_errors('{} does not exist'.format(path), do_exit=True)
    print_debug('Loading checkpoint from ' + path)
    _checkpoint[model_name] = torch.load(path)
Пример #6
0
def merge_dict_set(*args):
    """
    args contains a series of dictionaries each followed by the default value. merge_dict_set, will set
    the default values if not already set
    :param args:
    :return:
    """
    if len(args) % 2 != 0:
        print_errors('args must contains a multiple of 2 elements',
                     exception=MergeDictException('multiple of 2'))

    results = []

    for i in range(0, len(args), 2):
        # the currently set parameters
        dictionary = args[i] if args[i] is not None else {}
        if dictionary is not None and type(dictionary) is not dict:
            print_errors('arguments should be either None or a dict',
                         exception=MergeDictException('dict or None'))

        # default values of the parameters
        default = args[i + 1]
        if type(default) is not dict:
            print_errors('default values should be of type dict',
                         exception=MergeDictException('dict'))

        merge_smooth(dictionary, default)
        results.append(dictionary)
    return results
Пример #7
0
def check_source(source_name):
    root_path = os.path.join(special_parameters.root_path,
                             special_parameters.source_path,
                             source_name + '.json')
    if not os.path.isfile(root_path):
        print_errors('the source ' + source_name + ' does not exist...',
                     do_exit=True)

    with open(root_path) as f:
        d = json.load(f)

    if special_parameters.machine not in d:
        print_errors('The source ' + source_name + ' is not available for ' +
                     str(special_parameters.machine),
                     do_exit=True)

    results = {}

    for k, v in d[special_parameters.machine].items():
        if not k.startswith('_'):
            results[k] = v
    results['source_name'] = source_name

    return results
def compute_neural_directions(model,
                              X,
                              absolute_value,
                              threshold,
                              min_activations=10):
    # this method only works on fully connected models
    if type(model) is not fully_connected.Net:
        print_errors(str(type(model)) + ' must be of type ' +
                     str(fully_connected.Net) + '.',
                     do_exit=True)

    layers = [m for m in model.modules()
              if type(m) in (BatchNorm1d, Linear)][:-1]
    final_layers = []
    it = 0

    while it < len(layers):
        # linear layer
        M = layers[it]
        it += 1

        linear_app = M.weight.detach().cpu().numpy()

        if it < len(layers) and type(layers[it]) is BatchNorm1d:
            A = layers[it]
            var = np.diag(A.running_var.cpu().numpy())

            gamma = np.diag(A.weight.detach().cpu().numpy())
            bn = np.matmul(gamma, np.linalg.inv(var))

            linear_app = np.matmul(bn, linear_app)
            it += 1
        final_layers.append(linear_app)

    # get activations
    activations, _ = get_activations(model, X)
    activations, _ = np.unique(activations, return_inverse=True, axis=0)

    # partitions where change is not on the domain are removed.

    for i, v in enumerate(np.all(activations == activations[0, :], axis=0)):
        if v:
            activations[:, i] = 0

    # unique after corrections
    activations, _ = np.unique(activations, return_inverse=True, axis=0)

    vmin = 0. if absolute_value else -1.
    vmax = 1.

    vectors = [[] for _ in range(len(final_layers))]
    n_act = min(min_activations, len(activations))
    print_info("n_act: %d" % n_act)
    for i in range(n_act):

        la = None
        for li, l in enumerate(final_layers):
            activated = activations[i][li * l.shape[0]:(li + 1) * l.shape[0]]

            if la is None:
                la = final_layers[li] * activated[:, np.newaxis]
            else:
                la = np.matmul(final_layers[li], la) * activated[:, np.newaxis]

            for n in la:
                vectors[li].append(n)
            continue

    return vectors, vmin, vmax
def plot_layer_output_corr(model,
                           X,
                           absolute_value,
                           threshold,
                           figure_name='ploc'):
    # this method only works on fully connected models
    if type(model) is not fully_connected.Net:
        print_errors(str(type(model)) + ' must be of type ' +
                     str(fully_connected.Net) + '.',
                     do_exit=True)

    layers = [m for m in model.modules()
              if type(m) in (BatchNorm1d, Linear)][:-1]
    final_layers = []
    it = 0

    while it < len(layers):
        # linear layer
        M = layers[it]
        it += 1

        linear_app = M.weight.detach().numpy()

        if it < len(layers) and type(layers[it]) is BatchNorm1d:
            A = layers[it]
            var = np.diag(A.running_var.numpy())

            gamma = np.diag(A.weight.detach().numpy())
            bn = np.matmul(gamma, np.linalg.inv(var))

            linear_app = np.matmul(bn, linear_app)
            it += 1
        final_layers.append(linear_app)

    # get activations
    activations, _ = get_activations(model, X)
    activations, _ = np.unique(activations, return_inverse=True, axis=0)

    # partition that are not on the domain are removed.

    for i, v in enumerate(np.all(activations == activations[0, :], axis=0)):
        if v:
            activations[:, i] = 0
    # unique after corrections
    activations, _ = np.unique(activations, return_inverse=True, axis=0)

    min_activations = 10

    nb_c_activations = min(activations.shape[0], min_activations)

    plt(figure_name, figsize=(nb_c_activations * 6.4, len(final_layers) * 4.8))
    vmin = 0. if absolute_value else -1.
    vmax = 1.
    for i in range(min_activations):
        la = None
        for li, l in enumerate(final_layers):
            cos = np.zeros((l.shape[0], l.shape[0]))
            activated = activations[i][li * l.shape[0]:(li + 1) * l.shape[0]]

            if la is None:
                la = final_layers[li] * activated[:, np.newaxis]
            else:
                la = np.matmul(final_layers[li], la) * activated[:, np.newaxis]
            for r in range(l.shape[0]):
                for c in range(l.shape[0]):
                    if activations[i, li * l.shape[0] +
                                   r] != 0 and activations[i, li * l.shape[0] +
                                                           c] != 0:
                        cos[r,
                            c] = np.dot(la[r, :], la[c, :]) / np.linalg.norm(
                                la[r, :]) / np.linalg.norm(la[c, :])
                    else:
                        cos[r, c] = None
            if absolute_value:
                cos = np.abs(cos)
            if type(threshold) is not bool:
                cos = (cos > threshold).astype(int)
            plt(figure_name).subplot(len(final_layers), nb_c_activations,
                                     1 + li * nb_c_activations + i)
            plt(figure_name).imshow(cos, vmin=vmin, vmax=vmax)
            plt(figure_name).title('Layer: ' + str(li + 1))
            plt(figure_name).colorbar()
Пример #10
0
def fit(model_z,
        train,
        test,
        val=None,
        training_params=None,
        predict_params=None,
        validation_params=None,
        export_params=None,
        optim_params=None,
        model_selection_params=None):
    """
    This function is the core of an experiment. It performs the ml procedure as well as the call to validation.
    :param training_params: parameters for the training procedure
    :param val: validation set
    :param test: the test set
    :param train: The training set
    :param optim_params:
    :param export_params:
    :param validation_params:
    :param predict_params:
    :param model_z: the model that should be trained
    :param model_selection_params:
    """
    # configuration

    training_params, predict_params, validation_params, export_params, optim_params, \
        cv_params = merge_dict_set(
            training_params, TRAINING_PARAMS,
            predict_params, PREDICT_PARAMS,
            validation_params, VALIDATION_PARAMS,
            export_params, EXPORT_PARAMS,
            optim_params, OPTIM_PARAMS,
            model_selection_params, MODEL_SELECTION_PARAMS
        )

    train_loader, test_loader, val_loader = _dataset_setup(
        train, test, val, **training_params)

    statistics_path = output_path('metric_statistics.dump')

    metrics_stats = Statistics(
        model_z, statistics_path, **
        cv_params) if cv_params.pop('cross_validation') else None

    validation_path = output_path('validation.txt')

    # training parameters
    optim = optim_params.pop('optimizer')
    iterations = training_params.pop('iterations')
    gamma = training_params.pop('gamma')
    loss = training_params.pop('loss')
    log_modulo = training_params.pop('log_modulo')
    val_modulo = training_params.pop('val_modulo')
    first_epoch = training_params.pop('first_epoch')

    # callbacks for ml tests
    vcallback = validation_params.pop(
        'vcallback') if 'vcallback' in validation_params else None

    if iterations is None:
        print_errors(
            'Iterations must be set',
            exception=TrainingConfigurationException('Iterations is None'))

    # before ml callback
    if vcallback is not None and special_parameters.train and first_epoch < max(
            iterations):
        init_callbacks(vcallback, val_modulo,
                       max(iterations) // val_modulo, train_loader.dataset,
                       model_z)

    max_iterations = max(iterations)

    if special_parameters.train and first_epoch < max(iterations):
        print_h1('Training: ' + special_parameters.setup_name)

        loss_logs = [] if first_epoch < 1 else load_loss('loss_train')

        loss_val_logs = [] if first_epoch < 1 else load_loss('loss_validation')

        opt = create_optimizer(model_z.parameters(), optim, optim_params)

        scheduler = MultiStepLR(opt, milestones=list(iterations), gamma=gamma)

        # number of batches in the ml
        epoch_size = len(train_loader)

        # one log per epoch if value is -1
        log_modulo = epoch_size if log_modulo == -1 else log_modulo

        epoch = 0
        for epoch in range(max_iterations):

            if epoch < first_epoch:
                # opt.step()
                _skip_step(scheduler, epoch)
                continue
            # saving epoch to enable restart
            export_epoch(epoch)
            model_z.train()

            # printing new epoch
            print_h2('-' * 5 + ' Epoch ' + str(epoch + 1) + '/' +
                     str(max_iterations) + ' (lr: ' + str(scheduler.get_lr()) +
                     ') ' + '-' * 5)

            running_loss = 0.0

            for idx, data in enumerate(train_loader):

                # get the inputs
                inputs, labels = data

                # wrap labels in Variable as input is managed through a decorator
                # labels = model_z.p_label(labels)
                if use_gpu():
                    labels = labels.cuda()

                # zero the parameter gradients
                opt.zero_grad()
                outputs = model_z(inputs)
                loss_value = loss(outputs, labels)
                loss_value.backward()

                opt.step()

                # print math
                running_loss += loss_value.item()
                if idx % log_modulo == log_modulo - 1:  # print every log_modulo mini-batches
                    print('[%d, %5d] loss: %.5f' %
                          (epoch + 1, idx + 1, running_loss / log_modulo))

                    # tensorboard support
                    add_scalar('Loss/train', running_loss / log_modulo)
                    loss_logs.append(running_loss / log_modulo)
                    running_loss = 0.0

            # end of epoch update of learning rate scheduler
            scheduler.step(epoch + 1)

            # saving the model and the current loss after each epoch
            save_checkpoint(model_z, optimizer=opt)

            # validation of the model
            if epoch % val_modulo == val_modulo - 1:
                validation_id = str(int((epoch + 1) / val_modulo))

                # validation call
                predictions, labels, loss_val = predict(
                    model_z, val_loader, loss, **predict_params)
                loss_val_logs.append(loss_val)

                res = '\n[validation_id:' + validation_id + ']\n' + validate(
                    predictions,
                    labels,
                    validation_id=validation_id,
                    statistics=metrics_stats,
                    **validation_params)

                # save statistics for robust cross validation
                if metrics_stats:
                    metrics_stats.save()

                print_notification(res)

                if special_parameters.mail == 2:
                    send_email(
                        'Results for XP ' + special_parameters.setup_name +
                        ' (epoch: ' + str(epoch + 1) + ')', res)
                if special_parameters.file:
                    save_file(
                        validation_path,
                        'Results for XP ' + special_parameters.setup_name +
                        ' (epoch: ' + str(epoch + 1) + ')', res)

                # checkpoint
                save_checkpoint(model_z,
                                optimizer=opt,
                                validation_id=validation_id)

                # callback
                if vcallback is not None:
                    run_callbacks(vcallback, (epoch + 1) // val_modulo)

            # save loss
            save_loss(
                {  # // log_modulo * log_modulo in case log_modulo does not divide epoch_size
                    'train': (loss_logs, log_modulo),
                    'validation':
                    (loss_val_logs,
                     epoch_size // log_modulo * log_modulo * val_modulo)
                },
                ylabel=str(loss))

        # saving last epoch
        export_epoch(epoch +
                     1)  # if --restart is set, the train will not be executed

        # callback
        if vcallback is not None:
            finish_callbacks(vcallback)

    # final validation
    if special_parameters.evaluate or special_parameters.export:
        print_h1('Validation/Export: ' + special_parameters.setup_name)
        if metrics_stats is not None:
            # change the parameter states of the model to best model
            metrics_stats.switch_to_best_model()

        predictions, labels, val_loss = predict(model_z,
                                                test_loader,
                                                loss,
                                                validation_size=-1,
                                                **predict_params)

        if special_parameters.evaluate:

            res = validate(predictions,
                           labels,
                           statistics=metrics_stats,
                           **validation_params,
                           final=True)

            print_notification(res, end='')

            if special_parameters.mail >= 1:
                send_email(
                    'Final results for XP ' + special_parameters.setup_name,
                    res)
            if special_parameters.file:
                save_file(
                    validation_path,
                    'Final results for XP ' + special_parameters.setup_name,
                    res)

        if special_parameters.export:
            export_results(test_loader.dataset, predictions, **export_params)

    return metrics_stats
Пример #11
0
def initialize_model(model_name,
                     num_classes,
                     feature_extract,
                     use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        model_ft.input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        model_ft.input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        model_ft.input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512,
                                           num_classes,
                                           kernel_size=(1, 1),
                                           stride=(1, 1))
        model_ft.num_classes = num_classes
        model_ft.input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        model_ft.input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        model_ft.input_size = 299

    else:
        print_errors("Invalid model name, exiting...", do_exit=True)

    return model_ft
def check_extraction(source,
                     save_errors=True,
                     save_filtered=True,
                     id_name='X_key'):
    """
    check if all patches from an occurrences file have been extracted. Can save the list of errors and
    filtered the dataset keeping the correctly extracted data.

    :param id_name: the column that contains the patch id that will be used to construct its path
    :param save_filtered: save the dataframe filtered from the error
    :param save_errors: save the errors found in a file
    :param source: the source referring the occurrence file and the patches path
    """

    # retrieve details of the source
    r = check_source(source)
    if 'occurrences' not in r or 'patches' not in r:
        print_errors(
            'Only sources with occurrences and patches can be checked',
            do_exit=True)

    df = pd.read_csv(r['occurrences'],
                     header='infer',
                     sep=';',
                     low_memory=False)
    nb_errors = 0
    errors = []
    for idx, row in progressbar.progressbar(enumerate(df.iterrows())):
        patch_id = str(int(row[1][id_name]))

        # constructing the path of a patch given its id
        path = os.path.join(r['patches'], patch_id[-2:], patch_id[-4:-2],
                            patch_id + '.npy')

        # if the path does not correspond to a file, then it's an error
        if not os.path.isfile(path):
            errors.append(row[1][id_name])
            nb_errors += 1

    if nb_errors > 0:
        # summary of the error
        print_info(str(nb_errors) + ' errors found during the check...')

        if save_errors:
            # filter the dataframe using the errors
            df_errors = df[df[id_name].isin(errors)]

            error_path = output_path('_errors.csv')
            print_info('Saving error file at: ' + error_path)

            # save dataframe to the error file
            df_errors.to_csv(error_path, header=True, index=False, sep=';')
        if save_filtered:
            # filter the dataframe keeping the non errors
            df_filtered = df[~df[id_name].isin(errors)]
            filtered_path = r['occurrences'] + '.tmp'
            print_info('Saving filtered dataset at: ' + filtered_path)
            df_filtered.to_csv(filtered_path,
                               header=True,
                               index=False,
                               sep=';')
    else:
        print_info('No error has been found!')
Пример #13
0
def fit(model_z, game_class, game_params=None, training_params=None, predict_params=None, validation_params=None,
        export_params=None, optim_params=None):
    """
    This function is the core of an experiment. It performs the ml procedure as well as the call to validation.
    :param game_params:
    :param game_class:
    :param training_params: parameters for the training procedure
    :param optim_params:
    :param export_params:
    :param validation_params:
    :param predict_params:
    :param model_z: the model that should be trained
    """
    # configuration
    game_params, training_params, predict_params, validation_params, export_params, optim_params = merge_dict_set(
        game_params, GAME_PARAMS,
        training_params, TRAINING_PARAMS,
        predict_params, PREDICT_PARAMS,
        validation_params, VALIDATION_PARAMS,
        export_params, EXPORT_PARAMS,
        optim_params, OPTIM_PARAMS
    )

    validation_path = output_path('validation.txt')

    output_size = model_z.output_size if hasattr(model_z, 'output_size') else model_z.module.output_size

    # training parameters
    optim = optim_params.pop('optimizer')
    iterations = training_params.pop('iterations')
    gamma = training_params.pop('gamma')
    batch_size = training_params.pop('batch_size')
    loss = training_params.pop('loss')
    log_modulo = training_params.pop('log_modulo')
    val_modulo = training_params.pop('val_modulo')
    first_epoch = training_params.pop('first_epoch')
    rm_size = training_params.pop('rm_size')
    epsilon_start = training_params.pop('epsilon_start')
    epsilon_end = training_params.pop('epsilon_end')

    evaluate = special_parameters.evaluate
    # export = special_parameters.export
    do_train = special_parameters.train
    max_iterations = max(iterations)

    game = game_class(**game_params)

    replay_memory = ReplayMemory(rm_size)

    if do_train and first_epoch < max(iterations):
        print_h1('Training: ' + special_parameters.setup_name)

        state = unsqueeze(init_game(game, replay_memory, output_size, len(replay_memory)))
        memory_loader = torch.utils.data.DataLoader(
            replay_memory, shuffle=True, batch_size=batch_size,
            num_workers=16, drop_last=True
        )

        if batch_size > len(replay_memory):
            print_errors('Batch size is bigger than available memory...', do_exit=True)

        loss_logs = [] if first_epoch < 1 else load_loss('loss_train')

        loss_val_logs = [] if first_epoch < 1 else load_loss('loss_validation')

        rewards_logs = [] if first_epoch < 1 else load_loss('train_rewards')
        rewards_val_logs = [] if first_epoch < 1 else load_loss('val_rewards')

        epsilon_decrements = np.linspace(epsilon_start, epsilon_end, iterations[-1])

        opt = create_optimizer(model_z.parameters(), optim, optim_params)

        scheduler = MultiStepLR(opt, milestones=list(iterations), gamma=gamma)

        # number of batches in the ml
        epoch_size = len(replay_memory)

        # one log per epoch if value is -1
        log_modulo = epoch_size if log_modulo == -1 else log_modulo

        epoch = 0

        running_loss = 0.0
        running_reward = 0.0
        norm_opt = 0
        norm_exp = 0

        for epoch in range(max_iterations):

            if epoch < first_epoch:
                # opt.step()
                _skip_step(scheduler, epoch)
                continue
            # saving epoch to enable restart
            export_epoch(epoch)

            epsilon = epsilon_decrements[epoch]

            model_z.train()

            # printing new epoch
            print_h2('-' * 5 + ' Epoch ' + str(epoch + 1) + '/' + str(max_iterations) +
                     ' (lr: ' + str(scheduler.get_lr()) + ') ' + '-' * 5)

            for idx, data in enumerate(memory_loader):

                # the two Q-learning steps
                state, _, finish = _exploration(model_z, state, epsilon, game, replay_memory, output_size)

                if finish:
                    # if the game is finished, we save the score
                    running_reward += game.score_
                    norm_exp += 1
                # zero the parameter gradients

                running_loss += _optimization(model_z, data, gamma, opt, loss)
                norm_opt += 1

            if epoch % log_modulo == log_modulo - 1:
                print('[%d, %5d]\tloss: %.5f' % (epoch + 1, idx + 1, running_loss / log_modulo))
                print('\t\t reward: %.5f' % (running_reward / log_modulo))
                loss_logs.append(running_loss / log_modulo)
                rewards_logs.append(running_reward / log_modulo)
                running_loss = 0.0
                running_reward = 0.0
                norm_opt = 0
                norm_exp = 0

            # end of epoch update of learning rate scheduler
            scheduler.step(epoch + 1)

            # saving the model and the current loss after each epoch
            save_checkpoint(model_z, optimizer=opt)

            # validation of the model
            if epoch % val_modulo == val_modulo - 1:
                validation_id = str(int((epoch + 1) / val_modulo))

                # validation call
                loss_val = play(model_z, output_size, game_class, game_params, 1)

                loss_val_logs.append(loss_val)

                res = '\n[validation_id:' + validation_id + ']\n' + str(loss_val)

                print_notification(res)

                if special_parameters.mail == 2:
                    send_email('Results for XP ' + special_parameters.setup_name + ' (epoch: ' + str(epoch + 1) + ')',
                               res)
                if special_parameters.file:
                    save_file(validation_path, 'Results for XP ' + special_parameters.setup_name +
                              ' (epoch: ' + str(epoch + 1) + ')', res)

                # checkpoint
                save_checkpoint(model_z, optimizer=opt, validation_id=validation_id)

            # save loss
            save_loss(
                {  # // log_modulo * log_modulo in case log_modulo does not divide epoch_size
                    'train': (loss_logs, log_modulo),
                    # 'validation': (loss_val_logs, val_modulo)
                },
                ylabel=str(loss)
            )

        # saving last epoch
        export_epoch(epoch + 1)  # if --restart is set, the train will not be executed

    # final validation
    print_h1('Validation/Export: ' + special_parameters.setup_name)
    if evaluate:
        loss_val = play(model_z, output_size, game_class, game_params, 500)

        res = '' + loss_val

        print_notification(res, end='')

        if special_parameters.mail >= 1:
            send_email('Final results for XP ' + special_parameters.setup_name, res)
        if special_parameters.file:
            save_file(validation_path, 'Final results for XP ' + special_parameters.setup_name, res)