示例#1
0
    def train(self):
        if not self._state["initialized"]:
            self.init_train()
        self._state["initialized"] = True

        self._state["epoch"] += 1
        epoch = self._state["epoch"]
        num_iterations = self._hyper_params["num_iterations"]

        # udpate engine_state
        self._state["max_epoch"] = self._hyper_params["max_epoch"]
        self._state["max_iteration"] = num_iterations

        self._optimizer.modify_grad(epoch)
        pbar = tqdm(range(num_iterations))
        self._state["pbar"] = pbar
        self._state["print_str"] = ""

        time_dict = OrderedDict()
        for iteration, _ in enumerate(pbar):
            self._state["iteration"] = iteration
            with Timer(name="data", output_dict=time_dict):
                training_data = next(self._dataloader)
            training_data = move_data_to_device(training_data,
                                                self._state["devices"][0])

            schedule_info = self._optimizer.schedule(epoch, iteration)
            self._optimizer.zero_grad()

            # forward propagation
            with Timer(name="fwd", output_dict=time_dict):
                predict_data = self._model(training_data)
                training_losses, extras = OrderedDict(), OrderedDict()
                for loss_name, loss in self._losses.items():
                    training_losses[loss_name], extras[loss_name] = loss(
                        predict_data, training_data)
                total_loss = sum(training_losses.values())

            # backward propagation
            with Timer(name="bwd", output_dict=time_dict):
                if self._optimizer.grad_scaler is not None:
                    self._optimizer.grad_scaler.scale(total_loss).backward()
                else:
                    total_loss.backward()
            self._optimizer.modify_grad(epoch, iteration)
            with Timer(name="optim", output_dict=time_dict):
                self._optimizer.step()

            trainer_data = dict(
                schedule_info=schedule_info,
                training_losses=training_losses,
                extras=extras,
                time_dict=time_dict,
            )

            for monitor in self._monitors:
                monitor.update(trainer_data)
            del training_data
            print_str = self._state["print_str"]
            pbar.set_description(print_str)
示例#2
0
def run_dist_training(rank_id: int, world_size: int, task: str,
                      task_cfg: CfgNode, parsed_args, model, dist_url):
    """method to run on distributed process
       passed to multiprocessing.spawn
    
    Parameters
    ----------
    rank_id : int
        rank id, ith spawned process 
    world_size : int
        total number of spawned process
    task : str
        task name (passed to builder)
    task_cfg : CfgNode
        task builder (passed to builder)
    parsed_args : [type]
        parsed arguments from command line
    """
    devs = ["cuda:{}".format(rank_id)]
    # set up distributed
    setup(rank_id, world_size, dist_url)
    dist_utils.synchronize()
    # move model to device before building optimizer.
    # quick fix for resuming of DDP
    # TODO: need to be refined in future
    model.set_device(devs[0])
    # build optimizer
    optimizer = optim_builder.build(task, task_cfg.optim, model)
    # build dataloader with trainer
    with Timer(name="Dataloader building", verbose=True):
        dataloader = dataloader_builder.build(task,
                                              task_cfg.data,
                                              seed=rank_id)
    # build trainer
    trainer = engine_builder.build(task, task_cfg.trainer, "trainer",
                                   optimizer, dataloader)
    trainer.set_device(
        devs
    )  # need to be placed after optimizer built (potential pytorch issue)
    trainer.resume(parsed_args.resume)
    # trainer.init_train()
    logger.info("Start training")
    while not trainer.is_completed():
        trainer.train()
        if rank_id == 0:
            trainer.save_snapshot()
        dist_utils.synchronize()  # one synchronization per epoch

    if rank_id == 0:
        trainer.save_snapshot(model_param_only=True)
    # clean up distributed
    cleanup()
示例#3
0
def run_dist_training(rank_id: int, world_size: int, task: str,
                      task_cfg: CfgNode, parsed_args, model):
    """method to run on distributed process
       passed to multiprocessing.spawn
    
    Parameters
    ----------
    rank_id : int
        rank id, ith spawned process 
    world_size : int
        total number of spawned process
    task : str
        task name (passed to builder)
    task_cfg : CfgNode
        task builder (passed to builder)
    parsed_args : [type]
        parsed arguments from command line
    """
    # set up distributed
    setup(rank_id, world_size)
    # build model
    # model = model_builder.build(task, task_cfg.model)
    # build optimizer
    optimizer = optim_builder.build(task, task_cfg.optim, model)
    # build dataloader with trainer
    with Timer(name="Dataloader building", verbose=True, logger=logger):
        dataloader = dataloader_builder.build(task,
                                              task_cfg.data,
                                              seed=rank_id)
    # build trainer
    trainer = engine_builder.build(task, task_cfg.trainer, "trainer",
                                   optimizer, dataloader)
    devs = ["cuda:%d" % rank_id]
    trainer.set_device(devs)
    trainer.resume(parsed_args.resume_from_epoch, parsed_args.resume_from_file)
    # trainer.init_train()
    logger.info("Start training")
    while not trainer.is_completed():
        trainer.train()
        if rank_id == 0:
            trainer.save_snapshot()
        dist.barrier()  # one synchronization per epoch

    # clean up distributed
    cleanup()
