Пример #1
0
    def resume_checkpoint(self, resume_path):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        """
        tools.print_log("Loading checkpoint: {} ...".format(resume_path),
                        file=self.log_file)
        checkpoint = torch.load(resume_path)
        self.start_epoch = checkpoint["epoch"] + 1
        self.config = checkpoint[
            "config"] if "config" in checkpoint else self.config
        # load accuracy of checkpoint
        self.best_acc = checkpoint["best_acc"]
        self.best_index_sum = checkpoint["best_index_sum"]

        # load architecture params from checkpoint.
        self.model, self.optimizer = self._get_model_opt()
        self.model.load_state_dict(checkpoint["model"])

        # load optimizer state from checkpoint only when optimizer type is not changed.
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        if "source_optimizer" in checkpoint:
            self.source_optimizer.load_state_dict(
                checkpoint["source_optimizer"])
        else:
            self.source_optimizer = self.optimizer

        tools.print_log(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch),
            file=self.log_file)
Пример #2
0
def experiment(config_file, root_dir):
    """
    Normal experiment process

    Args:
        config_file: the path to configuration file
        root_dir: the root directory of dataset

    """
    # set configuration
    config = tools.load_config(config_file)
    conf = ModelConfiguration(**config)
    shutil.copy(conf_path, conf.log_path+"/"+"config.yaml")

    # set dataset
    dataset = AncientDataset(conf=conf, root_dir=root_dir)
    log_file = open(conf.log_file, "a+", encoding="utf-8")
    tools.print_log("Split Dataset", file=log_file)
    dataset.split_dataset(batch_size=32)

    # set trainer
    tools.print_log("Start Training", file=log_file)
    is_single = conf.strategy == "single"
    trainer = SingleDecoderTrainer(conf, dataset) if is_single else PairedDecoderTrainer(conf, dataset)
    trainer.train()
Пример #3
0
    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param save_best: if True, rename the saved checkpoint to "model_best.pth"
        """
        state = {
            "best_acc": self.best_acc,
            "best_index_sum": self.best_index_sum,
            "config": self.config,
            "epoch": epoch,
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }
        filename = os.path.join(self.config.saved_path,
                                "checkpoint-epoch{}.pth".format(epoch))
        torch.save(state, filename)
        tools.print_log("Saving checkpoint: {} ...".format(filename),
                        file=self.log_file)
        if save_best:
            best_path = os.path.join(self.config.saved_path, "model_best.pth")
            torch.save(state, best_path)
            tools.print_log("Saving current best: model_best.pth ...",
                            file=self.log_file)
Пример #4
0
 def _get_source_data(self):
     # full data is used for prediction
     self.source_data_full, self.source_labels_full, self.source_paths = get_dataset_by_path(
         os.path.join(self.ancient_ori_dir, self.source_name),
         self.transform, self.char_list)
     # full expand data is used for prediction
     if self.exp_chars:
         self.source_data_exp, self.source_labels_exp, self.source_paths_exp = get_dataset_by_path(
             self.source_dir, self.transform, self.exp_chars)
     else:
         self.source_data_exp, self.source_labels_exp = None, None
         tools.print_log("Fail to load expansion data!!!")
Пример #5
0
    def __init__(self, config: ModelConfiguration, dataset: AncientDataset):
        self.config = config
        # setup GPU device if available, move models into configured device
        self.device = self.config.device
        self.log_file = open(self.config.log_file, "a+", encoding="utf-8")
        tools.print_log("Using %s!!!" % self.device, file=self.log_file)

        # init models
        self.model, self.optimizer = self._get_model_opt()
        self.core, self.criterion = self.config.core, self.config.criterion
        self.start_epoch, self.best_acc, self.best_index_sum = 1, 0.0, 0
        # init dataset
        self.dataset = dataset
        self.target_data, self.source_data, self.labels = dataset.target_data, dataset.source_data, dataset.labels
        self.prediction_result = []
        self.add_cons = True if self.config.loss_type == "A" else False  # add mmd loss
