コード例 #1
0
ファイル: trainer.py プロジェクト: Trousersfield/ma
def train_all(data_dir: str, output_dir: str, debug: bool = False) -> None:
    pm = PortManager()
    pm.load()
    if len(pm.ports.keys()) < 1:
        raise ValueError("No port data available")

    for _, port in pm.ports.items():
        train(port_name=port.name, data_dir=data_dir, output_dir=output_dir, num_epochs=100, learning_rate=.0004, pm=pm,
              resume_checkpoint=None, debug=debug)
コード例 #2
0
def generate(input_dir: str, output_dir: str, data_source: str) -> None:
    print(f"Generating dataset from directory '{input_dir}'")
    initialize(output_dir)

    # initialize port manager
    pm = PortManager()
    pm.load()
    if len(pm.ports.keys()) < 1:
        raise ValueError("No port data available")

    # iterate all raw .csv files in given directory
    files = sorted(os.listdir(input_dir))
    for idx, file in enumerate(files):
        if file.startswith("aisdk_"):
            generate_dataset(os.path.join(input_dir, file), output_dir,
                             data_source, pm)
    print("Data generation complete!")
コード例 #3
0
def train(output_dir: str, port_name: str = None) -> None:
    pm = PortManager()
    pm.load()
    if len(pm.ports.keys()) < 1:
        raise ValueError("No port data available")
    e = Evaluator.load(os.path.join(output_dir, "eval"))

    if isinstance(port_name, str):
        port = pm.find_port(port_name)
        if port is None:
            raise ValueError(
                f"Unable to associate port with port name '{port_name}'")
        train_port(port, e)
    else:
        # train ports required for transfer
        config = read_json(os.path.join(script_dir, "transfer-config.json"))
        ports = [pm.find_port(port_name) for port_name in config["ports"]]
        if None in ports:
            raise ValueError(
                f"Found type None in list of ports to train: {ports}")
        for port in ports:
            train_port(port, e)
コード例 #4
0
ファイル: dataset.py プロジェクト: Trousersfield/ma
def main(args) -> None:
    if args.command == "generate":
        print("Generating Directory Dataset")
        pm = PortManager()
        pm.load()
        if len(pm.ports) == 0:
            raise ValueError("Port Manager has no ports. Is it initialized?")
        port = pm.find_port(args.port_name)
        data_dir = os.path.join(args.data_dir, "routes", port.name)
        batch_size = int(args.batch_size)
        start_datetime = datetime.now()
        start_time = as_str(start_datetime)
        dataset = RoutesDirectoryDataset(data_dir, batch_size=batch_size, start=0, training_type="base",
                                         start_time=start_time)
        dataset.save_config()
        end_train = int(.8 * len(dataset))
        if not (len(dataset) - end_train) % 2 == 0 and end_train < len(dataset):
            end_train += 1
        end_validate = int(len(dataset) - ((len(dataset) - end_train) / 2))

        train_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, start=0,
                                                                end=end_train)
        validate_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, start=end_train,
                                                                   end=end_validate)
        eval_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, start=end_validate)

        print(f"- - - - - Generated Datasets - - - - - -")
        print(f"Dataset: {len(dataset)}")
        print(f"Train: {len(train_dataset)}")
        print(f"Validate: {len(validate_dataset)}")
        print(f"Eval: {len(eval_dataset)}")

        data, target = dataset[args.data_idx]
        print(f"Dataset at pos {args.data_idx} has shape {data.shape}. Target shape: {target.shape}")
        data, target = train_dataset[args.data_idx]
        print(f"Train at pos {args.data_idx} has shape {data.shape}. Target shape: {target.shape}")
        data, target = validate_dataset[args.data_idx]
        print(f"Validate at pos {args.data_idx} has shape {data.shape}. Target shape: {target.shape}")
        data, target = eval_dataset[args.data_idx]
        print(f"Validate at pos {args.data_idx} has shape {data.shape}. Target shape: {target.shape}")
    elif args.command == "test":
        print(f"Testing Directory Dataset for port {args.port_name} at index {args.data_idx}")
        if args.port_name is None:
            raise ValueError("No port name found in 'args.port_name'. Specify a port name for testing.")
        pm = PortManager()
        pm.load()
        if len(pm.ports) == 0:
            raise LookupError("Unable to load ports! Make sure port manager is fit")
        port = pm.find_port(args.port_name)
        if port is None:
            raise ValueError(f"Unable to associate '{args.port_name}' with any port")
        dataset_dir = os.path.join(args.data_dir, "routes", port.name)
        dataset = RoutesDirectoryDataset.load_from_config(find_latest_dataset_config_path(dataset_dir,
                                                                                          training_type="base"))
        print(f"Loaded dataset config: {dataset.config_path}")
        print(f"Dataset length: {len(dataset)}")
        data, target = dataset[args.data_idx]
        print(f"Data at pos {args.data_idx} of shape {data.shape}:\n{data}")
        print(f"Target at pos {args.data_idx} of shape {target.shape}:\n{target}")
    elif args.command == "test_range":
        print("Testing Directory Dataset in directory 'routes'")
        pm = PortManager()
        pm.load()
        if len(pm.ports) == 0:
            raise LookupError("Unable to load ports! Make sure port manager is fit")
        port = pm.find_port(args.port_name)
        if port is None:
            raise ValueError(f"Unable to associate '{args.port_name}' with any port")

        dataset_dir = os.path.join(args.data_dir, "routes", port.name)
        dataset = RoutesDirectoryDataset.load_from_config(find_latest_dataset_config_path(dataset_dir,
                                                                                          training_type="base"))
        end_train = int(.8 * len(dataset))
        if not (len(dataset) - end_train) % 2 == 0 and end_train < len(dataset):
            end_train += 1
        end_validate = int(len(dataset) - ((len(dataset) - end_train) / 2))

        train_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, start=0, end=end_train)
        validate_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, start=end_train,
                                                                   end=end_validate)
        eval_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, start=end_validate)

        test_dataset(train_dataset)
        test_dataset(validate_dataset)
        test_dataset(eval_dataset)
    elif args.command == "change_data_dir":
        print("Changing Directory Dataset Config's data directory")
        pm = PortManager()
        pm.load()
        if len(pm.ports) == 0:
            raise LookupError("Unable to load ports! Make sure port manager is fit")
        port = pm.find_port(args.port_name)
        if port is None:
            raise ValueError(f"Unable to associate '{args.port_name}' with any port")
        routes_dir = os.path.join(args.data_dir, "routes", port.name)
        config_path = os.path.join(routes_dir, args.config_file_name)
        if not os.path.exists(config_path):
            raise ValueError(f"No config file found at '{config_path}'")
        dataset = RoutesDirectoryDataset.load_from_config(config_path, new_data_dir=routes_dir)
        dataset.save_config()
    else:
        raise ValueError(f"Unknown command: {args.command}")
