def check_train_load_DA(config,
                        config_kwargs,
                        small_debug=True,
                        all_data=False,
                        activation=None,
                        params={
                            "var": VAR,
                            "tol": TOL
                        }):
    expdir = EXPDIR
    try:
        if not config_kwargs:
            config_kwargs = {}
        assert isinstance(config_kwargs, dict)

        settings = config(**config_kwargs)
        settings.DEBUG = False
        if activation:
            settings.ACTIVATION = activation

        calc_DA_MAE = RUN_DA_IN_TRAINING
        num_epochs_cv = 0
        print_every = 1
        test_every = 1
        lr = 0.0002

        print(settings.__class__.__name__)
        if config_kwargs:
            print(list([(k, v) for (k, v) in config_kwargs.items()]))
        trainer = TrainAE(settings, expdir, calc_DA_MAE)
        expdir = trainer.expdir  #get full path

        model = trainer.train(EPOCHS,
                              learning_rate=lr,
                              test_every=test_every,
                              num_epochs_cv=num_epochs_cv,
                              print_every=print_every,
                              small_debug=small_debug)

        if PRINT_MODEL:
            print(model.layers_encode)
        #test loading
        model, settings = ML_utils.load_model_and_settings_from_dir(expdir)

        model.to(ML_utils.get_device())  #TODO

        x_fp = settings.get_X_fp(True)  #force init X_FP

        res_AE = run_DA_batch(settings, model, all_data, expdir, params)

        print(res_AE.head(10))
        shutil.rmtree(expdir, ignore_errors=False, onerror=None)
    except Exception as e:
        try:
            shutil.rmtree(expdir, ignore_errors=False, onerror=None)
            raise e
        except Exception as z:
            raise e
    def __init__(self,
                 AE_settings,
                 expdir,
                 batch_sz=BATCH,
                 model=None,
                 start_epoch=None):
        """Initilaizes the AE training class.

        ::AE_settings - a settings.config.Config class with the DA settings
        ::expdir - a directory of form `experiments/<possible_path>` to keep logs
        ::calc_DA_MAE - boolean. If True, training will evaluate DA Mean Absolute Error
            during the training cycle. Note: this is *MUCH* slower
        """

        self.settings = AE_settings

        err_msg = """AE_settings must be an AE configuration class"""
        assert self.settings.COMPRESSION_METHOD == "AE", err_msg

        if model is not None:  #for retraining
            assert start_epoch is not None, "If you are RE-training model you must pass start_epoch"
            assert start_epoch >= 0
            self.start_epoch = start_epoch
            self.model = model
            print("Loaded model, ", end="")
        else:
            self.start_epoch = 0
            self.model = ML_utils.load_model_from_settings(AE_settings)
            print("Initialized model, ", end="")

        print("Number of parameters:",
              sum(p.numel() for p in self.model.parameters()))

        self.batch_sz = batch_sz
        self.settings.batch_sz = batch_sz

        self.expdir = init_expdir(expdir)
        self.settings_fp = self.expdir + "settings.txt"

        if self.settings.SAVE == True:
            with open(self.settings_fp, "wb") as f:
                pickle.dump(self.settings, f)
        ML_utils.set_seeds()  #set seeds before init model

        self.device = ML_utils.get_device()
        self.columns = [
            "epoch", "reconstruction_err", "DA_MAE", "DA_ratio_improve_MAE",
            "time_DA(s)", "time_epoch(s)"
        ]
