예제 #1
0
def run_training(config):
    tb_writer = CustomWriter(config)
    logger = saver.get_logger(config)

    num_epochs = config.training.num_epochs
    model = {'G': models.get_model(config, tag='G'), 'D': models.get_model(config, tag='D')}
    model = {key: torch.nn.DataParallel(value) for key, value in model.items()}
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader = get_DataLoader(config, phase="train")
    val_loader = get_DataLoader(config, phase="val")

    optimizer = {'G': optimizers.get_optimizer(model['G'].parameters(), config.model.G.optimizer),
                 'D': optimizers.get_optimizer(model['D'].parameters(), config.model.D.optimizer)}

    scheduler = {'G': schedulers.get_scheduler(optimizer['G'], config.model.G.scheduler),
                 'D': schedulers.get_scheduler(optimizer['D'], config.model.D.scheduler)}

    criterion = {'G': losses.get_loss(config.model.G.criterion),
                 'D': losses.get_loss(config.model.D.criterion)}

    start_epoch_num = (
        saver.get_latest_epoch_num(config)
        if config.model.G.load_state == -1 or config.model.G.load_state == "latest"
        else config.model.G.load_state
    )

    # Dynamic imports according to protocols
    epoch_module = importlib.import_module('protocols.{}.epoch'.format(config.protocol))
    train_one_epoch, test = getattr(epoch_module, 'train_one_epoch'), getattr(epoch_module, 'test')

    for epoch in range(start_epoch_num + 1, start_epoch_num + num_epochs + 1):

        train_buffer = train_one_epoch(
            config=config,
            model=model,
            device=device,
            train_loader=train_loader,
            optimizer=optimizer,
            scheduler=scheduler,
            criterion=criterion,
            epoch=epoch,
            logger=logger,
            log_interval=config.training.log_interval,
        )
        tb_writer.write_result(train_buffer, epoch, phase="train")

        if epoch % config.training.validation_period == 0:
            val_buffer = test(
                config=config,
                model=model,
                device=device,
                test_loader=val_loader,
                criterion=criterion,
                logger=logger,
                phase="val",
                tag=epoch,
                log_interval=8,
            )
            tb_writer.write_result(val_buffer, epoch, phase="val")
예제 #2
0
    if config.adaptive_rate:
        lr = tf.placeholder(tf.float32, shape=())
    else:
        lr = config.learning_rate

    gen_optimizer = tf.train.AdamOptimizer(config.gen_learning_rate, beta1=0.5, beta2=0.9)
    disc_optimizer = tf.train.AdamOptimizer(lr, beta1=0.5, beta2=0.9)

    clipper_ret = get_clipper(config.clipper, config)
    if isinstance(clipper_ret, tuple):
        clipper, sampler = clipper_ret
        sampler.set_data_loader(sample_data_loader)
        sampler.keep_memory = False
    else:
        clipper = clipper_ret
        sampler = None

    scheduler = get_scheduler(config.scheduler, config)
    def callback_before_train(_0, _1, _2):
        print(clipper.info())
    supervisor = BasicSupervisorMNIST(config, clipper, scheduler, sampler=sampler,
                                      callback_before_train=callback_before_train)
    if config.adaptive_rate:
        supervisor.put_key("lr", lr)
    print(gan_data_loader)
    train(config, gan_data_loader, mnist.generator_forward, mnist.discriminator_forward,
          gen_optimizer=gen_optimizer,
          disc_optimizer=disc_optimizer, accountant=accountant,
          supervisor=supervisor, n_samples=n_samples)