コード例 #5
0
class TransferManager:
    def __init__(self, config_path: str, routes_dir: str, output_dir: str, transfers: Dict[str, List[str]] = None):
        self.path = os.path.join(script_dir, "TransferManager.tar")
        self.config_path = config_path
        self.routes_dir = routes_dir
        self.output_dir = output_dir
        self.pm = PortManager()
        self.pm.load()
        if len(self.pm.ports.keys()) < 1:
            raise ValueError("No port data available")
        self.transfer_defs = self._generate_transfers()
        self.transfers = {} if transfers is None else transfers

    def save(self) -> None:
        torch.save({
            "config_path": self.config_path,
            "routes_dir": self.routes_dir,
            "output_dir": self.output_dir,
            "transfers": self.transfers if self.transfers else None
        }, self.path)

    @staticmethod
    def load(path: str) -> 'TransferManager':
        if not os.path.exists(path):
            raise ValueError(f"No TransferManager.tar found at '{path}'")
        state_dict = torch.load(path)
        tm = TransferManager(
            config_path=state_dict["config_path"],
            routes_dir=state_dict["routes_dir"],
            output_dir=state_dict["output_dir"],
            transfers=state_dict["transfers"]
        )
        return tm

    def _is_transferred(self, base_port_name: str, target_port_name: str) -> bool:
        return base_port_name in self.transfers and target_port_name in self.transfers[base_port_name]

    def reset(self, base_port: Union[str, Port] = None, target_port: Union[str, Port] = None) -> None:
        if base_port is not None:
            if isinstance(base_port, str):
                orig_name = base_port
                base_port = self.pm.find_port(base_port)
                if base_port is None:
                    raise ValueError(f"Unable to associate port with port name '{orig_name}'")
            if target_port is not None:
                if isinstance(target_port, str):
                    orig_name = target_port
                    target_port = self.pm.find_port(target_port)
                    if target_port is None:
                        raise ValueError(f"Unable to associate port with port name '{orig_name}'")
                self.transfers[base_port.name].remove(target_port.name)
            else:
                del self.transfers[base_port.name]
        else:
            self.transfers = {}
        self.save()

    def transfer(self, source_port_name: str, skip_transferred: bool = True) -> None:
        source_port = self.pm.find_port(source_port_name)
        if source_port is None:
            print(f"No port found for port name '{source_port_name}'")
            return
        if source_port.name in self.transfer_defs:
            transfer_defs = self.transfer_defs[source_port.name]
        else:
            raise ValueError(f"No transfer definition found for port '{source_port.name}'. Make sure config contains "
                             f"transfer definition for '{source_port.name}' and has a base-training model")

        # transfer base model to each port specified in transfer definition
        for transfer_def in transfer_defs:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            training_type = "transfer"
            logger = Logger(training_type, transfer_def.target_log_dir)
            batch_size = 64
            window_width = 128
            # load start_time according to base model for association of models
            _, _, start_time, _, _ = decode_model_file(os.path.split(transfer_def.base_model_path)[1])
            port = self.pm.find_port(transfer_def.target_port_name)
            if port is None:
                raise ValueError(f"Unable to associate port with port name '{transfer_def.target_port_name}'")
            if not os.path.exists(transfer_def.target_routes_dir):
                print(f"Skipping transfer {transfer_def.base_port_name} -> {transfer_def.target_port_name}: No routes")
                continue
            if skip_transferred and self._is_transferred(transfer_def.base_port_name, transfer_def.target_port_name):
                print(f"Skipping transfer {transfer_def.base_port_name} -> {transfer_def.target_port_name}: "
                      f"Already transferred")
                continue

            dataset = RoutesDirectoryDataset(data_dir=transfer_def.target_routes_dir, start_time=start_time,
                                             training_type=training_type, batch_size=batch_size, start=0,
                                             window_width=window_width)
            dataset_file_name = encode_dataset_config_file(start_time, file_type="transfer")
            dataset_config_path = os.path.join(transfer_def.target_routes_dir, dataset_file_name)
            if not os.path.exists(dataset_config_path):
                dataset.save_config()
            else:
                if not os.path.exists(dataset_config_path):
                    raise FileNotFoundError(f"Unable to transfer: No dataset config found at {dataset_config_path}")
                dataset = RoutesDirectoryDataset.load_from_config(dataset_config_path)
            end_train = int(.8 * len(dataset))
            if not (len(dataset) - end_train) % 2 == 0 and end_train < len(dataset):
                end_train += 1
            end_validate = int(len(dataset) - ((len(dataset) - end_train) / 2))

            # use initialized dataset's config for consistent split
            train_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, start=0, end=end_train)
            validate_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, start=end_train,
                                                                       end=end_validate)

            train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=None, drop_last=False, pin_memory=True,
                                                       num_workers=1)
            validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=None, drop_last=False,
                                                          pin_memory=True, num_workers=1)

            model = InceptionTimeModel.load(transfer_def.base_model_path, device=device)
            model.freeze_inception()
            # TODO: optimizer
            # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
            #                                                lr=transfer_def.learning_rate)
            optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                                          lr=transfer_def.learning_rate)

            num_epochs = 10
            loss_history = ([], [])
            elapsed_time_history = []
            criterion: torch.nn.MSELoss = torch.nn.MSELoss()
            min_val_idx = 0

            print(f".:'`!`':. TRANSFERRING PORT {transfer_def.base_port_name} TO {transfer_def.target_port_name} .:'`!`"
                  f"':.")
            print(f"- - Epochs {num_epochs} </> Training examples {len(train_loader)} </> "
                  f"Learning rate {transfer_def.learning_rate} - -")
            print(f"- - Weight decay {0} Window width {window_width} </> Batch size {batch_size} - -")
            print(f"- - Number of model's parameters {num_total_trainable_parameters(model)} device {device} - -")
            logger.write(f"{port.name}-model\n"
                         f"Number of epochs: {num_epochs}\n"
                         f"Learning rate: {transfer_def.learning_rate}\n"
                         f"Total number of parameters: {num_total_parameters(model)}\n"
                         f"Total number of trainable parameters: {num_total_trainable_parameters(model)}")
            # transfer loop
            for epoch in range(num_epochs):
                # re-train model-parameters with requires_grad == True
                print(f"->->->->-> Epoch ({epoch + 1}/{num_epochs}) <-<-<-<-<-<-")
                avg_train_loss, elapsed_time = train_loop(criterion=criterion, model=model, device=device,
                                                          optimizer=optimizer, loader=train_loader)
                loss_history[0].append(avg_train_loss)
                elapsed_time_history.append(elapsed_time)

                # validate model
                avg_validation_loss = validate_loop(criterion=criterion, device=device, model=model,
                                                    optimizer=optimizer, loader=validate_loader)
                loss_history[1].append(avg_validation_loss)

                # check if current model has lowest validation loss (= is current optimal model)
                if loss_history[1][epoch] < loss_history[1][min_val_idx]:
                    min_val_idx = epoch

                logger.write(f"Epoch {epoch + 1}/{num_epochs}:\n"
                             f"\tAvg train loss {avg_train_loss}\n"
                             f"\tAvg val   loss {avg_validation_loss}")

                make_training_checkpoint(model=model, model_dir=transfer_def.target_model_dir, port=port,
                                         start_time=start_time, num_epochs=num_epochs,
                                         learning_rate=transfer_def.learning_rate, weight_decay=.0,
                                         num_train_examples=len(train_loader), loss_history=loss_history,
                                         elapsed_time_history=elapsed_time_history, optimizer=optimizer,
                                         is_optimum=min_val_idx == epoch, base_port_name=transfer_def.base_port_name)
                print(f">>>> Avg losses (MSE) - Train: {avg_train_loss} Validation: {avg_validation_loss} <<<<\n")

            # conclude transfer
            conclude_training(loss_history=loss_history, data_dir=transfer_def.target_output_data_dir,
                              plot_dir=transfer_def.target_plot_dir, port=port, start_time=start_time,
                              elapsed_time_history=elapsed_time_history, plot_title="Transfer loss",
                              training_type=training_type)

            if transfer_def.base_port_name in self.transfers:
                self.transfers[transfer_def.base_port_name].append(transfer_def.target_port_name)
            else:
                self.transfers[transfer_def.base_port_name] = [transfer_def.target_port_name]
            self.save()

    def _generate_transfers(self) -> Dict[str, List[TransferDefinition]]:
        """
        Generate TransferDefinitions based on config.json, containing those ports that have a base training for
        transferring to another port
        :return: Dict of key = port_name, val = List of TransferDefinition
        """
        config = read_json(self.config_path)
        transfer_defs = {}

        for transfer_def in config:
            base_port = self.pm.find_port(transfer_def["base_port"])
            base_port_trainings = self.pm.load_trainings(base_port, self.output_dir, self.routes_dir,
                                                         training_type="base")

            if len(base_port_trainings) == 0:
                print(f"No base-training found for port '{base_port.name}'")
                continue
            print(f"Port {base_port.name} has {len(base_port_trainings)} base-trainings. Using latest")
            base_train = base_port_trainings[-1]
            for target_port_name in transfer_def["target_ports"]:
                target_port = self.pm.find_port(target_port_name)
                if target_port is None:
                    raise ValueError(f"Unable to transfer from port '{base_port.name}'. "
                                     f"No port for '{target_port_name}' found")
                verify_output_dir(self.output_dir, target_port.name)
                td = TransferDefinition(base_port_name=base_port.name,
                                        base_model_path=base_train.model_path,
                                        target_port_name=target_port.name,
                                        target_routes_dir=os.path.join(self.routes_dir, target_port.name),
                                        target_model_dir=os.path.join(self.output_dir, "model", target_port.name),
                                        target_output_data_dir=os.path.join(self.output_dir, "data", target_port.name),
                                        target_plot_dir=os.path.join(self.output_dir, "plot", target_port.name),
                                        target_log_dir=os.path.join(self.output_dir, "log", target_port.name))
                if base_port.name in transfer_defs:
                    transfer_defs[base_port.name].append(td)
                else:
                    transfer_defs[base_port.name] = [td]
        return transfer_defs
