else:
    if is_uncertainty_net:
        criterion = Loss_with_uncertainty(core = loss_core)
    else:
        criterion = loss_fun_core
early_stopping = Early_Stopping(patience = patience)

# Setting up recordings:
all_keys = list(tasks_train.keys()) + list(tasks_test.keys())
data_record = {"loss": {key: [] for key in all_keys}, "loss_sampled": {key: [] for key in all_keys}, "mse": {key: [] for key in all_keys},
               "reg": {key: [] for key in all_keys}, "KLD": {key: [] for key in all_keys}}
info_dict = {"array_id": array_id}
info_dict["data_record"] = data_record
info_dict["model_dict"] = []
record_data(data_record, [exp_id, tasks_train, tasks_test, task_id_list, task_settings, reg_dict, is_uncertainty_net, lr, pre_pooling_neurons, num_backwards, batch_size_task, 
                          struct_param_gen_base, struct_param_pre, struct_param_post, statistics_pooling, activation_gen, activation_model], 
            ["exp_id", "tasks_train", "tasks_test", "task_id_list", "task_settings", "reg_dict", "is_uncertainty_net", "lr", "pre_pooling_neurons", "num_backwards", "batch_size_task",
             "struct_param_gen_base", "struct_param_pre", "struct_param_post", "statistics_pooling", "activation_gen", "activation_model"])

# Training:
for i in range(num_iter + 1):
    chosen_task_keys = np.random.choice(list(tasks_train.keys()), batch_size_task, replace = False).tolist()
    if optim_mode == "indi":
        if is_VAE:
            KLD_total = Variable(torch.FloatTensor([0]), requires_grad = False)
            if is_cuda:
                KLD_total = KLD_total.cuda()
        for task_key, task in tasks_train.items():
            if task_key not in chosen_task_keys:
                continue
            ((X_train, y_train), (X_test, y_test)), _ = task
            for k in range(num_backwards):