Ejemplo n.º 3
0
 def act_constr(activation_fn):
     if activation_fn == "relu":
         activation_constructor = lambda x, y: nn.ReLU()
     elif activation_fn == "lrelu":
         activation_constructor = lambda x, y: nn.LeakyReLU(0.05)
     elif activation_fn == "GDN":
         activation_constructor = lambda x, y: GDN(x, ML_utils.get_device(),
                                                   y)
     elif callable(activation_fn):
         activation_constructor = lambda x, y: activation_fn
     elif activation_fn == "prelu":  # must be initilalized in situ
         activation_constructor = lambda x, y: nn.PReLU(x)
     else:
         raise NotImplementedError("Activation function not implemented")
     return activation_constructor
    def test_CAE_forward_nobatch(self):
        settings = CAEConfig()
        Cin = settings.get_channels()[0]
        size = (Cin,) + settings.get_n()
        device = ML.get_device()
        x = torch.rand(size, requires_grad=True, device = device)

        model = CAE_3D(**settings.get_kwargs())

        model.to(device)
        try:

            y = model(x)
        except:
            pytest.fail("Unable to do forward pass")
    def test_CAE_linear_latent_nonbatched(self):
        settings = CAEConfig()
        Cin = settings.get_channels()[0]
        size = (Cin, ) + settings.get_n()
        device = ML.get_device()
        x = torch.rand(size, requires_grad=True, device = device)

        model = CAE_3D(**settings.get_kwargs())



        model.to(device)
        encode = model.encode
        try:

            w = encode(x)
        except:
            pytest.fail("Unable to do forward pass")

        assert len(w.shape) == 1, "There should only be one dimension"
        assert w.shape[0] == settings.get_number_modes()
    def training_loop_AE(self,
                         device=None,
                         print_every=2,
                         test_every=5,
                         save_every=5,
                         model_dir=None):
        """Runs a torch AE model training loop.
        NOTE: Ensure that the loss_fn is in mode "sum"
        """
        model = self.model
        self.model_dir = model_dir

        if device == None:
            device = ML_utils.get_device()
        self.device = device

        ML_utils.set_seeds()
        train_losses = []
        test_losses = []

        self.start = self.num_epochs_cv + self.start_epoch
        self.end = self.start_epoch + self.num_epoch
        epoch = self.end - 1  #for case where no training occurs

        for epoch in range(self.start, self.end):

            self.epoch = epoch

            train_loss, test_loss = self.train_one_epoch(
                epoch, print_every, test_every)
            train_losses.append(train_loss)
            if test_loss:
                test_losses.append(test_loss)

        if epoch % save_every != 0 and self.model_dir != None:
            #Save model (if new model hasn't just been saved)
            model_fp_new = "{}{}.pth".format(self.model_dir, epoch)
            torch.save(model.state_dict(), model_fp_new)

        return train_losses, test_losses
Ejemplo n.º 7
0
    def run(self):
        """Generates matrices for VarDA. All returned matrices are in the
        (M X n) or (M x nx x ny x nz) format """

        data = {}
        loader = self.settings.get_loader()
        splitter = SplitData()
        settings = self.settings

        X = loader.get_X(settings)

        train_X, test_X, u_c_std, X, mean, std = splitter.train_test_DA_split_maybe_normalize(
            X, settings)

        if self.u_c is None:
            self.u_c = u_c_std

        #self.u_c = train_X[62] #good
        #self.u_c = train_X[-1] #bad

        # We will take initial condition u_0, as mean of historical data
        if settings.NORMALIZE:
            u_0 = np.zeros_like(mean)  #since the data is mean centred
        else:
            u_0 = mean

        encoder = None
        decoder = None

        device = ML_utils.get_device()
        model = self.AEmodel
        if model:
            model.to(device)

        if self.settings.COMPRESSION_METHOD == "AE":
            #get encoder
            if model is None:
                model = ML_utils.load_model_from_settings(settings)

            def __create_encoderOrDecoder(fn):
                """This returns a function that deals with encoder/decoder
                input dimensions (e.g. adds channel dim for 3D case)"""
                def ret_fn(vec):
                    vec = torch.Tensor(vec).to(device)

                    #for 3D case, unsqueeze for channel
                    if self.settings.THREE_DIM:
                        dims = len(vec.shape)
                        if dims == 3:

                            vec = vec.unsqueeze(0)
                        elif dims == 4:
                            #batched input
                            vec = vec.unsqueeze(1)
                    with torch.no_grad():
                        res = fn(vec).detach().cpu()
                    #for 3D case, squeeze for channel
                    dims = len(res.shape)
                    if self.settings.THREE_DIM and dims > 2:
                        if dims == 4:
                            res = res.squeeze(0)
                        elif dims == 5:  #batched input
                            res = res.squeeze(1)
                    return res.numpy()

                return ret_fn

            encoder = __create_encoderOrDecoder(model.encode)
            decoder = __create_encoderOrDecoder(model.decode)

        H_0, obs_idx = None, None

        if self.settings.REDUCED_SPACE == True:
            if self.settings.COMPRESSION_METHOD == "SVD":
                raise NotImplementedError(
                    "SVD in reduced space not implemented")

            self.settings.OBS_MODE = "all"

            observations, H_0, w_0, d = self.__get_obs_and_d_reduced_space(
                self.settings, self.u_c, u_0, encoder)

        else:
            observations, w_0, d, obs_idx = self.__get_obs_and_d_not_reduced(
                self.settings, self.u_c, u_0, encoder)

        #TODO - **maybe** get rid of this monstrosity...:
        #i.e. you could return a class that has these attributes:

        data = {
            "d": d,
            "G": H_0,
            "observations": observations,
            "model": model,
            "obs_idx": obs_idx,
            "encoder": encoder,
            "decoder": decoder,
            "u_c": self.u_c,
            "u_0": u_0,
            "X": X,
            "train_X": train_X,
            "test_X": test_X,
            "std": std,
            "mean": mean,
            "device": device
        }

        if w_0 is not None:
            data["w_0"] = w_0

        return data