コード例 #6
0
class Evaluator:
    def __init__(self, output_dir: str, routes_dir: str, mae_base: Dict[str, float] = None,
                 mae_transfer: Dict[str, float] = None,
                 mae_base_groups: Dict[str, List[Tuple[int, int, int, float, str]]] = None,
                 mae_transfer_groups: Dict[str, List[Tuple[int, int, int, float, str]]] = None) -> None:
        self.output_dir = output_dir
        self.routes_dir = routes_dir
        self.data_dir = os.path.join(output_dir, "data")
        self.eval_dir = os.path.join(output_dir, "eval")
        self.model_dir = os.path.join(output_dir, "model")
        self.path = os.path.join(self.eval_dir, "evaluator.tar")
        if not os.path.exists(self.eval_dir):
            os.makedirs(self.eval_dir)
        self.pm = PortManager()
        self.pm.load()
        if len(self.pm.ports.keys()) < 1:
            raise ValueError("No port data available")
        self.mae_base = mae_base if mae_base is not None else {}
        self.mae_transfer = mae_transfer if mae_transfer is not None else {}
        self.mae_base_groups = mae_base_groups if mae_base_groups is not None else {}
        self.mae_transfer_groups = mae_transfer_groups if mae_transfer_groups is not None else {}

    def save(self):
        torch.save({
            "path": self.path,
            "output_dir": self.output_dir,
            "routes_dir": self.routes_dir,
            "mae_base": self.mae_base if self.mae_base else None,
            "mae_transfer": self.mae_transfer if self.mae_transfer else None,
            "mae_base_groups": self.mae_base_groups if self.mae_base_groups else None,
            "mae_transfer_groups": self.mae_transfer_groups if self.mae_transfer_groups else None
        }, self.path)

    def reset(self):
        self.mae_base = {}
        self.mae_base_groups = {}
        self.mae_transfer = {}
        self.mae_transfer_groups = {}

    @staticmethod
    def load(eval_dir_or_path: str, output_dir: str = None, routes_dir: str = None) -> 'Evaluator':
        path = eval_dir_or_path
        eval_dir, file = os.path.split(eval_dir_or_path)
        if not os.path.exists(eval_dir):
            os.makedirs(eval_dir)
        if not file.endswith(".tar"):
            path = os.path.join(path, "evaluator.tar")
        state_dict = torch.load(path)
        evaluator = Evaluator(
            output_dir=state_dict["output_dir"] if output_dir is None else output_dir,
            routes_dir=state_dict["routes_dir"] if routes_dir is None else routes_dir,
            mae_base=state_dict["mae_base"],
            mae_transfer=state_dict["mae_transfer"],
            mae_base_groups=state_dict["mae_base_groups"],
            mae_transfer_groups=state_dict["mae_transfer_groups"]
        )
        return evaluator

    @staticmethod
    def _encode_base_key(port_name: str, start_time: str) -> str:
        return f"{port_name}_{start_time}"

    @staticmethod
    def _decode_base_key(key: str) -> Tuple[str, str]:
        result = key.split("_")
        return result[0], result[1]

    @staticmethod
    def _encode_transfer_key(source_port: str, target_port: str, start_time: str) -> str:
        return f"{source_port}_{target_port}_{start_time}"

    @staticmethod
    def _decode_transfer_key(key: str) -> Tuple[str, str, str]:
        result = key.split("_")
        return result[0], result[1], result[2]

    def _get_mae_base(self, transfer_key: str, group: bool) -> float:
        source_port, _, start_time = self._decode_transfer_key(transfer_key)
        base_key = self._encode_base_key(source_port, start_time)
        return self.mae_base_groups[base_key] if group else self.mae_base[base_key]

    def export(self) -> None:
        base_keys = sorted(self.mae_base.keys())
        transfer_keys = sorted(self.mae_transfer.keys())
        decoded_transfer_keys = [self._decode_transfer_key(k) for k in transfer_keys]
        with open("evaluation_results.csv", "w", newline="") as file:
            writer = csv.writer(file, delimiter=",")
            writer.writerow(["Base Port", "Target Port", "Start Time", "Base Port MAE", "Transfer Port MAE"])
            for base_key in base_keys:
                base_port, start_time = self._decode_base_key(base_key)
                curr_decoded_transfer_keys = filter(lambda decoded_key: decoded_key[0] == base_port,
                                                    decoded_transfer_keys)

                for decoded_transfer_key in curr_decoded_transfer_keys:
                    transfer_key = self._encode_transfer_key(base_port, decoded_transfer_key[1], start_time)
                    if transfer_key in self.mae_transfer:
                        writer.writerow([base_port, decoded_transfer_key[1], start_time, self.mae_base[base_key],
                                         self.mae_transfer[transfer_key]])
                    else:
                        raise ValueError(f"Unable to retrieve transfer result base port '{base_port}' to "
                                         f"'{decoded_transfer_key[1]}. No such transfer key '{transfer_key}' "
                                         f"(base key: '{base_key}')")

    def set_mae(self, port: Port, start_time: str, mae: Union[float, List[Tuple[int, int, int, float, str]]],
                source_port: Port = None, grouped: bool = False) -> None:
        if source_port is not None:
            transfer_key = self._encode_transfer_key(source_port.name, port.name, start_time)
            if grouped:
                self.mae_transfer_groups[transfer_key] = mae
            else:
                self.mae_transfer[transfer_key] = mae
        else:
            base_key = self._encode_base_key(port.name, start_time)
            if grouped:
                self.mae_base_groups[base_key] = mae
            else:
                self.mae_base[base_key] = mae

    def remove_mae(self, port: Port, start_time: str, source_port: Port = None, grouped: bool = False) -> None:
        if source_port is not None:
            transfer_key = self._encode_transfer_key(source_port.name, port.name, start_time)
            # print(f"transfer key: {transfer_key}")
            # print(f"transfer keys: {self.mae_transfer.keys()}")
            # print(f"transfer group keys: {self.mae_transfer_groups.keys()}")
            if grouped:
                if transfer_key in self.mae_transfer_groups:
                    del self.mae_transfer_groups[transfer_key]
                else:
                    print(f"No grouped transfer result found for port '{port.name}', "
                          f"source_port '{source_port.name}' and start time '{start_time}'")
            else:
                if transfer_key in self.mae_transfer:
                    del self.mae_transfer[transfer_key]
                else:
                    print(f"No transfer result found for port '{port.name}', "
                          f"source_port '{source_port.name}' and start time '{start_time}'")
        else:
            base_key = self._encode_base_key(port.name, start_time)
            if grouped:
                if base_key in self.mae_base_groups:
                    del self.mae_base_groups[base_key]
                else:
                    print(f"No grouped base result found for port '{port.name}' and start time '{start_time}'")
            else:
                if base_key in self.mae_base:
                    del self.mae_base[base_key]
                else:
                    print(f"No base result found for port '{port.name}' and start time '{start_time}'")
        # print(f"base keys: {self.mae_base.keys()}")
        # print(f"base group keys: {self.mae_base_groups.keys()}")
        # print(f"transfer keys: {self.mae_transfer.keys()}")
        # print(f"transfer group keys: {self.mae_transfer_groups.keys()}")

    def eval_port(self, port: Union[str, Port], training_type: str, plot: bool = True) -> None:
        if isinstance(port, str):
            orig_port = port
            port = self.pm.find_port(port)
            if port is None:
                raise ValueError(f"Unable to associate port with port name '{orig_port}'")

        trainings = self.pm.load_trainings(port, self.output_dir, self.routes_dir, training_type=training_type)
        if len(trainings) < 1:
            print(f"Skipping evaluation for port '{port.name}': No {training_type}-training found")
            return

        training = trainings[-1]
        dataset = RoutesDirectoryDataset.load_from_config(training.dataset_config_path)
        end_train = int(.8 * len(dataset))
        if not (len(dataset) - end_train) % 2 == 0 and end_train < len(dataset):
            end_train += 1
        end_validate = int(len(dataset) - ((len(dataset) - end_train) / 2))

        # use initialized dataset's config for consistent split
        eval_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, start=end_validate)

        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=None, drop_last=False, pin_memory=True,
                                                  num_workers=1)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = InceptionTimeModel.load(training.model_path, device).to(device)
        model.eval()

        x = []
        y = []
        print(f">->->->->->-> STARTED EVALUATION OF PORT {port.name} | TRAINING TYPE {training_type} <-<-<-<-<-<-<-<-<")
        criterion = nn.L1Loss(reduction="mean")
        x_in = []
        with torch.no_grad():
            for eval_idx, (data, target) in enumerate(tqdm(eval_loader, desc="Evaluation progress")):
                data = data.to(device)
                target = target.to(device)
                output = model(data)

                x.append(output)
                y.append(target)
                x_in.append(data)

        x_input = torch.cat(x_in, dim=0)
        outputs = torch.cat(x, dim=0)
        targets = torch.cat(y, dim=0)
        loss = criterion(outputs, targets)
        mae = loss.item()
        print(f"Mae loss: {mae} || {as_duration(mae)}")

        mae_groups = self.group_mae(outputs, targets)
        print(f"Mae by duration:\n{mae_groups}")

        if training_type == "transfer":
            model_file = os.path.split(training.model_path)[1]
            _, _, _, _, source_port_name = decode_model_file(model_file)
            source_port = self.pm.find_port(source_port_name)
            if source_port is None:
                raise ValueError(f"Unable to associate port with port name '{source_port_name}")
            self.set_mae(port, training.start_time, mae, source_port=source_port, grouped=False)
            self.set_mae(port, training.start_time, mae_groups, source_port=source_port, grouped=True)
        else:
            self.set_mae(port, training.start_time, mae, grouped=False)
            self.set_mae(port, training.start_time, mae_groups, grouped=True)
        self.save()

        # Obtain Feature Attributions: https://arxiv.org/pdf/1703.01365.pdf
        ig = IntegratedGradients(model)
        ig_attr_test = ig.attribute(x_input, n_steps=50)

        if plot:
            self.plot_grouped_mae(port, training_type=training_type, training=training)
            # self.plot_ig_attr_test(ig_attr_test)

    def eval_all(self, plot: bool = True) -> None:
        """
        Entry point for evaluating all available ports
        :return: None
        """
        # evaluate all ports
        for port in self.pm.ports.values():
            for t in ["base", "transfer"]:
                self.eval_port(port, training_type=t, plot=plot)

        if plot:
            self.plot_ports_by_mae(training_type="base")
            self.plot_ports_by_mae(training_type="transfer")

    def plot(self, port_name: str = None) -> None:
        """
        Generate all general and specific plots for specified/all available ports.
        :param port_name: If specified, plot this port. If not, plot all
        :return: None
        """
        if port_name is not None:
            self.plot_port(port_name)
        else:
            for port in self.pm.ports.values():
                self.plot_port(port)
        self.plot_transfer_effects()

    def plot_port(self, port: Union[str, Port]):
        for t in ["base", "transfer"]:
            self.plot_grouped_mae(port, training_type=t)
            self.plot_ports_by_mae(training_type=t)
        self.plot_transfer_effect(port)

    def plot_grouped_mae(self, port: Union[str, Port], training_type: str, training: TrainingIteration = None) -> None:
        if isinstance(port, str):
            orig_port = port
            port = self.pm.find_port(port)
            if port is None:
                raise ValueError(f"Unable to associate port with port name '{orig_port}'")

        if training is None:
            trainings = self.pm.load_trainings(port, output_dir=self.output_dir, routes_dir=self.routes_dir,
                                               training_type=training_type)
            if len(trainings) > 0:
                training = trainings[-1]
            else:
                print(f"No training of type '{training_type}' found for port '{port.name}'. Skipping plot_grouped_mae")
                return

        source_port_name = None
        if training_type == "base":
            base_key = self._encode_base_key(port.name, training.start_time)
            mae_groups = self.mae_base_groups[base_key]
        else:
            model_file_name = os.path.split(training.model_path)[1]
            _, _, _, _, source_port_name = decode_model_file(model_file_name)
            transfer_key = self._encode_transfer_key(source_port_name, port.name, training.start_time)
            mae_groups = self.mae_transfer_groups[transfer_key]

        plot_path = os.path.join(self.eval_dir, port.name, encode_grouped_mae_plot(training.start_time,
                                                                                   file_type=training_type))
        title = f"Grouped MAE {training_type}-training: Port {port.name}"
        if training_type == "transfer":
            title = f"{title} (Source port: {source_port_name})"
        plot_grouped_maes(mae_groups, title=title, path=plot_path)

    def plot_ports_by_mae(self, training_type: str) -> None:
        result = []
        if training_type == "base":
            for key, mae in self.mae_base.items():
                port_name, start_time = self._decode_base_key(key)
                result.append((mae, port_name))
        elif training_type == "transfer":
            tmp = {}
            for key, mae in self.mae_transfer.items():
                source_port_name, target_port_name, start_time = self._decode_transfer_key(key)
                if target_port_name in tmp:
                    tmp[target_port_name].append(mae)
                else:
                    tmp[target_port_name] = [mae]
            result = [(sum(v) / len(v), k) for k, v in tmp.items()]
        else:
            raise ValueError(f"Unknown training-type '{training_type}'")

        result.sort(key=lambda r: r[0])  # sort by mae
        result = list(map(list, zip(*result)))

        title = f"MAE from {training_type}-training by port"
        if training_type == "transfer":
            title = f"Average {title}"
        plot_ports_by_mae(result[0], result[1], title=title,
                          path=os.path.join(self.output_dir, "eval", f"ports-mae_{training_type}-training.png"))

    def plot_transfer_effect(self, port: Union[str, Port]) -> None:
        """
        What's the cost of transferring a certain port's model to another port?
        MAE of transferred- vs. base-model
        :param port: Port, that has a transferred model from another port
        :return: None
        """
        if isinstance(port, str):
            orig_port = port
            port = self.pm.find_port(port)
            if port is None:
                raise ValueError(f"Unable to associate port with port name '{orig_port}'")
        transfer_trainings = self.pm.load_trainings(port, output_dir=self.output_dir, routes_dir=self.routes_dir,
                                                    training_type="transfer")
        if len(transfer_trainings) < 1:
            print(f"No training of type 'transfer' found for port {port.name}. Skipping plot_transfer_effect")
            return

        transfer_training = transfer_trainings[-1]
        _, _, start_time, _, source_port_name = decode_model_file(os.path.split(transfer_training.model_path)[1])

        base_trainings = self.pm.load_trainings(source_port_name, output_dir=self.output_dir,
                                                routes_dir=self.routes_dir, training_type="base")
        base_trainings = [t for t in base_trainings if t.start_time == start_time]
        if len(base_trainings) != 1:
            raise ValueError(f"Unable to identify base-training for start_time '{start_time}': "
                             f"Got {len(base_trainings)}, expected exactly 1")
        base_training = base_trainings[0]
        base_key = self._encode_base_key(source_port_name, base_training.start_time)
        # print(f"normal keys: {self.mae_base.keys()}")
        # print(f"grouped keys: {self.mae_base_groups.keys()}")
        # print(f"transferred normal keys: {self.mae_transfer.keys()}")
        # print(f"transferred grouped keys: {self.mae_transfer_groups.keys()}")
        transfer_key = self._encode_transfer_key(source_port_name, port.name, start_time)
        base_data = self.mae_base_groups[base_key]
        transfer_data = self.mae_transfer_groups[transfer_key]
        path = os.path.join(self.output_dir, "eval", f"transfer-effect_{source_port_name}-{port.name}.png")
        plot_transfer_effect(base_data, transfer_data, source_port_name, port.name, path)

    def plot_transfer_effects(self, sort: str = "mae_base") -> None:
        """
        MAE of transferred- vs base-model for all ports with matching trainings of type 'base' and 'transfer'
        :param sort: How to sort result data. Options [mae_base, num_data]
        :return: None
        """
        tmp = {}
        for transfer_key, mae_transfer in self.mae_transfer.items():
            source_port_name, target_port_name, _ = self._decode_transfer_key(transfer_key)
            mae_source_base = self._get_mae_base(transfer_key, group=False)

            if target_port_name in tmp:
                tmp[target_port_name][0].append(source_port_name)
                tmp[target_port_name][1].append(mae_source_base)
                tmp[target_port_name][2].append(mae_transfer)
            else:
                tmp[target_port_name] = ([source_port_name], [mae_source_base], [mae_transfer])

        def compute_metrics(key, val: Tuple[List[str], List[float], List[float]]) -> Tuple[str, str, float, str, float,
                                                                                           str, float, str, float,
                                                                                           float, float]:
            """
            :return: Tuple in form of
                transfer_port_name,
                max_mae_source_port_name, max_mae_source_base,
                min_mae_source_port_name, min_mae_source_base,
                max_mae_transfer_port_name, max_mae_transfer,
                min_mae_transfer_port_name, min_mae_transfer,
                avg_mae_base,
                avg_mae_transfer
            """
            max_mae_base = max(val[1])
            max_mae_base_port_name = val[0][val[1].index(max_mae_base)]
            min_mae_base = min(val[1])
            min_mae_base_port_name = val[0][val[1].index(min_mae_base)]
            max_mae_transfer = max(val[2])
            max_mae_transfer_port_name = val[0][val[2].index(max_mae_transfer)]
            min_mae_transfer = min(val[2])
            min_mae_transfer_port_name = val[0][val[2].index(min_mae_transfer)]
            return (key, max_mae_base_port_name, max_mae_base, min_mae_base_port_name, min_mae_base,
                    max_mae_transfer_port_name, max_mae_transfer, min_mae_transfer_port_name, min_mae_transfer,
                    sum(val[1]) / len(val[1]), sum(val[2]) / len(val[2]))

        result = [compute_metrics(key, val) for key, val in tmp.items()]

        if sort == "mae_base":
            result.sort(key=lambda r: r[0])
        result = list(map(list, zip(*result)))

        path = os.path.join(self.output_dir, "eval", f"transfer-effects_{sort}.png")
        plot_transfer_effects(result[0], result[1], result[2], result[3], result[4], result[5], result[6], result[7],
                              result[8], result[9], result[10], path)

    def plot_ig_attr_test(self, result: List[float]) -> None:
        # labels =
        return

    @staticmethod
    def group_mae(outputs: torch.Tensor, targets: torch.Tensor) -> List[Tuple[int, int, int, float, str]]:
        """
        Compute multiple maes for each target duration group
        :param outputs: Predicted values
        :param targets: Target values
        :return: List of tuples. Each tuple represents one group
            [(group_start, group_end, num_data, scaled_mae, group_description), ...]
        """
        # groups = [
        #     (-1, 1800, "0-0.5h"),
        #     (1800, 3600, "0.5-1h"),
        #     (3600, 7200, "1-2h"),
        #     (7200, 10800, "2-3h"),
        #     (10800, 14400, "3-4h"),
        #     (14400, 18000, "4-5h"),
        #     (18000, 21600, "5-6h"),
        #     (21600, 25200, "6-7h"),
        #     (25200, 28800, "7-8h"),
        #     (28800, 32400, "8-9h"),
        #     (32400, 36000, "9-10h"),
        #     (36000, 39600, "10-11h"),
        #     (39600, 43200, "11-12"),
        #     (43200, 86400, "12h - 1 day"),
        #     (86400, 172800, "1 day - 2 days"),
        #     (172800, 259200, "2 days - 3 days"),
        #     (259200, 345600, "3 days - 4 days"),
        #     (345600, 432000, "4 days - 5 days"),
        #     (432000, 518400, "5 days - 6 days"),
        #     (518400, 604800, "6 days - 1 week"),
        #     (604800, 155520000, "1 week - 1 month"),
        #     (155520000, int(data_ranges["label"]["max"]), "> 1 month")
        # ]
        groups = [
            (-1, 1800, "0-0.5h"),
            (1800, 3600, "0.5-1h"),
            (3600, 7200, "1-2h"),
            (7200, 10800, "2-3h"),
            (10800, 14400, "3-4h"),
            (14400, 21600, "4-6h"),
            (21600, 28800, "6-8h"),
            (28800, 36000, "8-10h"),
            (36000, 43200, "10-12h"),
            (43200, 50400, "12-16h"),
            (50400, 64800, "16-20h"),
            (64800, 86400, "20-24h"),
            (86400, 172800, "1-2d"),
            (172800, 259200, "2-3d"),
            (259200, 345600, "3-4d"),
            (345600, 432000, "4-5d"),
            (432000, 518400, "5-6d"),
            (518400, 604800, "6-7d"),
            (604800, 1209600, "1-2w"),
            (1209600, 2419200, "2-4w"),
            (2419200, int(data_ranges["label"]["max"]), "> 4w")
        ]

        def scale(seconds: int) -> float:
            # half_range = (data_ranges["label"]["max"] - data_ranges["label"]["min"]) / 2
            # result = seconds / half_range
            # return -1 + result if seconds < half_range else result
            label_range = data_ranges["label"]["max"]
            return seconds / label_range

        def process_group(x: torch.Tensor, y: torch.Tensor, group: Tuple[int, int, str]) -> Tuple[int, int, int, float,
                                                                                                  str]:
            criterion = nn.L1Loss(reduction="mean")
            mask = (y > scale(group[0])) & (y <= scale(group[1]))
            # mask = (y > group[0]) & (y <= group[1])
            x = x[mask]
            y = y[mask]
            mae = 0.
            num_data = x.shape[0]
            if num_data > 0:
                loss = criterion(x, y)
                mae = loss.item()
            return group[0], group[1], num_data, mae, group[2]

        mae_groups = [process_group(outputs, targets, group) for group in groups]
        return mae_groups
