Exemple #1
0
def main(params):
    # Obtain configuration path
    exp_path = os.path.join("results", params["dataset"], params["model_type"],
                            params["experiment"])
    config_path = os.path.join(exp_path, "config.yml")
    params["config_path"] = config_path

    # prepare model and dataset
    M, dataset, config = cmf.prepare_experiment(params)

    # evaluate on GT
    config["evaluation"]["use_gt"] = params["evaluate_on_gt"]

    # evaluate on Top 1000 proposals
    if params["evaluate_on_top1000"]:
        config["evaluation"]["use_gt"] = False
        config["evaluation"]["apply_nms"] = False

    if len(params["proposal"]) > 0:
        config["evaluation"]["precomputed_proposal_sequence"] = params[
            "proposal"]

    # create logger
    epoch_logger = cmf.create_logger(config, "EPOCH", "test.log")
    """ Build data loader """
    loader_config = io_utils.load_yaml(params["loader_config_path"])
    if params["test_on_server"]:
        loader_config = loader_config["test_loader"]
        test_on = "Test_Server"
    else:
        loader_config = loader_config["val_loader"]
        test_on = "Test"
    dsets, L = cmf.get_loader(dataset,
                              split=["test"],
                              loader_configs=[loader_config],
                              num_workers=params["num_workers"])
    config = M.override_config_from_dataset(config, dsets["test"], mode="Test")
    config["model"]["resume"] = True
    tensorboard_path = config["misc"]["tensorboard_dir"]
    config["misc"]["tensorboard_dir"] = ""  #
    config["misc"]["debug"] = params["debug_mode"]
    """ Evaluating networks """
    e0 = params["start_epoch"]
    e1 = params["end_epoch"]
    es = params["epoch_stride"]
    io_utils.check_and_create_dir(tensorboard_path +
                                  "_test_s{}_e{}".format(e0, e1))
    summary = PytorchSummary(tensorboard_path + "_test_s{}_e{}".format(e0, e1))
    for epoch in range(e0, e1 + 1, es):
        """ Build network """
        config["model"]["checkpoint_path"] = \
            os.path.join(exp_path, "checkpoints", "epoch_{:03d}.pkl".format(epoch))
        net, _ = cmf.factory_model(config, M, dsets["test"], None)
        net.set_tensorboard_summary(summary)

        cmf.test(config, L["test"], net, epoch, None, epoch_logger, on=test_on)
 def create_tensorboard_summary(self, tensorboard_dir):
     self.use_tf_summary = True
     self.summary = PytorchSummary(tensorboard_dir)