Пример #6
0
    def train(self, resume_path=None):
        """
        Full training logic
        """
        if resume_path is not None:
            self.resume_checkpoint(resume_path)
        not_improved_count = 0
        for epoch in range(self.start_epoch, self.config.epochs + 1):
            log = {"epoch": "%s/%s" % (epoch, self.config.epochs)}
            self.model.train()
            # save logged information into log dict
            log.update(self._train_epoch())
            best = False
            # self.model.eval()
            self.model.train(True)
            val_result = self.predict()
            log.update(val_result)

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            cur_acc, cur_index_sum = val_result["val_accuracy"], val_result[
                "val_index_sum"]
            # check whether models performance improved or not
            improved = (cur_acc > self.best_acc) or \
                       (cur_acc == self.best_acc and cur_index_sum < self.best_index_sum)

            if improved:
                not_improved_count, self.best_acc, self.best_index_sum, best = 0, cur_acc, cur_index_sum, True
                self._save_checkpoint(epoch, save_best=best)
            else:
                not_improved_count += 1

            # print logged information to the screen
            for key, value in log.items():
                tools.print_log("{:30s}: {}".format(str(key), value),
                                file=self.log_file)

            if not_improved_count > self.config.early_stop:
                tools.print_log(
                    "Validation performance did not improve for %s epochs.So Stop"
                    % self.config.early_stop,
                    file=self.log_file)
                break

            if epoch % self.config.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)
Пример #7
0
def get_cluster_output_df(input_df=None, add_center=True):
    """
    get cluster result and store the result
    Args:
        input_df: A pandas data frame with columns “label", "type", "feature"
        add_center: Whether add center point to final results

    Returns: A pandas data frame with columns “label", "type", "feature", "center", "size"

    """
    if input_df is None:
        tools.print_log("No input")
        return
    columns = input_df.columns
    output_df = pd.DataFrame(columns=columns)
    for (label, char_type), group_df in input_df.groupby(["label", "type"]):
        output_df = output_df.append(run_cluster(label, char_type, group_df, add_center=add_center), ignore_index=True,
                                     sort=False)
    tools.print_log("shape of output after cluster: %s" % str(output_df.shape))
    return output_df
Пример #8
0
def get_reduction_result(input_df=None):
    """
        get reduction result from input DataFrame object

    Args:
        input_df: A pandas data frame with columns “label", "type", "feature"

    Returns:  pandas data frame with columns “label", "type", "feature" and "feature" column is 2-D.

    """
    if input_df is None:
        tools.print_log("No input")
        return
    columns = input_df.columns
    output_df = pd.DataFrame(columns=columns)
    # take the feature columns
    feature = input_df["feature"].values
    feature = np.stack(feature)
    fea_dim = feature[0].shape[0]
    if fea_dim > 512:
        # use PCA to reduce the dimensions
        pca = PCA(n_components=256)
        feature = pca.fit_transform(feature)
        tools.print_log("Variance of pca: %s" % str(np.sum(pca.explained_variance_ratio_)))
    # reduce dimension to 2-D
    feature_reduction = TSNE(n_components=2, n_iter=1000, random_state=42).fit_transform(feature)
    input_df = input_df.drop(columns=["feature"])
    input_df["feature"] = list(feature_reduction)
    feature_reduction = np.array(feature_reduction)
    tools.print_log("Shape of features after reduce dimension: %s" % str(feature_reduction.shape))
    output_df = output_df.append(input_df, ignore_index=True)
    return output_df
    parser.add_argument('--root',
                        '-r',
                        dest="root",
                        metavar='TEXT',
                        help='dataset root directory',
                        default=root)
    args = parser.parse_args()
    conf_paths = ["../configs/sds_jc.yaml"]
    for conf_path in conf_paths:
        # load configuration
        config = tools.load_config(conf_path)
        saved_path = os.path.join("_".join(config["paired_chars"]),
                                  config["core"] + "_" + config["level"])
        conf = ModelConfiguration(**config, saved_path=saved_path)

        tools.print_log("Task: " + "->".join(conf.paired_chars) + ", Core: " +
                        conf.core)

        # load dataset, split train and validation dataset
        dataset = AncientDataset(conf=conf, root_dir=args.root)
        test_path = "cross_dataset/chars_%s_test.csv" % ("_".join(
            conf.paired_chars))
        val_path = "cross_dataset/chars_%s_val.csv" % ("_".join(
            conf.paired_chars))
        if not os.path.exists(val_path):
            tools.make_dir("cross_dataset")
            create_cross_dataset(test_num, dataset.char_list, test_path,
                                 val_path)

        char_lists_combination = pd.read_csv(val_path)
        i = 0
        for val_chars, train_chars in zip(char_lists_combination["val"],