示例#4
0
    # experiment config
    exp_cfg_path = osp.realpath(parsed_args.config)
    root_cfg.merge_from_file(exp_cfg_path)
    logger.info("Load experiment configuration at: %s" % exp_cfg_path)
    logger.info(
        "Merged with root_cfg imported from videoanalyst.config.config.cfg")
    # resolve config
    root_cfg = root_cfg.train
    task, task_cfg = specify_task(root_cfg)
    task_cfg.data.num_workers = 2
    task_cfg.data.sampler.submodules.dataset.GOT10kDataset.check_integrity = False
    task_cfg.freeze()

    if parsed_args.target == "dataloader":
        logger.info("visualize for dataloader")
        with Timer(name="Dataloader building", verbose=True):
            dataloader = dataloader_builder.build(task, task_cfg.data)

        for batch_training_data in dataloader:
            keys = list(batch_training_data.keys())
            batch_size = len(batch_training_data[keys[0]])
            training_samples = [{
                k: v[[idx]]
                for k, v in batch_training_data.items()
            } for idx in range(batch_size)]
            for training_sample in training_samples:
                target_cfg = task_cfg.data.target
                show_img_FCOS(target_cfg[target_cfg.name], training_sample)
                scan_key()
    elif parsed_args.target == "dataset":
        logger.info("visualize for dataset")
示例#5
0
    def train(self):
        if not self._state["initialized"]:
            self.init_train()
        self._state["initialized"] = True

        # epoch counter +1
        self._state["epoch"] += 1
        epoch = self._state["epoch"]
        num_iterations = self._hyper_params["num_iterations"]

        # udpate engine_state
        self._state["max_epoch"] = self._hyper_params["max_epoch"]
        self._state["max_iteration"] = num_iterations

        self._optimizer.modify_grad(epoch)
        # TODO: build stats gathering code and reorganize tqdm
        pbar = tqdm(range(num_iterations))
        # pbar = range(num_iterations)
        self._state["pbar"] = pbar
        self._state["print_str"] = ""

        time_dict = OrderedDict()
        for iteration, _ in enumerate(pbar):
            self._state["iteration"] = iteration
            with Timer(name="data", output_dict=time_dict):
                training_data = next(self._dataloader)
            training_data = move_data_to_device(training_data,
                                                self._state["devices"][0])
            schedule_info = self._optimizer.schedule(epoch, iteration)
            self._optimizer.zero_grad()
            # forward propagation
            with Timer(name="fwd", output_dict=time_dict):
                predict_data = self._model(training_data)
                training_losses, extras = OrderedDict(), OrderedDict()
                for loss_name, loss in self._losses.items():
                    training_losses[loss_name], extras[loss_name] = loss(
                        predict_data, training_data)
                total_loss = sum(training_losses.values())
            # backward propagation
            with Timer(name="bwd", output_dict=time_dict):
                total_loss.backward()
            # TODO: No need for average_gradients() when wrapped model with DDP?
            # TODO: need to register _optimizer.modify_grad as hook
            #       see https://discuss.pytorch.org/t/distributeddataparallel-modify-gradient-before-averaging/59291
            # self._optimizer.modify_grad(epoch, iteration)
            with Timer(name="optim", output_dict=time_dict):
                self._optimizer.step()

            trainer_data = dict(
                schedule_info=schedule_info,
                training_losses=training_losses,
                extras=extras,
                time_dict=time_dict,
            )

            for monitor in self._monitors:
                monitor.update(trainer_data)
            del training_data
            print_str = self._state["print_str"]
            pbar.set_description(print_str)
        del pbar  # need to be freed, otherwise spawn would be stucked.
    def train(self):
        if not self._state["initialized"]:
            self.init_train()
        self._state["initialized"] = True

        # epoch counter +1
        self._state["epoch"] += 1
        epoch = self._state["epoch"]
        num_iterations = self._hyper_params["num_iterations"]

        # udpate engine_state
        self._state["max_epoch"] = self._hyper_params["max_epoch"]
        self._state["max_iteration"] = num_iterations

        self._optimizer.modify_grad(epoch)
        pbar = tqdm(range(num_iterations))
        self._state["pbar"] = pbar
        self._state["print_str"] = ""

        time_dict = OrderedDict()
        for iteration, _ in enumerate(pbar):
            self._state["iteration"] = iteration
            with Timer(name="data", output_dict=time_dict):
                training_data = next(self._dataloader)
            training_data = move_data_to_device(training_data,
                                                self._state["devices"][0])

            schedule_info = self._optimizer.schedule(epoch, iteration)
            self._optimizer.zero_grad()

            # forward propagation
            with Timer(name="fwd", output_dict=time_dict):
                pred_data = self._model(training_data)

            # compute losses
            loss_extra_dict = OrderedDict()
            for k in self._losses:
                loss_extra_dict[k] = self._losses[k](pred_data, training_data)

            # split losses & extras
            training_losses, extras = OrderedDict(), OrderedDict()
            for k in self._losses:
                training_losses[k], extras[k] = loss_extra_dict[k]

            # get loss weights & sum up
            loss_weights = OrderedDict()
            for k in self._losses:
                loss_weights[k] = self._losses[k].get_hps()["weight"]
            total_loss = [
                training_losses[k] * loss_weights[k] for k in self._losses
            ]
            total_loss = sum(total_loss)

            # backward propagation
            with Timer(name="bwd", output_dict=time_dict):
                total_loss.backward()
            self._optimizer.modify_grad(epoch, iteration)
            with Timer(name="optim", output_dict=time_dict):
                self._optimizer.step()

            trainer_data = dict(
                schedule_info=schedule_info,
                training_losses=training_losses,
                extras=extras,
                time_dict=time_dict,
            )

            for monitor in self._monitors:
                monitor.update(trainer_data)
            del training_data
            print_str = self._state["print_str"]
            pbar.set_description(print_str)