コード例 #7
0
ファイル: transfer.py プロジェクト: Trousersfield/ma
class TransferManager:
    def __init__(self,
                 config_path: str,
                 routes_dir: str,
                 output_dir: str,
                 transfers: Dict[str, List[Tuple[str, int]]] = None):
        self.path = os.path.join(script_dir, "TransferManager.tar")
        self.config_path = config_path
        self.routes_dir = routes_dir
        self.output_dir = output_dir
        self.pm = PortManager()
        self.pm.load()
        if len(self.pm.ports.keys()) < 1:
            raise ValueError("No port data available")
        self.transfer_defs = self._generate_transfers()
        self.transfer_configs = self._generate_configs()
        self.transfers = {} if transfers is None else transfers

    def save(self) -> None:
        torch.save(
            {
                "config_path": self.config_path,
                "routes_dir": self.routes_dir,
                "output_dir": self.output_dir,
                "transfers": self.transfers if self.transfers else None
            }, self.path)

    @staticmethod
    def load(path: str) -> 'TransferManager':
        if not os.path.exists(path):
            raise ValueError(f"No TransferManager.tar found at '{path}'")
        state_dict = torch.load(path)
        tm = TransferManager(config_path=state_dict["config_path"],
                             routes_dir=state_dict["routes_dir"],
                             output_dir=state_dict["output_dir"],
                             transfers=state_dict["transfers"])
        return tm

    def _is_transferred(self, target_port: str, source_port: str,
                        config_uid: int) -> bool:
        if target_port in self.transfers:
            return len([
                t for t in self.transfers[target_port]
                if t[0] == source_port and t[1] == config_uid
            ]) == 1
        return False

    def set_transfer(self, target_port: str, source_port: str,
                     config_uid: int) -> None:
        if target_port in self.transfers:
            self.transfers[target_port].append((source_port, config_uid))
        else:
            self.transfers[target_port] = [(source_port, config_uid)]
        self.save()
        print(f"transfers:\n{self.transfers}")

    def reset_transfer(self,
                       target_port: str = None,
                       source_port: str = None,
                       config_uid: int = None) -> None:
        if target_port is not None:
            if source_port is not None:
                if config_uid is not None:
                    indices = [
                        i for i, t in enumerate(self.transfers)
                        if t[0] == source_port and t[1] == config_uid
                    ]
                else:
                    indices = [
                        i for i, t in enumerate(self.transfers)
                        if t[0] == source_port
                    ]
                [self.transfers[target_port].pop(i) for i in indices]
            else:
                self.transfers[target_port] = []
        else:
            self.transfers = {}
        self.save()

    # def transfer(self, source_port_name: str) -> None:
    def transfer(self,
                 target_port: Port,
                 evaluator: Evaluator,
                 config_uids: List[int] = None) -> None:
        """
        Transfer models to target port
        :param target_port: port for which to train transfer-model
        :param evaluator: evaluator instance to store results
        :param config_uids: specify config_uids to transfer. If none, transfer all
        :return: None
        """
        if target_port.name not in self.transfer_defs:
            print(
                f"No transfer definition found for target port '{target_port.name}'"
            )
            return
        # transfer definitions for specified target port
        tds = self.transfer_defs[target_port.name]
        output_dir = os.path.join(script_dir, os.pardir, "output")
        training_type = "transfer"
        print(f"TRANSFERRING MODELS TO TARGET PORT '{target_port.name}'")
        if config_uids is not None:
            print(f"Transferring configs -> {config_uids} <-")
        window_width = 50
        num_epochs = 25
        train_lr = 0.01
        fine_num_epochs = 20
        fine_tune_lr = 1e-5
        batch_size = 1024

        # skip port if fully transferred
        num_not_transferred = 0
        for td in tds:
            for config in self.transfer_configs:
                if not self._is_transferred(target_port.name,
                                            td.base_port_name, config.uid):
                    # print(f"Not transferred: {td.base_port_name} -> {target_port.name} ({config.uid})")
                    num_not_transferred += 1
        num_transfers = len(tds) * len(self.transfer_configs)
        print(
            f"Transferred count {num_transfers - num_not_transferred}/{num_transfers}"
        )
        if num_not_transferred == 0:
            print(
                f"All transfers done for target port '{target_port.name}': Skipping"
            )
            return
        X_ts, y_ts = load_data(target_port, window_width)

        baseline = mean_absolute_error(y_ts, np.full_like(y_ts, np.mean(y_ts)))
        evaluator.set_naive_baseline(target_port, baseline)
        print(f"Naive baseline: {baseline}")
        # X_train_orig, X_test_orig, y_train_orig, y_test_orig = train_test_split(X_ts, y_ts, test_size=0.2,
        #                                                                         random_state=42, shuffle=False)
        # train_optimizer = Adam(learning_rate=train_lr)
        # fine_tune_optimizer = Adam(learning_rate=fine_tune_lr)

        for td in tds:
            print(
                f".:'`!`':. TRANSFERRING PORT {td.base_port_name} TO {td.target_port_name} .:'`!`':."
            )
            print(
                f"- - Epochs {num_epochs} </>  </> Learning rate {train_lr} - -"
            )
            print(
                f"- - Window width {window_width} </> Batch size {batch_size} - -"
            )
            # print(f"- - Number of model's parameters {num_total_trainable_parameters(model)} device {device} - -")
            base_port = self.pm.find_port(td.base_port_name)
            if base_port is None:
                raise ValueError(
                    f"Unable to associate port with port name '{td.base_port_name}'"
                )

            # model = inception_time(input_shape=(window_width, 37))
            # print(model.summary())

            # apply transfer config
            for config in self.transfer_configs:
                if config_uids is not None and config.uid not in config_uids:
                    continue
                if self._is_transferred(target_port.name, td.base_port_name,
                                        config.uid):
                    print(f"Skipping config {config.uid}")
                    continue
                print(f"\n.:'':. APPLYING CONFIG {config.uid} ::'':.")
                print(f"-> -> {config.desc} <- <-")
                print(f"-> -> nth_subset: {config.nth_subset} <- <-")
                print(f"-> -> trainable layers: {config.train_layers} <- <-")
                _, _, start_time, _, _ = decode_keras_model(
                    os.path.split(td.base_model_path)[1])
                model_file_name = encode_keras_model(td.target_port_name,
                                                     start_time,
                                                     td.base_port_name,
                                                     config.uid)
                file_path = os.path.join(output_dir, "model",
                                         td.target_port_name, model_file_name)

                X_train_orig, X_test_orig, y_train_orig, y_test_orig = train_test_split(
                    X_ts, y_ts, test_size=0.2, random_state=42, shuffle=False)
                train_optimizer = Adam(learning_rate=train_lr)
                fine_tune_optimizer = Adam(learning_rate=fine_tune_lr)

                checkpoint = ModelCheckpoint(file_path,
                                             monitor='val_mae',
                                             mode='min',
                                             verbose=2,
                                             save_best_only=True)
                early = EarlyStopping(monitor="val_mae",
                                      mode="min",
                                      patience=10,
                                      verbose=2)
                redonplat = ReduceLROnPlateau(monitor="val_mae",
                                              mode="min",
                                              patience=3,
                                              verbose=2)
                callbacks_list = [checkpoint, early, redonplat]

                # optimizer = Adam(learning_rate=lr)
                #
                # # configure model
                # model.compile(optimizer=optimizer, loss="mse", metrics=["mae"])

                # load base model
                model = load_model(td.base_model_path)
                # if config.uid == 0:
                #     print(model.summary())
                # else:
                #     print(model.summary())
                # del model

                X_train = X_train_orig
                X_test = X_test_orig
                y_train = y_train_orig
                y_test = y_test_orig

                # apply transfer configuration
                if config.nth_subset > 1:
                    if X_train.shape[0] < config.nth_subset:
                        print(f"Unable to apply nth-subset. Not enough data")
                    X_train = X_train_orig[0::config.nth_subset]
                    X_test = X_test_orig[0::config.nth_subset]
                    y_train = y_train_orig[0::config.nth_subset]
                    y_test = y_test_orig[0::config.nth_subset]
                    print(
                        f"Orig shape: {X_train_orig.shape} {config.nth_subset} th-subset shape: {X_train.shape}"
                    )
                    print(
                        f"Orig shape: {X_test_orig.shape} {config.nth_subset} th-subset shape: {X_test.shape}"
                    )
                    print(
                        f"Orig shape: {y_train_orig.shape} {config.nth_subset} th-subset shape: {y_train.shape}"
                    )
                    print(
                        f"Orig shape: {y_test_orig.shape} {config.nth_subset} th-subset shape: {y_test.shape}"
                    )
                modified = False
                # freeze certain layers
                for layer in model.layers:
                    if layer.name not in config.train_layers:
                        modified = True
                        print(f"setting layer {layer.name} to False")
                        layer.trainable = False
                    else:
                        print(f"layer {layer.name} stays True")
                if modified:
                    print(f"modified. compiling")
                    # re-compile
                    model.compile(optimizer=train_optimizer,
                                  loss="mse",
                                  metrics=["mae"])
                # trainable_count = int(np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
                # non_trainable_count = int(np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
                trainable_count = count_params(model.trainable_weights)
                non_trainable_count = count_params(model.non_trainable_weights)
                print(f"Total params: {trainable_count + non_trainable_count}")
                print(f"Trainable params: {trainable_count}")
                print(f"Non trainable params: {non_trainable_count}")

                # transfer model
                result = model.fit(X_train,
                                   y_train,
                                   epochs=num_epochs,
                                   batch_size=batch_size,
                                   verbose=2,
                                   validation_data=(X_test, y_test),
                                   callbacks=callbacks_list)
                train_mae = result.history["mae"]
                val_mae = result.history["val_mae"]
                gc.collect()
                tune_result = None
                tune_train_mae = None
                tune_val_mae = None

                if config.tune:
                    print(f"Fine-Tuning transferred model")
                    # apply fine-tuning: unfreeze all but batch-normalization layers!
                    for layer in model.layers:
                        if not layer.name.startswith("batch_normalization"):
                            layer.trainable = True
                    model.compile(optimizer=fine_tune_optimizer,
                                  loss="mse",
                                  metrics=["mae"])
                    # print(f"model for fine tuning")
                    # print(model.summary())
                    tune_result = model.fit(X_train,
                                            y_train,
                                            epochs=fine_num_epochs,
                                            batch_size=batch_size,
                                            verbose=2,
                                            validation_data=(X_test, y_test),
                                            callbacks=callbacks_list)
                    tune_train_mae = tune_result.history["mae"]
                    tune_val_mae = tune_result.history["val_mae"]
                model.load_weights(file_path)

                # set evaluation
                def _compute_mae(_val_mae: List[float],
                                 _tune_val_mae: List[float]) -> float:
                    if _tune_val_mae is not None:
                        _val_mae = _val_mae + _tune_val_mae
                    return min(val_mae)

                evaluator.set_mae(target_port, start_time,
                                  _compute_mae(val_mae, tune_val_mae),
                                  base_port, config.uid)
                y_pred = model.predict(X_test)
                grouped_mae = evaluator.group_mae(y_test, y_pred)
                evaluator.set_mae(target_port, start_time, grouped_mae,
                                  base_port, config.uid)

                # save history
                history_file_name = encode_history_file(
                    training_type, target_port.name, start_time,
                    td.base_port_name, config.uid)
                history_path = os.path.join(output_dir, "data",
                                            target_port.name,
                                            history_file_name)
                np.save(history_path, [
                    result.history,
                    tune_result.history if tune_result else None
                ])

                # plot history
                plot_dir = os.path.join(output_dir, "plot")
                plot_history(train_mae, val_mae, plot_dir, target_port.name,
                             start_time, training_type, td.base_port_name,
                             config.uid, tune_train_mae, tune_val_mae)
                # evaluator.plot_grouped_mae(target_port, training_type, start_time, config.uid)
                plot_predictions(y_pred, y_test, plot_dir, target_port.name,
                                 start_time, training_type, td.base_port_name,
                                 config.uid)
                self.set_transfer(target_port.name, td.base_port_name,
                                  config.uid)
                del checkpoint, early, redonplat
                del X_train_orig, X_test_orig, y_train_orig, y_test_orig, model, X_train, y_train, X_test, y_test
                gc.collect()
                tf.keras.backend.clear_session()
            gc.collect()
        del X_ts, y_ts

    def _generate_transfers(self) -> Dict[str, List[TransferDefinition]]:
        """
        Generate TransferDefinitions based on transfer-config.json, containing those ports that have a base training for
        transferring to another port
        :return: Dict of key = target_port_name, val = List of TransferDefinition
        """
        config = read_json(self.config_path)
        transfer_defs = {}
        ports = list(config["ports"])
        permutations = list(itertools.permutations(ports, r=2))

        # for pair in _permute(config["ports"]):
        for pair in permutations:
            base_port, target_port = self.pm.find_port(
                pair[0]), self.pm.find_port(pair[1])
            if target_port is None:
                raise ValueError(
                    f"No port found: Unable to transfer from base-port with name '{base_port.name}'"
                )
            if target_port is None:
                raise ValueError(
                    f"No port found: Unable to transfer to target-port with name '{pair[1]}'"
                )

            trainings = self.pm.load_trainings(base_port,
                                               self.output_dir,
                                               self.routes_dir,
                                               training_type="base")
            # print(f"loaded trainings. base port {base_port.name}:\n{trainings.keys()}")
            if len(trainings.keys()) < 1:
                print(
                    f"No base-training found for port '{base_port.name}'. Skipping"
                )
                continue

            training = list(trainings.values())[-1][0]
            # print(f"training ({len(trainings.values())}): {training}")
            # print(f"Pair {base_port.name} ({len(trainings)} base-trains) -> {target_port.name}. "
            #       f"Using latest at '{training.start_time}'")
            verify_output_dir(self.output_dir, target_port.name)
            td = TransferDefinition(
                base_port_name=base_port.name,
                base_model_path=training.model_path,
                target_port_name=target_port.name,
                target_routes_dir=os.path.join(self.routes_dir,
                                               target_port.name),
                target_model_dir=os.path.join(self.output_dir, "model",
                                              target_port.name),
                target_output_data_dir=os.path.join(self.output_dir, "data",
                                                    target_port.name),
                target_plot_dir=os.path.join(self.output_dir, "plot",
                                             target_port.name),
                target_log_dir=os.path.join(self.output_dir, "log",
                                            target_port.name))
            name = target_port.name
            if name in transfer_defs:
                transfer_defs[target_port.name].append(td)
            else:
                transfer_defs[target_port.name] = [td]
        return transfer_defs

    def _generate_configs(self) -> List[TransferConfig]:
        skip_uids = [0, 2, 3, 4]
        config = read_json(self.config_path)

        def _make_config(uid: str, desc: str, nth_subset: str,
                         train_layers: List[str],
                         tune: bool) -> TransferConfig:
            uid = int(uid)
            nth_subset = int(nth_subset)
            return TransferConfig(uid, desc, nth_subset, train_layers, tune)

        configs = [
            _make_config(c["uid"], c["desc"], c["nth_subset"],
                         c["train_layers"], c["tune"])
            for c in config["configs"] if c["uid"] not in skip_uids
        ]
        print(f"{len(configs)}")
        return configs