class AbstractNetwork(nn.Module):
    def __init__(self, config, logger=None, verbose=False):
        super(AbstractNetwork, self).__init__()  # Must call super __init__()

        # update configuration
        config = self.model_specific_config_update(config)

        # create internal variables
        self.optimizer = None
        self.models_to_update = None
        self.training_mode = True
        self.best_score = None
        self.use_tf_summary = False
        self.it = 0  # it: iteration

        # for gpu use
        use_gpu = config["model"].get("use_gpu", True)
        if not torch.cuda.is_available: use_gpu = False
        self.device = torch.device("cuda" if use_gpu else "cpu")

        # save configuration for later network reproduction
        if verbose:
            save_config_path = os.path.join(config["misc"]["result_dir"],
                                            "config.yml")
            io_utils.write_yaml(save_config_path, config)
        config["model"]["dataset"] = config["train_loader"]["dataset"]
        self.config = config

        # prepare loggin
        self.log = print
        if logger is not None:
            self.log = logger.info
        self.log(json.dumps(config, indent=2))
        self.verbose = verbose

    """ methods for forward/backward """

    @abstractmethod
    def forward(self, net_inps):
        """ Forward network
        Args:
            net_inps: inputs for network; dict()
        Returns:
            net_outs: dictionary including inputs for criterion, etc
        """
        pass

    def loss_fn(self, crit_inp, gts, count_loss=True):
        """ Compute loss
        Args:
            crit_inp: inputs for criterion which is outputs from forward(); dict()
            gts: ground truth
            count_loss: flag of accumulating loss or not (training or inference)
        Returns:
            loss: results of self.criterion; dict()
        """
        self.loss = self.criterion(crit_inp, gts)
        for name in self.loss.keys():
            self.status[name] = self.loss[name].item()
        if count_loss:
            for name in self.loss.keys():
                self.counters[name].add(self.status[name], 1)
        return self.loss

    def update(self, loss):
        """ Update the network
        Args:
            loss: loss to train the network; dict()
        """

        self.it = self.it + 1
        # initialize optimizer
        if self.optimizer == None:
            self.create_optimizer()
            self.optimizer.zero_grad()  # set gradients as zero before update

        total_loss = loss["total_loss"]
        total_loss.backward()
        if self.scheduler is not None: self.scheduler.step()
        self.optimizer.step()
        self.optimizer.zero_grad(
        )  # set gradients as zero before updating the network

    def forward_update(self, net_inps, gts):
        """ Forward and update the network at the same time
        Args:
            net_inps: inputs for network; dict()
            gts: ground truth; dict()
        Returns:
            {loss, net_output}: two items of dictionary
                - loss: results from self.criterion(); dict()
                - net_output: output from self.forward(); dict()
        """

        net_out = self.forward(net_inps)
        loss = self.loss_fn(net_out, gts, count_loss=True)
        self.update(loss)
        return {"loss": loss, "net_output": net_out}

    def compute_loss(self, net_inps, gts):
        """ Compute loss and network's output at once
        Args:
            net_inps: inputs for network; dict()
            gts: ground truth; dict()
        Returns:
            {loss, net_output}: two items of dictionary
                - loss: results from self.criterion(); dict()
                - net_output: first output from self.forward(); dict()
        """
        net_out = self.forward(net_inps)
        loss = self.loss_fn(net_out, gts, count_loss=True)
        return {"loss": loss, "net_output": net_out}

    def forward_only(self, net_inps):
        """ Compute loss and network's output at once
        Args:
            net_inps: inputs for network; dict()
            gts: ground truth; dict()
        Returns:
            {loss, net_output}: two items of dictionary
                - loss: results from self.criterion(); dict()
                - net_output: first output from self.forward(); dict()
        """
        net_out = self.forward(net_inps)
        return {"net_output": net_out}

    def get_lr(self):
        for param_group in self.optimizer.param_groups:
            return param_group["lr"]

    def create_optimizer(self):
        """ Create optimizer for training phase
        Currently supported optimizer list: [SGD, Adam]
        Args:
            lr: learning rate; int
        """

        # setting optimizer
        lr = self.config["optimize"]["init_lr"]
        opt_type = self.config["optimize"]["optimizer_type"]
        if opt_type == "SGD":
            self.optimizer = torch.optim.SGD(
                self.get_parameters(),
                lr=lr,
                momentum=self.config["optimize"]["momentum"],
                weight_decay=self.config["optimize"]["weight_decay"])
        elif opt_type == "Adam":
            betas = self.config["optimize"].get("betas", (0.9, 0.999))
            weight_decay = self.config["optimize"].get("weight_decay", 0.0)
            self.optimizer = torch.optim.Adam(self.get_parameters(),
                                              lr=lr,
                                              betas=betas,
                                              weight_decay=weight_decay)
        elif opt_type == "Adadelta":
            self.optimizer = torch.optim.Adadelta(self.get_parameters(), lr=lr)
        elif opt_type == "RMSprop":
            self.optimizer = torch.optim.RMSprop(self.get_parameters(), lr=lr)
        else:
            raise NotImplementedError(
                "Not supported optimizer [{}]".format(opt_type))

        # setting scheduler
        self.scheduler = None
        scheduler_type = self.config["optimize"].get("scheduler_type", "")
        decay_factor = self.config["optimize"]["decay_factor"]
        decay_step = self.config["optimize"]["decay_step"]
        if scheduler_type == "step":
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer, decay_step, decay_factor)
        elif scheduler_type == "multistep":
            milestones = self.config["optimize"]["milestones"]
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
                self.optimizer, milestones, decay_factor)
        elif scheduler_type == "exponential":
            self.scheduler = torch.optim.lr_scheduler.ExponentialLR(
                self.optimizer, decay_factor)
        elif scheduler_type == "lambda":
            lambda1 = lambda it: it // decay_step
            lambda2 = lambda it: decay_factor**it
            self.scheduler = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer, [lambda1, lambda2])
        elif scheduler_type == "warmup":
            raise NotImplementedError()

    @abstractmethod
    def _build_network(self):
        pass

    def _build_evaluator(self):
        self.dataset = self.config["train_loader"].get("dataset", "charades")
        self.evaluator = eval_utils.get_evaluator(self.dataset)

    @abstractmethod
    def prepare_batch(self, batch):
        """ Prepare batch to be used for network
        e.g., shipping batch to gpu
        Args:
            batch: batch data; dict()
        Returns:
            net_inps: network inputs; dict()
            gts: ground-truths; dict()
        """
        pass

    @abstractmethod
    def apply_curriculum_learning(self):
        pass

    def save_results(self, prefix, mode="Train"):
        pass

    """ Method for status (losses, metrics)
    """

    def _get_score(self):
        return self.counters[self.evaluator.get_metric()].get_average()

    def renew_best_score(self):
        cur_score = self._get_score()
        if (self.best_score is None) or (cur_score > self.best_score):
            self.best_score = cur_score
            self.log("Iteration {}: New best score {:4f}".format(
                self.it, self.best_score))
            return True
        self.log("Iteration {}: Current score {:4f}".format(
            self.it, cur_score))
        self.log("Iteration {}: Current best score {:4f}".format(
            self.it, self.best_score))
        return False

    def reset_status(self, init_reset=False):
        """ Reset (initialize) metric scores or losses (status).
        """
        if init_reset:
            self.status = OrderedDict()
            self.status["total_loss"] = 0
            for k, v in self.criterion.get_items():
                self.status[k] = 0
            for k in self.evaluator.metrics:
                self.status[k] = 0
        else:
            for k in self.status.keys():
                self.status[k] = 0

    @abstractmethod
    def compute_status(self, net_outs, gts, mode="Train"):
        """ Compute metric scores or losses (status).
            You may need to implement this method.
        Args:
            net_outs: output of network.
            gts: ground-truth
        """
        pass

    def _get_print_list(self, mode):
        if mode == "Train":
            print_list = copy.deepcopy(self.criterion.get_names())
            print_list.append("total_loss")
        else:
            print_list = copy.deepcopy(self.evaluator.metrics)
        return print_list

    def print_status(self, enter_every=3):
        """ Print current metric scores or losses (status).
            You are encouraged to implement this method.
        Args:
            epoch: current epoch
        """
        val_list = self._get_print_list("Train")
        # print status information
        txt = "Step {} ".format(self.it)
        for i, (k) in enumerate(val_list):
            v = self.status[k]
            if (i + 1) % enter_every == 0:
                txt += "{} = {:.4f}, ".format(k, float(v))
                self.log(txt)
                txt = ""
            else:
                txt += "{} = {:.4f}, ".format(k, float(v))
        if len(txt) > 0: self.log(txt)

    """ methods for counters """

    def _create_counters(self):
        self.counters = OrderedDict()
        self.counters["total_loss"] = accumulator.Accumulator("total_loss")
        for k, v in self.criterion.get_items():
            self.counters[k] = accumulator.Accumulator(k)
        for k in self.evaluator.metrics:
            self.counters[k] = accumulator.Accumulator(k)

    def reset_counters(self):
        for k, v in self.counters.items():
            v.reset()

    def print_counters_info(self, logger, epoch, mode="Train"):
        val_list = self._get_print_list(mode)
        txt = "[{}] {} epoch {} iter".format(mode, epoch, self.it)
        for k in val_list:
            v = self.counters[k]
            txt += ", {} = {:.4f}".format(v.get_name(), v.get_average())
        if logger:
            logger.info(txt)
        else:
            self.log(txt)

        if self.use_tf_summary:
            self.write_counter_summary(epoch, mode)

        # reset counters
        self.reset_counters()

    """ methods for checkpoint """

    def load_checkpoint(self, ckpt_path, load_crit=False):
        """ Load checkpoint of the network.
        Args:
            ckpt_path: checkpoint file path; str
        """
        self.log("Checkpoint is loaded from {}".format(ckpt_path))
        model_state_dict = torch.load(
            ckpt_path, map_location=lambda storage, loc: storage)
        self.log("[{}] are in checkpoint".format("|".join(
            model_state_dict.keys())))
        for m in model_state_dict.keys():
            if (not load_crit) and (m == "criterion"): continue
            if m in self.model_list:
                self.log("Initializing [{}] from checkpoint".format(m))
                self[m].load_state_dict(model_state_dict[m])
            else:
                self.log("{} is not in {}".format(m,
                                                  "|".join(self.model_list)))

    def save_checkpoint(self, ckpt_path, save_crit=False):
        """ Save checkpoint of the network.
        Args:
            ckpt_path: checkpoint file path
        """
        model_state_dict = {
            m: self[m].state_dict()
            for m in self.model_list if m != "criterion"
        }
        if save_crit:
            model_state_dict["criterion"] = self["criterion"].state_dict()
        torch.save(model_state_dict, ckpt_path)

        self.log("Checkpoint [{}] is saved in {}".format(
            " | ".join(model_state_dict.keys()), ckpt_path))

    """ methods for tensorboard """

    def create_tensorboard_summary(self, tensorboard_dir):
        self.use_tf_summary = True
        self.summary = PytorchSummary(tensorboard_dir)

    def set_tensorboard_summary(self, summary):
        self.use_tf_summary = True
        self.summary = summary

    def write_counter_summary(self, epoch, mode):
        for k, v in self.counters.items():
            self.summary.add_scalar(mode + '/counters/' + v.get_name(),
                                    v.get_average(),
                                    global_step=epoch)

    """ wrapper methods of nn.Modules """

    def get_parameters(self):
        if self.models_to_update is None:
            for name, param in self.named_parameters():
                yield param
        else:
            for m in self.models_to_update:
                if isinstance(self[m], dict):
                    for k, v in self[m].items():
                        for name, param in v.named_parameters():
                            yield param
                else:
                    for name, param in self[m].named_parameters():
                        yield param

    def cpu_mode(self):
        sel.log("Setting cpu() for [{}]".format(" | ".join(self.model_list)))
        self.cpu()

    def gpu_mode(self):
        #cudnn.benchmark = False
        if torch.cuda.is_available():
            self.log("Setting gpu() for [{}]".format(" | ".join(
                self.model_list)))
            self.cuda()
        else:
            raise NotImplementedError("Available GPU not exists")

    def train_mode(self):
        self.train()
        self.training_mode = True
        if self.verbose:
            self.log("Setting train() for [{}]".format(" | ".join(
                self.model_list)))

    def eval_mode(self):
        self.eval()
        self.training_mode = False
        if self.verbose:
            self.log("Setting eval() for [{}]".format(" | ".join(
                self.model_list)))

    """ related to configuration or dataset """

    def bring_dataset_info(self, dset):
        print("You would need to implement 'bring_dataset_info'")
        pass

    def model_specific_config_update(self, config):
        print("You would need to implement 'model_specific_config_update'")
        return config

    @staticmethod
    def dataset_specific_config_update(config, dset):
        print("You would need to implement 'dataset_specific_config_update'")
        return config

    """ basic methods """

    def __getitem__(self, key):
        return getattr(self, key)

    def __setitem__(self, key, value):
        return setattr(self, key, value)