Exemplo n.º 2
0
def train(model,
          X,
          y,
          validation_data=None,
          criterion=nn.MSELoss(),
          inspect_interval=10,
          isplot=False,
          **kwargs):
    """minimal version of training. "model" can be a single model or a ordered list of models"""
    def get_regularization(model, **kwargs):
        reg_dict = kwargs["reg_dict"] if "reg_dict" in kwargs else {}
        reg = to_Variable(np.array([0]), is_cuda=X.is_cuda)
        for reg_type, reg_coeff in reg_dict.items():
            reg = reg + model.get_regularization(
                source=[reg_type], mode="L1", **kwargs) * reg_coeff
        return reg

    epochs = kwargs["epochs"] if "epochs" in kwargs else 10000
    lr = kwargs["lr"] if "lr" in kwargs else 5e-3
    optim_type = kwargs["optim_type"] if "optim_type" in kwargs else "adam"
    optim_kwargs = kwargs["optim_kwargs"] if "optim_kwargs" in kwargs else {}
    early_stopping_epsilon = kwargs[
        "early_stopping_epsilon"] if "early_stopping_epsilon" in kwargs else 0
    patience = kwargs["patience"] if "patience" in kwargs else 20
    record_keys = kwargs["record_keys"] if "record_keys" in kwargs else [
        "loss"
    ]
    scheduler_type = kwargs[
        "scheduler_type"] if "scheduler_type" in kwargs else "ReduceLROnPlateau"
    data_record = {key: [] for key in record_keys}
    if patience is not None:
        early_stopping = Early_Stopping(patience=patience,
                                        epsilon=early_stopping_epsilon)

    if validation_data is not None:
        X_valid, y_valid = validation_data
    else:
        X_valid, y_valid = X, y

    # Get original loss:
    loss_original = model.get_loss(X_valid, y_valid, criterion).data[0]
    if "loss" in record_keys:
        record_data(
            data_record,
            [-1, model.get_loss(X_valid, y_valid, criterion).data[0]],
            ["iter", "loss"])
    if "param" in record_keys:
        record_data(data_record,
                    [model.get_weights_bias(W_source="core", b_source="core")],
                    ["param"])
    if "param_grad" in record_keys:
        record_data(data_record, [
            model.get_weights_bias(
                W_source="core", b_source="core", is_grad=True)
        ], ["param_grad"])

    # Setting up optimizer:
    parameters = model.parameters()
    num_params = len(list(model.parameters()))
    if num_params == 0:
        print("No parameters to optimize!")
        loss_value = model.get_loss(X_valid, y_valid, criterion).data[0]
        if "loss" in record_keys:
            record_data(
                data_record,
                [0, model.get_loss(X_valid, y_valid, criterion).data[0]],
                ["iter", "loss"])
        if "param" in record_keys:
            record_data(
                data_record,
                [model.get_weights_bias(W_source="core", b_source="core")],
                ["param"])
        if "param_grad" in record_keys:
            record_data(data_record, [
                model.get_weights_bias(
                    W_source="core", b_source="core", is_grad=True)
            ], ["param_grad"])
        return loss_original, loss_value, data_record
    optimizer = get_optimizer(optim_type, lr, parameters, **optim_kwargs)

    # Set up learning rate scheduler:
    if scheduler_type is not None:
        if scheduler_type == "ReduceLROnPlateau":
            scheduler_patience = kwargs[
                "scheduler_patience"] if "scheduler_patience" in kwargs else 10
            scheduler_factor = kwargs[
                "scheduler_factor"] if "scheduler_factor" in kwargs else 0.1
            scheduler_verbose = kwargs[
                "scheduler_verbose"] if "scheduler_verbose" in kwargs else False
            scheduler = ReduceLROnPlateau(optimizer,
                                          factor=scheduler_factor,
                                          patience=scheduler_patience,
                                          verbose=scheduler_verbose)
        elif scheduler_type == "LambdaLR":
            scheduler_lr_lambda = kwargs[
                "scheduler_lr_lambda"] if "scheduler_lr_lambda" in kwargs else (
                    lambda epoch: 1 / (1 + 0.01 * epoch))
            scheduler = LambdaLR(optimizer, lr_lambda=scheduler_lr_lambda)
        else:
            raise

    # Training:
    to_stop = False
    for i in range(epochs + 1):
        if optim_type != "LBFGS":
            optimizer.zero_grad()
            reg = get_regularization(model, **kwargs)
            loss = model.get_loss(X, y, criterion, **kwargs) + reg
            loss.backward()
            optimizer.step()
        else:
            # "LBFGS" is a second-order optimization algorithm that requires a slightly different procedure:
            def closure():
                optimizer.zero_grad()
                reg = get_regularization(model, **kwargs)
                loss = model.get_loss(X, y, criterion, **kwargs) + reg
                loss.backward()
                return loss

            optimizer.step(closure)
        if i % inspect_interval == 0:
            loss_value = model.get_loss(X_valid, y_valid, criterion).data[0]
            if scheduler_type is not None:
                if scheduler_type == "ReduceLROnPlateau":
                    scheduler.step(loss_value)
                else:
                    scheduler.step()
            if "loss" in record_keys:
                record_data(
                    data_record,
                    [i, model.get_loss(X_valid, y_valid, criterion).data[0]],
                    ["iter", "loss"])
            if "param" in record_keys:
                record_data(
                    data_record,
                    [model.get_weights_bias(W_source="core", b_source="core")],
                    ["param"])
            if "param_grad" in record_keys:
                record_data(data_record, [
                    model.get_weights_bias(
                        W_source="core", b_source="core", is_grad=True)
                ], ["param_grad"])
            if patience is not None:
                to_stop = early_stopping.monitor(loss_value)
        if to_stop:
            break

    loss_value = model.get_loss(X_valid, y_valid, criterion).data[0]
    if isplot:
        import matplotlib.pylab as plt
        plt.semilogy(data_record["iter"], data_record["loss"])
        plt.show()
    return loss_original, loss_value, data_record