예제 #3
0
    def __init__(self, cfg, writer, img_writer, logger, run_id):
        # Copy shared config fields
        if "monodepth_options" in cfg:
            cfg["data"].update(cfg["monodepth_options"])
            cfg["model"].update(cfg["monodepth_options"])
            cfg["training"]["monodepth_loss"].update(cfg["monodepth_options"])
        if "generated_depth_dir" in cfg["data"]:
            dataset_name = f"{cfg['data']['dataset']}_" \
                           f"{cfg['data']['width']}x{cfg['data']['height']}"
            depth_teacher = cfg["data"].get("depth_teacher", None)
            assert not (depth_teacher and cfg['model'].get('detph_estimator_weights') is not None)
            if depth_teacher is not None:
                cfg["data"]["generated_depth_dir"] += dataset_name + "/" + depth_teacher + "/"
            else:
                cfg["data"]["generated_depth_dir"] += dataset_name + "/" + cfg['model']['depth_estimator_weights'] + "/"

        # Setup seeds
        setup_seeds(cfg.get("seed", 1337))
        if cfg["data"]["dataset_seed"] == "same":
            cfg["data"]["dataset_seed"] = cfg["seed"]

        # Setup device
        torch.backends.cudnn.benchmark = cfg["training"].get("benchmark", True)
        self.cfg = cfg
        self.writer = writer
        self.img_writer = img_writer
        self.logger = logger
        self.run_id = run_id
        self.mIoU = 0
        self.fwAcc = 0
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.setup_segmentation_unlabeled()

        self.unlabeled_require_depth = (self.cfg["training"]["unlabeled_segmentation"] is not None and
                                        (self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depth" or
                                         self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depthcomp" or
                                         self.cfg["training"]["unlabeled_segmentation"]["mix_mask"] == "depthhist"))

        # Prepare depth estimates
        do_precalculate_depth = self.cfg["training"]["segmentation_lambda"] != 0 and self.unlabeled_require_depth and \
                                self.cfg['model']['segmentation_name'] != 'mtl_pad'
        use_depth_teacher = cfg["data"].get("depth_teacher", None) is not None
        if do_precalculate_depth or use_depth_teacher:
            assert not (do_precalculate_depth and use_depth_teacher)
            if not self.cfg["training"].get("disable_depth_estimator", False):
                print("Prepare depth estimates")
                depth_estimator = DepthEstimator(cfg)
                depth_estimator.prepare_depth_estimates()
                del depth_estimator
                torch.cuda.empty_cache()
        else:
            self.cfg["data"]["generated_depth_dir"] = None

        # Setup Dataloader
        load_labels, load_sequence = True, True
        if self.cfg["training"]["monodepth_lambda"] == 0:
            load_sequence = False
        if self.cfg["training"]["segmentation_lambda"] == 0:
            load_labels = False
        train_data_cfg = deepcopy(self.cfg["data"])
        if not do_precalculate_depth and not use_depth_teacher:
            train_data_cfg["generated_depth_dir"] = None
        self.train_loader = build_loader(train_data_cfg, "train", load_labels=load_labels, load_sequence=load_sequence)
        if self.cfg["training"].get("minimize_entropy_unlabeled", False) or self.enable_unlabled_segmentation:
            unlabeled_segmentation_cfg = deepcopy(self.cfg["data"])
            if not self.only_unlabeled and self.mix_use_gt:
                unlabeled_segmentation_cfg["load_onehot"] = True
            if self.only_unlabeled:
                unlabeled_segmentation_cfg.update({"load_unlabeled": True, "load_labeled": False})
            elif self.only_labeled:
                unlabeled_segmentation_cfg.update({"load_unlabeled": False, "load_labeled": True})
            else:
                unlabeled_segmentation_cfg.update({"load_unlabeled": True, "load_labeled": True})
            if self.mix_video:
                assert not self.mix_use_gt and not self.only_labeled and not self.only_unlabeled, \
                    "Video sample indices are not compatible with non-video indices."
                unlabeled_segmentation_cfg.update({"only_sequences_with_segmentation": not self.mix_video,
                                                   "restrict_to_subset": None})
            self.unlabeled_loader = build_loader(unlabeled_segmentation_cfg, "train",
                                                 load_labels=load_labels if not self.mix_video else False,
                                                 load_sequence=load_sequence)
        else:
            self.unlabeled_loader = None
        self.val_loader = build_loader(self.cfg["data"], "val", load_labels=load_labels,
                                       load_sequence=load_sequence)
        self.n_classes = self.train_loader.n_classes

        # monodepth dataloader settings uses drop_last=True and shuffle=True even for val
        self.train_data_loader = data.DataLoader(
            self.train_loader,
            batch_size=self.cfg["training"]["batch_size"],
            num_workers=self.cfg["training"]["n_workers"],
            shuffle=self.cfg["data"]["shuffle_trainset"],
            pin_memory=True,
            # Setting to false will cause crash at the end of epoch
            drop_last=True,
        )
        if self.unlabeled_loader is not None:
            self.unlabeled_data_loader = infinite_iterator(data.DataLoader(
                self.unlabeled_loader,
                batch_size=self.cfg["training"]["batch_size"],
                num_workers=self.cfg["training"]["n_workers"],
                shuffle=self.cfg["data"]["shuffle_trainset"],
                pin_memory=True,
                # Setting to false will cause crash at the end of epoch
                drop_last=True,
            ))

        self.val_batch_size = self.cfg["training"]["val_batch_size"]
        self.val_data_loader = data.DataLoader(
            self.val_loader,
            batch_size=self.val_batch_size,
            num_workers=self.cfg["training"]["n_workers"],
            pin_memory=True,
            # If using a dataset with odd number of samples (CamVid), the memory consumption suddenly increases for the
            # last batch. This can be circumvented by dropping the last batch. Only do that if it is necessary for your
            # system as it will result in an incomplete validation set.
            # drop_last=True,
        )

        # Setup Model
        self.model = get_model(cfg["model"], self.n_classes).to(self.device)
        # print(self.model)
        assert not (self.enable_unlabled_segmentation and self.cfg["training"]["save_monodepth_ema"])
        if self.enable_unlabled_segmentation and not self.only_labeled:
            print("Create segmentation ema model.")
            self.ema_model = self.create_ema_model(self.model).to(self.device)
        elif self.cfg["training"]["save_monodepth_ema"]:
            print("Create depth ema model.")
            # TODO: Try to remove unnecessary components and fit into gpu for better performance
            self.ema_model = self.create_ema_model(self.model)  # .to(self.device)
        else:
            self.ema_model = None

        # Setup optimizer, lr_scheduler and loss function
        optimizer_cls = get_optimizer(cfg)
        optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if
                            k not in ["name", "backbone_lr", "pose_lr", "depth_lr", "segmentation_lr"]}
        train_params = get_train_params(self.model, self.cfg)
        self.optimizer = optimizer_cls(train_params, **optimizer_params)

        self.scheduler = get_scheduler(self.optimizer, self.cfg["training"]["lr_schedule"])

        # Creates a GradScaler once at the beginning of training.
        self.scaler = GradScaler(enabled=self.cfg["training"]["amp"])

        self.loss_fn = get_segmentation_loss_function(self.cfg)
        self.monodepth_loss_calculator_train = get_monodepth_loss(self.cfg, is_train=True)
        self.monodepth_loss_calculator_val = get_monodepth_loss(self.cfg, is_train=False, batch_size=self.val_batch_size)

        if cfg["training"]["early_stopping"] is None:
            logger.info("Using No Early Stopping")
            self.earlyStopping = None
        else:
            self.earlyStopping = EarlyStopping(
                patience=round(cfg["training"]["early_stopping"]["patience"] / cfg["training"]["val_interval"]),
                min_delta=cfg["training"]["early_stopping"]["min_delta"],
                cumulative_delta=cfg["training"]["early_stopping"]["cum_delta"],
                logger=logger
            )