示例#7
0
    logger.info(
        "Merged with root_cfg imported from videoanalyst.config.config.cfg")
    # resolve config
    root_cfg = complete_path_wt_root_in_cfg(root_cfg, ROOT_PATH)
    root_cfg = root_cfg.train
    task, task_cfg = specify_task(root_cfg)
    task_cfg.freeze()
    # backup config
    cfg_bak_dir = osp.join(task_cfg.exp_save, task_cfg.exp_name, "logs")
    ensure_dir(cfg_bak_dir)
    cfg_bak_file = osp.join(cfg_bak_dir, "%s_bak.yaml" % task_cfg.exp_name)
    with open(cfg_bak_file, "w") as f:
        f.write(task_cfg.dump())
    logger.info("Task configuration backed up at %s" % cfg_bak_file)
    # build dummy dataloader (for dataset initialization)
    with Timer(name="Dummy dataloader building", verbose=True, logger=logger):
        dataloader = dataloader_builder.build(task, task_cfg.data)
    del dataloader
    logger.info("Dummy dataloader destroyed.")
    # build model
    model = model_builder.build(task, task_cfg.model)
    # prepare to spawn
    world_size = task_cfg.num_processes
    torch.multiprocessing.set_start_method('spawn', force=True)
    # spawn trainer process
    mp.spawn(run_dist_training,
             args=(world_size, task, task_cfg, parsed_args, model),
             nprocs=world_size,
             join=True)
    logger.info("Distributed training completed.")
    def train(self):
        if not self._state["initialized"]:
            self.init_train()
        self._state["initialized"] = True

        self._state["epoch"] += 1
        epoch = self._state["epoch"]
        num_iterations = self._hyper_params["num_iterations"]

        # udpate engine_state
        self._state["max_iteration"] = num_iterations
        self._optimizer.modify_grad(epoch)
        self._state["print_str"] = ""

        time_dict = OrderedDict()
        for iteration in range(num_iterations):
            start_time = time.time()
            self._state["iteration"] = iteration
            with Timer(name="data", output_dict=time_dict):
                training_data = next(self._dataloader)
            training_data = move_data_to_device(training_data,
                                                self._state["devices"][0])
            schedule_info = self._optimizer.schedule(epoch, iteration)
            self._optimizer.zero_grad()
            with Timer(name="track_fwd", output_dict=time_dict):
                with torch.no_grad():
                    tracker_output = self.tracker(training_data, phase="train")
                corr_fea = tracker_output["corr_fea"].detach()
            # forward propagation
            with Timer(name="segfwd", output_dict=time_dict):
                predict_data = self._model(
                    training_data["seg_img"], corr_fea,
                    training_data["filtered_global_img"])
                training_losses, extras = OrderedDict(), OrderedDict()
                for loss_name, loss in self._losses.items():
                    training_losses[loss_name], extras[loss_name] = loss(
                        predict_data, training_data["seg_mask"])
                total_loss = sum(training_losses.values())
            # backward propagation
            with Timer(name="bwd", output_dict=time_dict):
                if self._optimizer.grad_scaler is not None:
                    self._optimizer.grad_scaler.scale(total_loss).backward()
                else:
                    total_loss.backward()
            with Timer(name="optim", output_dict=time_dict):
                self._optimizer.step()
            cost_time = (num_iterations - iteration) * (time.time() -
                                                        start_time)
            if dist_utils.get_rank() == 0:
                trainer_data = dict(
                    schedule_info=schedule_info,
                    training_losses=training_losses,
                    training_data=training_data,
                    extras=extras,
                    time_dict=time_dict,
                    predict_data=predict_data,
                    iter=iteration,
                )
                for monitor in self._monitors:
                    monitor.update(trainer_data)
                print_str = "{}/{} epoch {} eta ({}h {}m {}s) bs: {} ".format(
                    iteration, num_iterations, epoch, int(cost_time // (3600)),
                    int(cost_time % 3600 // 60), int(cost_time % 60),
                    training_data["im_x"].size(0)) + self._state["print_str"]
                logger.info(print_str)
            del training_data