コード例 #8
0
ファイル: trainer.py プロジェクト: Trousersfield/ma
def train(port_name: str, data_dir: str, output_dir: str, num_epochs: int = 100, learning_rate: float = .00025,
          weight_decay: float = .0001, pm: PortManager = None, resume_checkpoint: str = None,
          debug: bool = False) -> None:
    # TODO: Make sure dataset does not overwrite if (accidently) new training is started
    start_datetime = datetime.now()
    start_time = as_str(start_datetime)
    # set device: use gpu if available
    # more options: https://pytorch.org/docs/stable/notes/cuda.html
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # torch.autograd.set_detect_anomaly(True)
    if pm is None:
        pm = PortManager()
        pm.load()
        if len(pm.ports.keys()) < 1:
            raise ValueError("No port data available")
    port = pm.find_port(port_name)

    output_dirs = verify_output_dir(output_dir, port.name)

    log_file_name = f"train-log-base_{port.name}_{start_time}"
    train_logger = Logger(log_file_name, output_dirs["log"], save=False)
    debug_logger = Logger(f"{log_file_name}_debug", output_dirs["log"]) if debug else train_logger

    if port is None:
        train_logger.write(f"Training skipped: Unable to find port based on port_name {port_name}")
        return

    training_type = "base"
    batch_size = 64
    window_width = 128
    dataset_dir = os.path.join(data_dir, "routes", port.name)

    # init dataset on directory
    dataset = RoutesDirectoryDataset(dataset_dir, start_time=start_time, training_type=training_type,
                                     batch_size=batch_size, start=0, window_width=window_width)
    if resume_checkpoint is not None:
        dataset_config_path = encode_dataset_config_file(resume_checkpoint, training_type) \
            if resume_checkpoint != "latest" else find_latest_dataset_config_path(dataset_dir,
                                                                                  training_type=training_type)
        if not os.path.exists(dataset_config_path):
            latest_config_path = find_latest_dataset_config_path(dataset_dir, training_type=training_type)
            use_latest = input(f"Unable to find dataset config for start time '{resume_checkpoint}'. "
                               f"Continue with latest config (Y) at '{latest_config_path}' or abort")
            if use_latest not in ["Y", "y", "YES", "yes"]:
                print(f"Training aborted")
                return
            dataset_config_path = latest_config_path
        if dataset_config_path is None or not os.path.exists(dataset_config_path):
            raise FileNotFoundError(f"Unable to recover training: No dataset config found at {dataset_config_path}")
        dataset = RoutesDirectoryDataset.load_from_config(dataset_config_path)
    else:
        dataset.save_config()
    end_train = int(.8 * len(dataset))
    if not (len(dataset) - end_train) % 2 == 0 and end_train < len(dataset):
        end_train += 1
    end_validate = int(len(dataset) - ((len(dataset) - end_train) / 2))

    # use initialized dataset's config for consistent split
    train_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, start=0, end=end_train)
    validate_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, start=end_train, end=end_validate)
    # eval_dataset = RoutesDirectoryDataset.load_from_config(dataset.config_path, kind="eval", start=end_validate)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=None, drop_last=False, pin_memory=True,
                                               num_workers=2)
    validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=None, drop_last=False,
                                                  pin_memory=True, num_workers=2)
    train_logger.write(f"Dataset lengths:\n"
                       f"all: {len(dataset)}\ntrain: {len(train_dataset)}\nvalidate: {len(validate_dataset)}")

    data, target = train_dataset[0]
    input_dim = data.size(-1)
    output_dim = 1
    start_epoch = 0
    loss_history = ([], [])
    elapsed_time_history = []
    criterion: torch.nn.MSELoss = torch.nn.MSELoss()

    # resume from a checkpoint if training was aborted
    if resume_checkpoint is not None:
        tc, model = load_checkpoint(output_dirs["model"], device)
        start_epoch = len(tc.loss_history[1])
        start_time = tc.start_time
        num_epochs = tc.num_epochs
        learning_rate = tc.learning_rate
        weight_decay = tc.weight_decay
        loss_history = tc.loss_history
        elapsed_time_history = tc.elapsed_time_history
        # TODO: optimizer
        optimizer = tc.optimizer
    else:
        model = InceptionTimeModel(num_inception_blocks=3, in_channels=input_dim, out_channels=32,
                                   bottleneck_channels=16, use_residual=True, output_dim=output_dim).to(device)
        # test what happens if using "weight_decay" e.g. with 1e-4
        # TODO: optimizer
        # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    print(f".:'`!`':. TRAINING FOR PORT {port_name} STARTED .:'`!`':.")
    print(f"- - Epochs {num_epochs} </> Training examples {len(train_loader)} </> Learning rate {learning_rate} - -")
    print(f"- - Weight decay {weight_decay} Window width {window_width} </> Batch size {batch_size} - -")
    print(f"- - Number of model's parameters {num_total_trainable_parameters(model)} device {device} - -")
    train_logger.write(f"{port.name}-model\n"
                       f"Number of epochs: {num_epochs}\n"
                       f"Learning rate: {learning_rate}\n"
                       f"Total number of parameters: {num_total_parameters(model)}\n"
                       f"Total number of trainable parameters: {num_total_trainable_parameters(model)}")

    min_val_idx = 0
    if resume_checkpoint is not None:
        min_val_idx = loss_history[1].index(min(loss_history[1]))
    # training loop
    print(f"loss history:\n{loss_history}")
    print(f"min index:\n{min_val_idx}")
    for epoch in range(start_epoch, num_epochs):
        # train model
        print(f"->->->->-> Epoch ({epoch + 1}/{num_epochs}) <-<-<-<-<-<-")
        avg_train_loss, elapsed_time = train_loop(criterion=criterion, model=model, device=device, optimizer=optimizer,
                                                  loader=train_loader, debug=debug, debug_logger=debug_logger)
        loss_history[0].append(avg_train_loss)
        elapsed_time_history.append(elapsed_time)

        # validate model
        avg_validation_loss = validate_loop(criterion=criterion, device=device, model=model, optimizer=optimizer,
                                            loader=validate_loader, debug=debug, debug_logger=debug_logger)
        loss_history[1].append(avg_validation_loss)

        # check if current model has lowest validation loss (= is current optimal model)
        if avg_validation_loss < loss_history[1][min_val_idx]:
            min_val_idx = epoch

        train_logger.write(f"Epoch {epoch + 1}/{num_epochs}:\n"
                           f"\tAvg train loss {avg_train_loss}\n"
                           f"\tAvg val   loss {avg_validation_loss}")

        make_training_checkpoint(model=model, model_dir=output_dirs["model"], port=port, start_time=start_time,
                                 num_epochs=num_epochs, learning_rate=learning_rate,
                                 weight_decay=weight_decay, num_train_examples=len(train_loader),
                                 loss_history=loss_history, elapsed_time_history=elapsed_time_history,
                                 optimizer=optimizer, is_optimum=min_val_idx == epoch)
        save_intermediate(data_dir=output_dirs["data"], elapsed_time_history=elapsed_time_history,
                          loss_history=loss_history, port=port, start_time=start_time, training_type="base")
        print(f">>>> Avg losses (MSE) - Train: {avg_train_loss} Validation: {avg_validation_loss} <<<<\n")

    # conclude training
    conclude_training(loss_history=loss_history, data_dir=output_dirs["data"], plot_dir=output_dirs["plot"],
                      port=port, start_time=start_time, elapsed_time_history=elapsed_time_history,
                      plot_title="Training loss", training_type="base")