Exemple #4
0
    def create_tensorboard_summary(self, tensorboard_dir):
        self.use_tf_summary = True
        self.summary = PytorchSummary(tensorboard_dir)

        if self.debug_mode:
            self.write_params_summary(epoch=0)
Exemple #5
0
class VirtualNetwork(nn.Module):
    def __init__(self):
        super(VirtualNetwork, self).__init__()  # Must call super __init__()

        self.models_to_update = None
        self.sample_data = None
        self.optimizer = None
        self.training_mode = True
        self.is_main_net = True

        self.counters = None
        self.status = None
        self.use_tf_summary = False
        self.it = 0  # it: iteration
        self.update_every = 1
        self.debug_mode = False
        self.qsts = None

        self._create_counters()
        self._get_loggers()
        self.reset_status(init_reset=True)

        self.tm = timer.Timer()  # tm: timer

    """ methods for forward/backward """

    def forward(self, data):
        """ Forward network
        Args:
            data: list of two components [inputs for network, image information]
                - inputs for network: should be variables
        Returns:
            criterion_inp: input list for criterion
        """
        raise NotImplementedError("Should override a method (forward)")

    def loss_fn(self, criterion_inp, gt, count_loss=True):
        """ Compute loss
        Args:
            criterion_inp: inputs for criterion which is outputs from forward(); list
            gt: ground truth
            count_loss: flag whether accumulating loss or not (training or inference)
        """
        self.loss = self.criterion(criterion_inp, gt)
        self.status["loss"] = net_utils.get_data(self.loss)[0]
        if count_loss:
            self.counters["loss"].add(self.status["loss"], 1)
        return self.loss

    def update(self, loss, lr):
        """ Update the network
        Args:
            loss: loss to train the network
            lr: learning rate
        """

        if self.optimizer == None:
            self.optimizer = torch.optim.Adam(self.get_parameters(), lr=lr)
            self.optimizer.zero_grad(
            )  # set gradients as zero while initializing

        self.it += 1
        loss = loss / self.update_every
        loss.backward()
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr
        if self.it % self.update_every == 0:
            self.optimizer.step()
            self.optimizer.zero_grad(
            )  # set gradients as zero before updating the network

    def forward_update(self, batch, lr):
        """ Forward and update the network at the same time
        Args:
            batch: list of two components [inputs for network, image information]
                - inputs for network: should be tensors
            lr: learning rate
        """

        # convert data (tensors) as Variables
        data = self.tensor2variable(batch)

        # Note that return value is a list of at least two items
        # where the 1st and 2nd items should be loss and inputs for criterion
        # (e.g. logits), and remaining items would be intermediate values of network
        #  that you want to show or check
        outputs = self.forward(data)
        loss = self.loss_fn(outputs[0], data[-1], count_loss=True)
        self.update(loss, lr)
        return [loss, *outputs]

    def evaluate(self, batch):
        """ Compute loss and network's output at once
        Args:
            batch: list of two components [inputs for network, image information]
                - inputs for network: should be tensors
        """

        # convert data (tensors) as Variables
        data = self.tensor2variable(batch)

        # Note that return value is a list of at least two items
        # where the 1st and 2nd items should be loss and inputs for criterion layer
        # (e.g. logits), and remaining items would be intermediate values of network
        # that you want to show or check
        outputs = self.forward(data)
        loss = self.loss_fn(outputs[0], data[-1], count_loss=True)
        return [loss, *outputs]

    def predict(self, batch):
        """ Compute only network's output
        Args:
            batch: list of two components [inputs for network, image information]
                - inputs for network: should be tensors
        """

        # convert data (tensors) as Variables
        data = self.tensor2variable(batch)

        outputs = self.forward(data)
        return [*outputs]

    def tensor2variable(self, tensors):
        """ Convert tensors to variables
        Args:
            tensors: input tensors fetched from data loader
        """
        raise NotImplementedError(
            "Should override this function (tensor2variable)")

    """ methods for checkpoint """

    def load_checkpoint(self, ckpt_path):
        """ Load checkpoint of the network.
        Args:
            ckpt_path: checkpoint file path
        """
        self.logger["train"].info(
            "Checkpoint is loaded from {}".format(ckpt_path))
        model_state_dict = torch.load(
            ckpt_path, map_location=lambda storage, loc: storage)
        for m in model_state_dict.keys():
            if m in self.model_list:
                self[m].load_state_dict(model_state_dict[m])
            else:
                self.logger["train"].info("{} is not in {}".format(
                    m, " | ".join(self.model_list)))

        self.logger["train"].info(
            "[{}] are initialized from checkpoint".format(" | ".join(
                model_state_dict.keys())))

    def save_checkpoint(self, cid):
        """ Save checkpoint of the network.
        Args:
            cid: id of checkpoint; e.g. epoch
        """
        ckpt_path = os.path.join(self.config["misc"]["result_dir"], \
                "checkpoints", "epoch_{:03d}.pkl")
        ckpt_path = ckpt_path.format(cid)
        model_state_dict = OrderedDict()
        for m in self.model_list:
            model_state_dict[m] = self[m].state_dict()
        torch.save(model_state_dict, ckpt_path)
        self.logger["train"].info(
            "Checkpoint is saved in {}".format(ckpt_path))

    def _get_loggers(self):
        """ Create logging variables.
        """
        self.logger = {}
        self.logger["train"] = io_utils.get_logger("Train")
        self.logger["epoch"] = io_utils.get_logger("Epoch")
        self.logger["eval"] = io_utils.get_logger("Evaluate")

    def _set_sample_data(self, data):
        if self.sample_data == None:
            self.sample_data = copy.deepcopy(data)

    """ method for status (metrics) """

    def reset_status(self, init_reset=False):
        """ Reset (initialize) metric scores or losses (status).
        """
        if self.status == None:
            self.status = OrderedDict()
            self.status["loss"] = 0
        else:
            for k in self.status.keys():
                self.status[k] = 0

    def compute_status(self, logits, gts):
        """ Compute metric scores or losses (status).
            You may need to implement this method.
        Args:
            logits: output logits of network.
            gts: ground-truth
        """
        self.logger["train"].warning(
            "You may need to implement method (compute_status).")
        return

    def print_status(self, epoch, iteration, prefix="", is_main_net=True):
        """ Print current metric scores or losses (status).
            You may need to implement this method.
        Args:
            epoch: current epoch
            iteration: current iteration
            prefix: identity to distinguish models; if is_main_net, this is not needed
            is_main_net: flag about whether this network is root (main)
        """
        if is_main_net:
            # prepare txt to print
            txt = "epoch {} step {}".format(epoch, iteration)
            for k, v in self.status.items():
                txt += ", {} = {:.3f}".format(k, v)

            # print learning information
            self.logger["train"].debug(txt)

            if self.use_tf_summary and self.training_mode:
                self.write_status_summary(iteration)

    """ methods for tensorboard """

    def create_tensorboard_summary(self, tensorboard_dir):
        self.use_tf_summary = True
        self.summary = PytorchSummary(tensorboard_dir)

        if self.debug_mode:
            self.write_params_summary(epoch=0)

    def write_params_summary(self, epoch):
        if self.models_to_update is None:
            for name, param in self.named_parameters():
                self.summary.add_histogram("model/{}".format(name),
                                           net_utils.get_data(param).numpy(),
                                           global_step=epoch)
        else:
            for m in self.models_to_update:
                for name, param in self[m].named_parameters():
                    self.summary.add_histogram(
                        "model/{}/{}".format(m, name),
                        net_utils.get_data(param).numpy(),
                        global_step=epoch)

    def write_status_summary(self, iteration):
        for k, v in self.status.items():
            self.summary.add_scalar('status/' + k, v, global_step=iteration)

    def write_counter_summary(self, epoch, mode):
        for k, v in self.counters.items():
            self.summary.add_scalar(mode + '/counters/' + v.get_name(),
                                    v.get_average(),
                                    global_step=epoch)

    """ methods for counters """

    def _create_counters(self):
        self.counters = OrderedDict()
        self.counters["loss"] = accumulator.Accumulator("loss")

    def reset_counters(self):
        for k, v in self.counters.items():
            v.reset()

    def print_counters_info(self, epoch, logger_name="epoch", mode="train"):
        # prepare txt to print
        txt = "[{}] {} epoch".format(mode, epoch)
        for k, v in self.counters.items():
            txt += ", {} = {:.5f}".format(v.get_name(), v.get_average())

        # print learning information at this epoch
        assert logger_name in self.logger.keys(), \
                "{} does not belong to loggers".format(logger_name)
        self.logger[logger_name].info(txt)

        if self.use_tf_summary:
            self.write_counter_summary(epoch, mode)

        # reset counters
        self.reset_counters()

    def bring_loader_info(self, loader):
        self.logger["train"].warning(
            "You may need to implement method (bring_loader_info)")
        return

    """ wrapper methods of nn.Modules """

    def get_parameters(self):
        if self.models_to_update is None:
            for name, param in self.named_parameters():
                yield param
        else:
            for m in self.models_to_update:
                for name, param in self[m].named_parameters():
                    yield param

    def cpu_mode(self):
        self.logger["train"].info("Setting cpu() for [{}]".format(" | ".join(
            self.model_list)))
        self.cpu()

    def gpu_mode(self):
        if torch.cuda.is_available():
            self.logger["train"].info("Setting gpu() for [{}]".format(
                " | ".join(self.model_list)))
            self.cuda()
        else:
            raise NotImplementedError("Available GPU not exists")
        cudnn.benchmark = True

    def train_mode(self):
        self.train()
        self.training_mode = True
        self.logger["train"].info("Setting train() for [{}]".format(" | ".join(
            self.model_list)))

    def eval_mode(self):
        self.eval()
        self.training_mode = False
        self.logger["train"].info("Setting eval() for [{}]".format(" | ".join(
            self.model_list)))

    def __getitem__(self, key):
        return getattr(self, key)

    def __setitem__(self, key, value):
        return setattr(self, key, value)

    @staticmethod
    def override_config_from_params(config, params):
        config["misc"]["debug"] = params["debug_mode"]
        config["misc"]["dataset"] = params["dataset"]
        config["misc"]["num_workers"] = params["num_workers"]
        exp_prefix = utils.get_filename_from_path(params["config_path"], delimiter="options/") \
                if "options" in params["config_path"] \
                else utils.get_filename_from_path(params["config_path"], delimiter="results/")[:-7]
        config["misc"]["exp_prefix"] = exp_prefix
        config["misc"]["result_dir"] = os.path.join("results", exp_prefix)
        config["misc"]["tensorboard_dir"] = os.path.join(
            "tensorboard", exp_prefix)
        config["misc"]["model_type"] = params["model_type"]

        return config
