Ejemplo n.º 1
0
        optimizer=optimizer,
        n_epochs=ARGS.max_train_epochs,
        val_interval=ARGS.epochs_per_eval,
        patience_early_stopping=ARGS.early_stopping_patience_epochs,
        device=ARGS.device,
        metrics=[],
        val_metric="loss",
        val_metric_mode="min",
    )

    #path = os.path.join(ARGS.model_dir, "{}_model.pk".format(ARGS.model))
    path = os.path.join(ARGS.model_dir,
                        "{}_model.h5".format(ARGS.model))  #save DL model in h5

    if ARGS.model == "plain_ae":
        log.error("Save plain_ae model is not supported at the moment.")
    elif ARGS.model == "conv_ae":
        encoder = model.encoder
        decoder = model.decoder
        save_model(encoder, encoderOpts, decoder, decoderOpts, dataOpts, path)
    else:
        log.error(
            "The model type you would like to save is not supported at the moment. Pls implement."
        )

    log.close()
    """Leftover from previous trials. Could be removed when finalizing the script"""

    #save_model(encoder, encoderOpts)

    #train_loss = []
Ejemplo n.º 2
0
class CsvSplit(object):
    def __init__(
        self,
        split_fracs: Dict[str, float],
        working_dir: (str) = None,
        seed: (int) = None,
        split_per_dir=False,
    ):
        if not np.isclose(np.sum([p for _, p in split_fracs.items()]), 1.):
            raise ValueError("Split probabilities have to sum up to 1.")
        self.split_fracs = split_fracs
        self.working_dir = working_dir
        self.seed = seed
        self.split_per_dir = split_per_dir
        self.splits = defaultdict(list)
        self._logger = Logger("CSVSPLIT")

    """
    Return split for given partition. If there is already an existing CSV split return this split if it is valid or
    in case there exist not a split yet generate a new CSV split
    """

    def load(self, split: str, files: List[Any] = None):

        if split not in self.split_fracs:
            raise ValueError(
                "Provided split '{}' is not in `self.split_fracs`.".format(
                    split))

        if self.splits[split]:
            return self.splits[split]
        if self.working_dir is None:
            self.splits = self._split_with_seed(files)
            return self.splits[split]
        if self.can_load_from_csv():
            if not self.split_per_dir:
                csv_split_files = {
                    split_: (os.path.join(self.working_dir, split_ + ".csv"), )
                    for split_ in self.split_fracs.keys()
                }
            else:
                csv_split_files = {}
                for split_ in self.split_fracs.keys():
                    split_file = os.path.join(self.working_dir, split_)
                    csv_split_files[split_] = []
                    with open(split_file, "r") as f:
                        for line in f.readlines():
                            csv_split_files[split_].append(line.strip())

            for split_ in self.split_fracs.keys():
                for csv_file in csv_split_files[split_]:
                    if not csv_file or csv_file.startswith(r"#"):
                        continue
                    csv_file_path = os.path.join(self.working_dir, csv_file)
                    with open(csv_file_path, "r") as f:
                        reader = csv.reader(f)
                        for item in reader:
                            file_ = os.path.basename(item[0])
                            file_ = os.path.join(os.path.dirname(csv_file),
                                                 file_)
                            self.splits[split_].append(file_)
            return self.splits[split]

        if not self.split_per_dir:
            working_dirs = (self.working_dir, )
        else:
            f_d_map = self._get_f_d_map(files)
            working_dirs = [
                os.path.join(self.working_dir, p) for p in f_d_map.keys()
            ]
        for working_dir in working_dirs:
            splits = self._split_with_seed(
                files if not self.split_per_dir else f_d_map[working_dir])
            for split_ in splits.keys():
                csv_file = os.path.join(working_dir, split_ + ".csv")
                self._logger.debug("Generating {}".format(csv_file))
                if self.split_per_dir:
                    with open(os.path.join(self.working_dir, split_),
                              "a") as f:
                        p = pathlib.Path(csv_file).relative_to(
                            self.working_dir)
                        f.write(str(p) + "\n")
                if len(splits[split_]) == 0:
                    raise ValueError(
                        "Error splitting dataset. Split '{}' has 0 entries".
                        format(split_))
                with open(csv_file, "w", newline="") as fh:
                    writer = csv.writer(fh)
                    for item in splits[split_]:
                        writer.writerow([item])
                self.splits[split_].extend(splits[split_])
        return self.splits[split]

    """
    Check whether it is possible to correctly load information from existing csv files
    """

    def can_load_from_csv(self):
        if not self.working_dir:
            return False
        if self.split_per_dir:
            for split in self.split_fracs.keys():
                split_file = os.path.join(self.working_dir, split)
                if not os.path.isfile(split_file):
                    return False
                self._logger.debug(
                    "Found dataset split file {}".format(split_file))
                with open(split_file, "r") as f:
                    for line in f.readlines():
                        csv_file = line.strip()
                        if not csv_file or csv_file.startswith(r"#"):
                            continue
                        if not os.path.isfile(
                                os.path.join(self.working_dir, csv_file)):
                            self._logger.error(
                                "File not found: {}".format(csv_file))
                            raise ValueError(
                                "Split file found, but csv files are missing. "
                                "Aborting...")
        else:
            for split in self.split_fracs.keys():
                csv_file = os.path.join(self.working_dir, split + ".csv")
                if not os.path.isfile(csv_file):
                    return False
                self._logger.debug("Found csv file {}".format(csv_file))
        return True

    """
    Create a mapping from directory to containing files.
    """

    def _get_f_d_map(self, files: List[Any]):

        f_d_map = defaultdict(list)
        if self.working_dir is not None:
            for f in files:
                f_d_map[str(pathlib.Path(
                    self.working_dir).joinpath(f).parent)].append(f)
        else:
            for f in files:
                f_d_map[str(
                    pathlib.Path(".").resolve().joinpath(f).parent)].append(f)
        return f_d_map

    """
    Randomly splits the dataset using given seed
    """

    def _split_with_seed(self, files: List[Any]):
        if not files:
            raise ValueError("Provided list `files` is `None`.")
        if self.seed:
            random.seed(self.seed)
        return self.split_fn(files)

    """
    A generator function that returns all values for the given `split`.
    """

    def split_fn(self, files: List[Any]):
        _splits = np.split(
            ary=random.sample(files, len(files)),
            indices_or_sections=[
                int(p * len(files)) for _, p in self.split_fracs.items()
            ],
        )
        splits = dict()
        for i, key in enumerate(self.splits.keys()):
            splits[key] = _splits[i]
        return splits
                    format(file_names[i], pred_gm[i]))
                df = df.append(dict(
                    zip(df.columns, [file_names[i]] + [pred_gm[i]])),
                               ignore_index=True)

            if ARGS.clustering_dir is not None:
                if Path(ARGS.clustering_dir).exists() == False:
                    os.mkdir(ARGS.clustering_dir)

                df.to_csv(ARGS.clustering_dir + "/gmm_clusters.csv")
                log.info("gmm_clusters csv is saved under directory {}".format(
                    ARGS.clustering_dir))

        else:
            log.error(
                "Pls choose a clustering algorithm - kmeans or gmm (in a case sensitive manner)"
            )

    #summary_dir = ARGS.clustering_dir
    #if summary_dir is not None:
    #df.to_csv(summary_dir + "/Kmeans_clusters")

    #print("km.cluster_centers_ length :", len(km.cluster_centers_))
"""
    if ARGS.decod_dir is not None:
        bottleneck_output = 0
        with torch.no_grad():
            for i in range(len(bottleneck_outputs)):
                bottleneck_output = torch.tensor(bottleneck_outputs[i]).to(ARGS.device)
                #print("Krupal : bottleneck_output shape :", bottleneck_output.shape)
                bottleneck_output = torch.reshape(bottleneck_output, (-1, 4, 4, 8)) #512