Ejemplo n.º 8
0
    def run(self, print_every=10, print_small=True):

        shuffle = self.settings.SHUFFLE_DATA  #save value
        self.settings.SHUFFLE_DATA = False

        if self.settings.COMPRESSION_METHOD == "SVD":
            if self.settings.REDUCED_SPACE:
                raise NotImplementedError("Cannot have reduced space SVD")

            fp_base = self.settings.get_X_fp().split("/")[-1][1:]

            U = np.load(self.settings.INTERMEDIATE_FP + "U" + fp_base)
            s = np.load(self.settings.INTERMEDIATE_FP + "s" + fp_base)
            W = np.load(self.settings.INTERMEDIATE_FP + "W" + fp_base)

            num_modes = self.settings.get_number_modes()

            V_trunc = SVD.SVD_V_trunc(U, s, W, modes=num_modes)
            V_trunc_plus = SVD.SVD_V_trunc_plus(U, s, W, modes=num_modes)

            self.DA_pipeline = DAPipeline(self.settings)
            DA_data = self.DA_pipeline.data
            DA_data["V_trunc"] = V_trunc
            DA_data["V"] = None
            DA_data["w_0"] = V_trunc_plus @ DA_data.get("u_0").flatten()
            DA_data["V_grad"] = None

        elif self.settings.COMPRESSION_METHOD == "AE":
            if self.model is None:
                raise ValueError(
                    "Must provide an AE torch.nn model if settings.COMPRESSION_METHOD == 'AE'"
                )

            self.DA_pipeline = DAPipeline(self.settings, self.model)
            DA_data = self.DA_pipeline.data

            if self.reconstruction:
                encoder = DA_data.get("encoder")
                decoder = DA_data.get("decoder")

        else:
            raise ValueError(
                "settings.COMPRESSION_METHOD must be in ['AE', 'SVD']")

        self.settings.SHUFFLE_DATA = shuffle

        if self.reconstruction:
            L1 = torch.nn.L1Loss(reduction='sum')
            L2 = torch.nn.MSELoss(reduction="sum")

        totals = {
            "percent_improvement": 0,
            "ref_MAE_mean": 0,
            "da_MAE_mean": 0,
            "mse_DA": 0,
            "mse_ref": 0,
            "counts": 0,
            "l1_loss": 0,
            "l2_loss": 0,
            "time": 0,
            "time_online": 0
        }

        tot_DA_MAE = np.zeros_like(self.control_states[0]).flatten()
        tot_ref_MAE = np.zeros_like(self.control_states[0]).flatten()
        results = []

        if len(self.control_states.shape) in [1, 3]:
            raise ValueError("This is not batched control_state input")
        else:
            num_states = self.control_states.shape[0]

        for idx in range(num_states):
            u_c = self.control_states[idx]
            if self.settings.REDUCED_SPACE:
                self.DA_pipeline.data = VDAInit.provide_u_c_update_data_reduced_AE(
                    DA_data, self.settings, u_c)
            else:
                self.DA_pipeline.data = VDAInit.provide_u_c_update_data_full_space(
                    DA_data, self.settings, u_c)
            t1 = time.time()
            if self.settings.COMPRESSION_METHOD == "AE":
                DA_results = self.DA_pipeline.DA_AE(save_vtu=self.save_vtu)
            elif self.settings.COMPRESSION_METHOD == "SVD":
                DA_results = self.DA_pipeline.DA_SVD(save_vtu=self.save_vtu)
            t2 = time.time()
            t_tot = t2 - t1
            #print("time_online {:.4f}s".format(DA_results["time_online"]))

            if self.reconstruction:
                data_tensor = torch.Tensor(u_c)
                if self.settings.COMPRESSION_METHOD == "AE":
                    device = ML_utils.get_device()
                    #device = ML_utils.get_device(True, 1)

                    data_tensor = data_tensor.to(device)

                    data_hat = decoder(encoder(u_c))
                    data_hat = torch.Tensor(data_hat)
                    data_hat = data_hat.to(device)

                elif self.settings.COMPRESSION_METHOD == "SVD":

                    data_hat = SVD.SVD_reconstruction_trunc(
                        u_c, U, s, W, num_modes)

                    data_hat = torch.Tensor(data_hat)
                with torch.no_grad():
                    l1 = L1(data_hat, data_tensor)
                    l2 = L2(data_hat, data_tensor)
            else:
                l1, l2 = None, None

            result = {}
            result["percent_improvement"] = DA_results["percent_improvement"]
            result["ref_MAE_mean"] = DA_results["ref_MAE_mean"]
            result["da_MAE_mean"] = DA_results["da_MAE_mean"]
            result["counts"] = DA_results["counts"]
            result["mse_ref"] = DA_results["mse_ref"]
            result["mse_DA"] = DA_results["mse_DA"]
            if self.reconstruction:
                result["l1_loss"] = l1.detach().cpu().numpy()
                result["l2_loss"] = l2.detach().cpu().numpy()
            result["time"] = t2 - t1
            result["time_online"] = DA_results["time_online"]
            if self.save_vtu:
                tot_DA_MAE += DA_results.get("da_MAE")
                tot_ref_MAE += DA_results.get("ref_MAE")
            #add to results list (that will become a .csv)
            results.append(result)

            #add to aggregated dict results
            totals = self.__add_result_to_totals(result, totals)

            if idx % print_every == 0 and idx > 0:
                if not print_small:
                    print("idx:", idx)
                self.__print_totals(totals, idx + 1, print_small)
        if not print_small:
            print("------------")
        self.__print_totals(totals, num_states, print_small)
        if not print_small:
            print("------------")

        results_df = pd.DataFrame(results)
        if self.save_vtu:
            tot_DA_MAE /= num_states
            tot_ref_MAE /= num_states
            out_fp_ref = self.save_vtu_fp + "av_ref_MAE.vtu"
            out_fp_DA = self.save_vtu_fp + "av_da_MAE.vtu"
            fluidity.utils.save_vtu(self.settings, out_fp_ref, tot_ref_MAE)
            fluidity.utils.save_vtu(self.settings, out_fp_DA, tot_DA_MAE)

        #save to csv
        if self.csv_fp:
            results_df.to_csv(self.csv_fp)

        if self.plot:
            raise NotImplementedError(
                "plotting functionality not implemented yet")
        return results_df