Exemple #6
0
class AbstractNetwork(nn.Module):
    def __init__(self, config, logger=None, verbose=False):
        super(AbstractNetwork, self).__init__()  # Must call super __init__()

        # update configuration
        config = self.model_specific_config_update(config)

        self.optimizer, self.sample_data, self.models_to_update = None, None, None
        self.training_mode = True
        self.evaluate_after = config["evaluation"].get("evaluate_after", 1)

        self.it = 0  # it: iteration
        self.tm = timer.Timer()  # tm: timer
        self.grad_clip = config["optimize"].get("gradient_clip", 10)
        self.update_every = config["optimize"].get("update_every", 1)
        self.use_gpu = config["model"].get(
            "use_gpu", True if torch.cuda.is_available else False)
        self.device = torch.device("cuda" if self.use_gpu else "cpu")
        if len(config["misc"]["tensorboard_dir"]) > 0:
            self.create_tensorboard_summary(config["misc"]["tensorboard_dir"])

        # save configuration for later network reproduction
        resume = config["model"].get("resume", False)
        if not resume:
            save_config_path = os.path.join(config["misc"]["result_dir"],
                                            "config.yml")
            io_utils.write_yaml(save_config_path, config)
        self.config = config
        # prepare logging
        if logger is not None:
            self.log = logger.info
        else:
            self.log = print
        self.log(json.dumps(config, indent=2))

    """ methods for forward/backward """

    @abstractmethod
    def forward(self, net_inps):
        """ Forward network
        Args:
            net_inps: inputs for network; dict()
        Returns:
            [crit_inp, misc]: two items of list
                - inputs for criterion; dict()
                - intermediate values for visualization, etc; dict()
        """
        pass

    def loss_fn(self, crit_inp, gts, count_loss=True):
        """ Compute loss
        Args:
            crit_inp: inputs for criterion which is outputs from forward(); dict()
            gts: ground truth
            count_loss: flag of accumulating loss or not (training or inference)
        Returns:
            loss: results of self.criterion; dict()
        """
        self.loss = self.criterion(crit_inp, gts)
        for name in self.loss.keys():
            self.status[name] = self.loss[name].detach().cpu().numpy()
        if count_loss:
            for name in self.loss.keys():
                self.counters[name].add(self.status[name], 1)
        return self.loss

    def update(self, loss):
        """ Update the network
        Args:
            loss: loss to train the network; dict()
        """
        self.it += 1
        lr = net_utils.adjust_lr(self.it, self.it_per_epoch,
                                 self.config["optimize"])

        # initialize optimizer
        if self.optimizer == None:
            self.create_optimizer(lr)
            self.optimizer.zero_grad()  # set gradients as zero before update

        total_loss = loss["total_loss"] / self.update_every
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.grad_clip)
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr
        if self.it % self.update_every == 0:
            self.optimizer.step()
            self.optimizer.zero_grad(
            )  # set gradients as zero before updating the network

    def forward_update(self, net_inps, gts):
        """ Forward and update the network at the same time
        Args:
            net_inps: inputs for network; dict()
            gts: ground truth; dict()
        Returns:
            {loss, net_output}: two items of dictionary
                - loss: results from self.criterion(); dict()
                - net_output: first output from self.forward(); dict()
        """
        outputs = self.forward(net_inps)
        loss = self.loss_fn(outputs[0], gts, count_loss=True)
        self.update(loss)
        return {"loss": loss, "net_output": outputs}

    def compute_loss(self, net_inps, gts):
        """ Compute loss and network's output at once
        Args:
            net_inps: inputs for network; dict()
            gts: ground truth; dict()
        Returns:
            {loss, net_output}: two items of dictionary
                - loss: results from self.criterion(); dict()
                - net_output: first output from self.forward(); dict()
        """
        outputs = self.forward(net_inps)
        loss = self.loss_fn(outputs[0], gts, count_loss=True)
        return {"loss": loss, "net_output": outputs}

    def create_optimizer(self, lr):
        """ Create optimizer for training phase
        Currently supported optimizer list: [SGD, Adam]
        Args:
            lr: learning rate; int
        """
        opt_type = self.config["optimize"]["optimizer_type"]
        if opt_type == "SGD":
            self.optimizer = torch.optim.SGD(
                self.get_parameters(),
                lr=lr,
                momentum=self.config["optimize"]["momentum"],
                weight_decay=self.config["optimize"]["weight_decay"])
        elif opt_type == "Adam":
            self.optimizer = torch.optim.Adam(self.get_parameters(), lr=lr)
        elif opt_type == "Adadelta":
            self.optimizer = torch.optim.Adadelta(self.get_parameters(), lr=lr)
        elif opt_type == "RMSprop":
            self.optimizer = torch.optim.RMSprop(self.get_parameters(), lr=lr)
        else:
            raise NotImplementedError(
                "Not supported optimizer [{}]".format(opt_type))

    @abstractmethod
    def build_network(self, config):
        pass

    @abstractmethod
    def build_evaluator(self, config):
        pass

    @abstractmethod
    def prepare_batch(self, batch):
        """ Prepare batch to be used for network
        e.g., shipping batch to gpu
        Args:
            batch: batch data; dict()
        Returns:
            batch: batch; dict()
        """
        pass

    @abstractmethod
    def check_apply_curriculum(self, epoch=-1):
        """ Check and apply curriculum learning
        """
        pass

    """ methods for checkpoint """

    def load_checkpoint(self, ckpt_path, load_crit=False):
        """ Load checkpoint of the network.
        Args:
            ckpt_path: checkpoint file path; str
        """
        self.log("Checkpoint is loaded from {}".format(ckpt_path))
        model_state_dict = torch.load(
            ckpt_path, map_location=lambda storage, loc: storage)
        self.log("[{}] are in checkpoint".format("|".join(
            model_state_dict.keys())))
        for m in model_state_dict.keys():
            if load_crit and m == "criterion": continue
            if m in self.model_list:
                self[m].load_state_dict(model_state_dict[m])
                self.log("{} is initialized from checkpoint".format(m))
            else:
                self.log("{} is not in {}".format(m,
                                                  "|".join(self.model_list)))

    def save_checkpoint(self, ckpt_path, save_crit=False):
        """ Save checkpoint of the network.
        Args:
            ckpt_path: checkpoint file path
        """
        model_state_dict = {
            m: self[m].state_dict()
            for m in self.model_list if m != "criterion"
        }
        if save_crit:
            model_state_dict["criterion"] = self["criterion"].state_dict()
        torch.save(model_state_dict, ckpt_path)

        self.log("Checkpoint [{}] is saved in {}".format(
            " | ".join(model_state_dict.keys()), ckpt_path))

    """ method for status (metrics) """

    def reset_status(self, init_reset=False):
        """ Reset (initialize) metric scores or losses (status).
        """
        if init_reset:
            self.status = OrderedDict()
            self.status["total_loss"] = 0
            for k, v in self.criterion.get_items():
                self.status[k] = 0
        else:
            for k in self.status.keys():
                self.status[k] = 0

    @abstractmethod
    def compute_status(self, logits, gts):
        """ Compute metric scores or losses (status).
            You may need to implement this method.
        Args:
            logits: output logits of network.
            gts: ground-truth
        """
        pass

    def print_status(self, epoch, mode="Train", enter_every=2):
        """ Print current metric scores or losses (status).
            You are encouraged to implement this method.
        Args:
            epoch: current epoch
        """
        # print status information
        txt = "epoch {} step {} ".format(epoch, self.it)
        for i, (k, v) in enumerate(self.status.items()):
            if (i + 1) % enter_every == 0:
                txt += "{} = {:.4f}, ".format(k, float(v))
                self.log(txt)
                txt = ""
            else:
                txt += "{} = {:.4f}, ".format(k, float(v))
        if len(txt) > 0: self.log(txt)

    """ methods for counters """

    def _create_counters(self):
        self.counters = OrderedDict()
        self.counters["total_loss"] = accumulator.Accumulator("total_loss")
        for k, v in self.criterion.get_items():
            self.counters[k] = accumulator.Accumulator(k)

    def reset_counters(self):
        for k, v in self.counters.items():
            v.reset()

    def print_counters_info(self, epoch, logger, mode="Train"):
        if mode != "Train" and epoch < self.evaluate_after:
            self.reset_counters()
            return

        txt = "[{}] {} epoch".format(mode, epoch)
        for k, v in self.counters.items():
            txt += ", {} = {:.4f}".format(v.get_name(), v.get_average())
        logger.info(txt)

        if self.use_tf_summary:
            self.write_counter_summary(epoch, mode)

        # reset counters
        self.reset_counters()

    """ methods for tensorboard """

    def create_tensorboard_summary(self, tensorboard_dir):
        self.use_tf_summary = True
        self.summary = PytorchSummary(tensorboard_dir)
        #self.write_params_summary(epoch=0)

    def set_tensorboard_summary(self, summary):
        self.use_tf_summary = True
        self.summary = summary

    def write_params_summary(self, epoch):
        if self.models_to_update is None:
            for name, param in self.named_parameters():
                self.summary.add_histogram("model/{}".format(name),
                                           param,
                                           global_step=epoch)
        else:
            for m in self.models_to_update:
                for name, param in self[m].named_parameters():
                    self.summary.add_histogram("model/{}/{}".format(m, name),
                                               param,
                                               global_step=epoch)

    def write_status_summary(self):
        for k, v in self.status.items():
            self.summary.add_scalar('status/' + k, v, global_step=self.it)

    def write_counter_summary(self, epoch, mode):
        for k, v in self.counters.items():
            self.summary.add_scalar(mode + '/counters/' + v.get_name(),
                                    v.get_average(),
                                    global_step=epoch)

    @abstractmethod
    def bring_loader_info(self, dataset):
        pass

    """ wrapper methods of nn.Modules """

    def _get_parameter(self, net):
        if isinstance(net, dict):
            for k, v in net.items():
                self._get_parameter(v)
        else:
            for name, param in net.named_parameters():
                yield param

    def get_parameters(self):
        if self.models_to_update is None:
            for name, param in self.named_parameters():
                yield param
        else:
            for m in self.models_to_update:
                if isinstance(self[m], dict):
                    for k, v in self[m].items():
                        for name, param in v.named_parameters():
                            yield param
                else:
                    for name, param in self[m].named_parameters():
                        yield param

    def cpu_mode(self):
        sel.log("Setting cpu() for [{}]".format(" | ".join(self.model_list)))
        self.cpu()

    def gpu_mode(self):
        #cudnn.benchmark = False
        if torch.cuda.is_available():
            self.log("Setting gpu() for [{}]".format(" | ".join(
                self.model_list)))
            self.cuda()
        else:
            raise NotImplementedError("Available GPU not exists")

    def train_mode(self):
        self.train()
        self.training_mode = True
        self.log("Setting train() for [{}]".format(" | ".join(
            self.model_list)))

    def eval_mode(self):
        self.eval()
        self.training_mode = False
        self.log("Setting eval() for [{}]".format(" | ".join(self.model_list)))

    def _set_sample_data(self, data):
        if self.sample_data == None:
            self.sample_data = copy.deepcopy(data)

    def __getitem__(self, key):
        return getattr(self, key)

    def __setitem__(self, key, value):
        return setattr(self, key, value)

    def model_specific_config_update(self, config):
        return config