Esempio n. 1
0
def test_load_checkpoint_with_different_num_classes(dirname):
    model = DummyPretrainedModel()
    to_save_single_object = {"model": model}

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)

    handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
    handler(trainer, to_save_single_object)

    fname = handler.last_checkpoint
    loaded_checkpoint = torch.load(fname)

    to_load_single_object = {"pretrained_features": model.features}

    with pytest.raises(RuntimeError):
        Checkpoint.load_objects(to_load_single_object, loaded_checkpoint)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        Checkpoint.load_objects(to_load_single_object,
                                loaded_checkpoint,
                                strict=False,
                                blah="blah")

    loaded_weights = to_load_single_object["pretrained_features"].state_dict(
    )["weight"]

    assert torch.all(model.state_dict()["features.weight"].eq(loaded_weights))
def _resume_training(resume_from: Union[str, Path], to_save: Dict[str, Any]):
    if resume_from:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f'Checkpoint "{checkpoint_fp}" is not found'
        print(f'Resuming from a checkpoint: {checkpoint_fp}')
        checkpoint = torch.load(checkpoint_fp.as_posix())
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)
Esempio n. 3
0
    def _load_model(self):
        model = get_model(self.config)
        model.to(self.device)
        checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
        Checkpoint.load_objects(to_load={"model": model}, checkpoint=checkpoint)
        model.eval()

        return model
Esempio n. 4
0
    def __call__(self, engine):
        checkpoint = torch.load(self.load_path)
        if len(self.load_dict) == 1:
            key = list(self.load_dict.keys())[0]
            if not (key in checkpoint):
                checkpoint = {key: checkpoint}

        Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint)
        self.logger.info(f"Restored all variables from {self.load_path}")
Esempio n. 5
0
 def resume(self):
     d = Path(self.save_path)
     pattern = "checkpoint_*.pth"
     saves = list(d.glob(pattern))
     if len(saves) == 0:
         raise FileNotFoundError("No checkpoint to load in %s" %
                                 (self.save_path))
     fp = max(saves, key=lambda f: f.stat().st_mtime)
     checkpoint = torch.load(fp)
     Checkpoint.load_objects(self.to_save(), checkpoint)
     print("Load trainer from %s" % fp)
Esempio n. 6
0
    def get_model(self, model_name, device, prefix='', path=None):
        if path is None:
            path = join(Constants.MODELS_PATH, model_name)
        best_file, best_loss = get_model_filename(path, prefix)
        model = copy.deepcopy(self.my_models[model_name])
        to_load = {'model': model}

        checkpoint = torch.load(join(path, best_file), map_location=device)
        Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

        return model, best_loss, self.get_thresholds(model_name)
Esempio n. 7
0
def extract_model(ckp_file,
                  device='cuda' if torch.cuda.is_available() else 'cpu'):
    tokenizer = BertTokenizer.from_pretrained(base_model)
    model = BertClassificationModel(cls=tokenizer.vocab_size,
                                    model_file=base_model)

    to_load = {'BertClassificationModel': model}
    checkpoint = torch.load(ckp_file, map_location=device)

    Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
    model.to(device)

    model.bert.save_pretrained('./extract_bert/')
Esempio n. 8
0
def create_trainer_and_evaluators(
    model: nn.Module,
    optimizer: Optimizer,
    criterion: nn.Module,
    data_loaders: Dict[str, DataLoader],
    metrics: Dict[str, Metric],
    config: ConfigSchema,
    logger: Logger,
) -> Tuple[Engine, Dict[str, Engine]]:
    trainer = get_trainer(model, criterion, optimizer)
    trainer.logger = logger

    evaluators = get_evaluators(model, metrics)
    setup_evaluation(trainer, evaluators, data_loaders, logger)

    lr_scheduler = get_lr_scheduler(config, optimizer, trainer, evaluators["val"])

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }

    common.setup_common_training_handlers(
        trainer=trainer,
        to_save=to_save,
        save_every_iters=config.checkpoint_every,
        save_handler=get_save_handler(config),
        with_pbars=False,
        train_sampler=data_loaders["train"].sampler,
    )
    trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler)
    ProgressBar(persist=False).attach(
        trainer,
        metric_names="all",
        event_name=Events.ITERATION_COMPLETED(every=config.log_every_iters),
    )

    resume_from = config.resume_from
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix()
        )
        logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer, evaluators
 def load_trainer_from_checkpoint(self):
     if self.hparams.checkpoint_dir is not None:
         if not self.hparams.load_model_only:
             objects_to_checkpoint = {
                 "trainer": self.trainer,
                 "model": self.model, 
                 "optimizer": self.optimizer,
                 "scheduler": self.scheduler
             }
             if USE_AMP:
                 objects_to_checkpoint["amp"] = amp
         else:
             objects_to_checkpoint = {"model": self.model}
         objects_to_checkpoint = {k: v for k, v in objects_to_checkpoint.items() if v is not None}
         checkpoint = torch.load(self.hparams.checkpoint_dir, map_location="cpu")
         Checkpoint.load_objects(to_load=objects_to_checkpoint, checkpoint=checkpoint)
Esempio n. 10
0
def _test_checkpoint_load_objects_ddp(device):
    model = DummyModel().to(device)
    device_ids = (
        None if "cpu" in device.type else [device,]
    )
    ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=device_ids)
    opt = torch.optim.SGD(ddp_model.parameters(), lr=0.01)

    # single object:
    to_load = {"model": ddp_model}
    checkpoint = ddp_model.module.state_dict()
    Checkpoint.load_objects(to_load, checkpoint)

    # multiple objects:
    to_load = {"model": ddp_model, "opt": opt}
    checkpoint = {"model": ddp_model.module.state_dict(), "opt": opt.state_dict()}
    Checkpoint.load_objects(to_load, checkpoint)
def resume_from_checkpoint(to_save, conf, device=None):
    # type: (Dict[str, Any], DictConfig, Device) -> None
    to_load = {k: v for k, v in to_save.items() if v is not None}

    if conf.drop_state:
        # we might want to swap optimizer or to reset it state
        drop_keys = set(conf.drop_state)
        to_load = {k: v for k, v in to_load.items() if k not in drop_keys}

    checkpoint = torch.load(conf.load, map_location=device)
    ema_key = "model_ema"
    if ema_key in to_load and ema_key not in checkpoint:
        checkpoint[ema_key] = checkpoint["model"]
        logging.warning("There are no EMA weights in the checkpoint. "
                        "Using saved model weights as a starting point for the EMA.")

    Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
Esempio n. 12
0
def test_checkpoint_load_objects_from_saved_file(dirname):
    def _get_single_obj_to_save():
        model = DummyModel()
        to_save = {
            "model": model,
        }
        return to_save

    def _get_multiple_objs_to_save():
        model = DummyModel()
        optim = torch.optim.SGD(model.parameters(), lr=0.001)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.5)
        to_save = {
            "model": model,
            "optimizer": optim,
            "lr_scheduler": lr_scheduler,
        }
        return to_save

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)

    # case: multiple objects
    handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
    to_save = _get_multiple_objs_to_save()
    handler(trainer, to_save)
    fname = handler.last_checkpoint
    assert isinstance(fname, str)
    assert os.path.join(dirname, _PREFIX) in fname
    assert os.path.exists(fname)
    loaded_objects = torch.load(fname)
    Checkpoint.load_objects(to_save, loaded_objects)
    os.remove(fname)

    # case: saved multiple objects, loaded single object
    handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
    to_save = _get_multiple_objs_to_save()
    handler(trainer, to_save)
    fname = handler.last_checkpoint
    assert isinstance(fname, str)
    assert os.path.join(dirname, _PREFIX) in fname
    assert os.path.exists(fname)
    loaded_objects = torch.load(fname)
    to_load = {'model': to_save['model']}
    Checkpoint.load_objects(to_load, loaded_objects)
    os.remove(fname)

    # case: single object
    handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=1)
    to_save = _get_single_obj_to_save()
    handler(trainer, to_save)
    fname = handler.last_checkpoint
    assert isinstance(fname, str)
    assert os.path.join(dirname, _PREFIX) in fname
    assert os.path.exists(fname)
    loaded_objects = torch.load(fname)
    Checkpoint.load_objects(to_save, loaded_objects)
Esempio n. 13
0
    def resume(self, fp=None):
        assert self._traier_state == TrainerState.INIT

        if fp is None:
            d = Path(self.save_path)
            pattern = "checkpoint_*.pt*"
            saves = list(d.glob(pattern))
            if len(saves) == 0:
                raise FileNotFoundError("No checkpoint to load in %s" %
                                        self.save_path)
            fp = max(saves, key=lambda f: f.stat().st_mtime)

        checkpoint = torch.load(fp)

        if not self.fp16 and 'amp' in checkpoint:
            del checkpoint['amp']

        Checkpoint.load_objects(self.to_save(), checkpoint)

        self._train_engine_state = checkpoint['train_engine']
        self._eval_engine_state = checkpoint['eval_engine']
        self._traier_state = TrainerState.FITTING

        print("Load trainer from %s" % fp)
Esempio n. 14
0
def inference(test_loader, metircs):
    if os.listdir(opt.checkpoint_dir):
        for root, dirs, files in os.walk(opt.checkpoint_dir):
            checkpoint_file = root + '\\' + files[-1]
        checkpoint = torch.load(checkpoint_file)
        model = tv.models.resnet50()
        model.fc = nn.Linear(2048, 2)
        object_to_checkpoint = {'model': model}
        Checkpoint.load_objects(to_load=object_to_checkpoint, checkpoint=checkpoint)
    test = create_supervised_evaluator(model, metrics)

    @test.on(Events.COMPLETED)
    def get_result(engine):
        preds, labels = test.state.metrics['test_result']
        preds = preds.reshape((-1, 1))
        labels = labels.reshape((-1, 1))
        preds = np.clip(preds, 0.005, 0.995)
        result = np.concatenate((labels, preds), axis=1)
        result = pd.DataFrame(result, columns=['id', 'label'])
        print(result)
        result.to_csv("..\\result.csv", index=None)
        return result

    test.run(test_loader)
Esempio n. 15
0
def test_checkpoint_load_objects():

    with pytest.raises(TypeError, match=r"Argument checkpoint should be a dictionary"):
        Checkpoint.load_objects({}, [])

    with pytest.raises(TypeError, match=r"should have `load_state_dict` method"):
        Checkpoint.load_objects({"a": None}, {"a": None})

    model = DummyModel()
    to_load = {'model': model}

    with pytest.raises(ValueError, match=r"from `to_load` is not found in the checkpoint"):
        Checkpoint.load_objects(to_load, {})

    model = DummyModel()
    to_load = {'model': model}
    model2 = DummyModel()

    chkpt = {'model': model2.state_dict()}
    Checkpoint.load_objects(to_load, chkpt)
    assert model.state_dict() == model2.state_dict()
Esempio n. 16
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()

        with autocast(enabled=with_amp):
            y_pred = model(x)
            loss = criterion(y_pred, y)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 17
0
                        resume =Events.ITERATION_STARTED,
                        pause =Events.ITERATION_COMPLETED,
                        step =Events.ITERATION_COMPLETED)

    train_timer.attach(trainer,
                        start =Events.EPOCH_STARTED,
                        resume =Events.EPOCH_STARTED,
                        pause =Events.EPOCH_COMPLETED,
                        step =Events.EPOCH_COMPLETED)

    if len(args.load_model) > 0:
        load_model_path = args.load_model
        print("load mode " + load_model_path)
        to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer}
        checkpoint = torch.load(load_model_path)
        Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
        print("load model complete")
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr
            print("change lr to ", args.lr)
    else:
        print("do not load, keep training")


    @trainer.on(Events.ITERATION_COMPLETED(every=100))
    def log_training_loss(trainer):
        timestamp = get_readable_time()
        print(timestamp + " Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))


    @trainer.on(Events.EPOCH_COMPLETED)
Esempio n. 18
0
def run(
    train_batch_size,
    val_batch_size,
    epochs,
    lr,
    momentum,
    log_interval,
    log_dir,
    checkpoint_every,
    resume_from,
    crash_iteration=1000,
):

    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    writer = SummaryWriter(log_dir=log_dir)
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    criterion = nn.NLLLoss()
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.5)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                "accuracy": Accuracy(),
                                                "nll": Loss(criterion)
                                            },
                                            device=device)

    @trainer.on(Events.EPOCH_COMPLETED)
    def lr_step(engine):
        lr_scheduler.step()

    desc = "ITERATION - loss: {:.4f} - lr: {:.4f}"
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0, lr))

    if log_interval is None:
        e = Events.ITERATION_COMPLETED
        log_interval = 1
    else:
        e = Events.ITERATION_COMPLETED(every=log_interval)

    @trainer.on(e)
    def log_training_loss(engine):
        lr = optimizer.param_groups[0]["lr"]
        pbar.desc = desc.format(engine.state.output, lr)
        pbar.update(log_interval)
        writer.add_scalar("training/loss", engine.state.output,
                          engine.state.iteration)
        writer.add_scalar("lr", lr, engine.state.iteration)

    if resume_from is None:

        @trainer.on(Events.ITERATION_COMPLETED(once=crash_iteration))
        def _(engine):
            raise Exception("STOP at {}".format(engine.state.iteration))

    else:

        @trainer.on(Events.STARTED)
        def _(engine):
            pbar.n = engine.state.iteration

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))
        writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))
        pbar.n = pbar.last_print_n = 0
        writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("valdation/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    objects_to_checkpoint = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    training_checkpoint = Checkpoint(to_save=objects_to_checkpoint,
                                     save_handler=DiskSaver(
                                         log_dir, require_empty=False))

    trainer.add_event_handler(
        Events.ITERATION_COMPLETED(every=checkpoint_every),
        training_checkpoint)

    if resume_from is not None:
        tqdm.write("Resume from a checkpoint: {}".format(resume_from))
        checkpoint = torch.load(resume_from)
        Checkpoint.load_objects(to_load=objects_to_checkpoint,
                                checkpoint=checkpoint)

    try:
        trainer.run(train_loader, max_epochs=epochs)
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    pbar.close()
    writer.close()
Esempio n. 19
0
def create_trainer(model, optimizer, criterion, train_sampler, config, logger, with_clearml):
    device = config.device
    prepare_batch = data.prepare_image_mask

    # Setup trainer
    accumulation_steps = config.get("accumulation_steps", 1)
    model_output_transform = config.get("model_output_transform", lambda x: x)

    with_amp = config.get("with_amp", True)
    scaler = GradScaler(enabled=with_amp)

    def forward_pass(batch):
        model.train()
        x, y = prepare_batch(batch, device=device, non_blocking=True)
        with autocast(enabled=with_amp):
            y_pred = model(x)
            y_pred = model_output_transform(y_pred)
            loss = criterion(y_pred, y) / accumulation_steps
        return loss

    def amp_backward_pass(engine, loss):
        scaler.scale(loss).backward()
        if engine.state.iteration % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

    def hvd_amp_backward_pass(engine, loss):
        scaler.scale(loss).backward()
        optimizer.synchronize()
        with optimizer.skip_synchronize():
            scaler.step(optimizer)
            scaler.update()
        optimizer.zero_grad()

    if idist.backend() == "horovod" and with_amp:
        backward_pass = hvd_amp_backward_pass
    else:
        backward_pass = amp_backward_pass

    def training_step(engine, batch):
        loss = forward_pass(batch)
        output = {"supervised batch loss": loss.item()}
        backward_pass(engine, loss)
        return output

    trainer = Engine(training_step)
    trainer.logger = logger

    output_names = [
        "supervised batch loss",
    ]
    lr_scheduler = config.lr_scheduler

    to_save = {
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
        "trainer": trainer,
        "amp": scaler,
    }

    save_every_iters = config.get("save_every_iters", 1000)

    common.setup_common_training_handlers(
        trainer,
        train_sampler,
        to_save=to_save,
        save_every_iters=save_every_iters,
        save_handler=utils.get_save_handler(config.output_path.as_posix(), with_clearml),
        lr_scheduler=lr_scheduler,
        output_names=output_names,
        with_pbars=not with_clearml,
        log_every_iters=1,
    )

    resume_from = config.get("resume_from", None)
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 20
0
def main():
    # region Setup
    conf = parse_args()
    setup_seeds(conf.session.seed)
    tb_logger, tb_img_logger, json_logger = setup_all_loggers(conf)
    logger.info("Parsed configuration:\n" +
                pyaml.dump(OmegaConf.to_container(conf),
                           safe=True,
                           sort_dicts=False,
                           force_embed=True))

    # region Predicate classification engines
    datasets, dataset_metadata = build_datasets(conf.dataset)
    dataloaders = build_dataloaders(conf, datasets)

    model = build_model(conf.model,
                        dataset_metadata["train"]).to(conf.session.device)
    criterion = PredicateClassificationCriterion(conf.losses)

    pred_class_trainer = Trainer(pred_class_training_step, conf)
    pred_class_trainer.model = model
    pred_class_trainer.criterion = criterion
    pred_class_trainer.optimizer, scheduler = build_optimizer_and_scheduler(
        conf.optimizer, pred_class_trainer.model)

    pred_class_validator = Validator(pred_class_validation_step, conf)
    pred_class_validator.model = model
    pred_class_validator.criterion = criterion

    pred_class_tester = Validator(pred_class_validation_step, conf)
    pred_class_tester.model = model
    pred_class_tester.criterion = criterion
    # endregion

    if "resume" in conf:
        checkpoint = Path(conf.resume.checkpoint).expanduser().resolve()
        logger.debug(f"Resuming checkpoint from {checkpoint}")
        Checkpoint.load_objects(
            {
                "model": pred_class_trainer.model,
                "optimizer": pred_class_trainer.optimizer,
                "scheduler": scheduler,
                "trainer": pred_class_trainer,
            },
            checkpoint=torch.load(checkpoint,
                                  map_location=conf.session.device),
        )
        logger.info(f"Resumed from {checkpoint}, "
                    f"epoch {pred_class_trainer.state.epoch}, "
                    f"samples {pred_class_trainer.global_step()}")
    # endregion

    # region Predicate classification training callbacks
    def increment_samples(trainer: Trainer):
        images = trainer.state.batch[0]
        trainer.state.samples += len(images)

    pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                         increment_samples)

    ProgressBar(persist=True, desc="Pred class train").attach(
        pred_class_trainer, output_transform=itemgetter("losses"))

    tb_logger.attach(
        pred_class_trainer,
        OptimizerParamsHandler(
            pred_class_trainer.optimizer,
            param_name="lr",
            tag="z",
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.EPOCH_STARTED,
    )

    pred_class_trainer.add_event_handler(
        Events.ITERATION_COMPLETED,
        PredicateClassificationMeanAveragePrecisionBatch())
    pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                         RecallAtBatch(sizes=(5, 10)))

    tb_logger.attach(
        pred_class_trainer,
        OutputHandler(
            "train",
            output_transform=lambda o: {
                **o["losses"],
                "pc/mAP": o["pc/mAP"].mean().item(),
                **{k: r.mean().item()
                   for k, r in o["recalls"].items()},
            },
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.ITERATION_COMPLETED,
    )

    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        log_metrics,
        "Predicate classification training",
        "train",
        json_logger=None,
        tb_logger=tb_logger,
        global_step_fn=pred_class_trainer.global_step,
    )
    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        PredicateClassificationLogger(
            grid=(2, 3),
            tag="train",
            logger=tb_img_logger.writer,
            metadata=dataset_metadata["train"],
            global_step_fn=pred_class_trainer.global_step,
        ),
    )
    tb_logger.attach(
        pred_class_trainer,
        EpochHandler(
            pred_class_trainer,
            tag="z",
            global_step_transform=pred_class_trainer.global_step,
        ),
        Events.EPOCH_COMPLETED,
    )

    pred_class_trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        lambda _: pred_class_validator.run(dataloaders["val"]))
    # endregion

    # region Predicate classification validation callbacks
    ProgressBar(persist=True,
                desc="Pred class val").attach(pred_class_validator)

    if conf.losses["bce"]["weight"] > 0:
        Average(output_transform=lambda o: o["losses"]["loss/bce"]).attach(
            pred_class_validator, "loss/bce")
    if conf.losses["rank"]["weight"] > 0:
        Average(output_transform=lambda o: o["losses"]["loss/rank"]).attach(
            pred_class_validator, "loss/rank")
    Average(output_transform=lambda o: o["losses"]["loss/total"]).attach(
        pred_class_validator, "loss/total")

    PredicateClassificationMeanAveragePrecisionEpoch(
        itemgetter("target", "output")).attach(pred_class_validator, "pc/mAP")
    RecallAtEpoch((5, 10),
                  itemgetter("target",
                             "output")).attach(pred_class_validator,
                                               "pc/recall_at")

    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        lambda val_engine: scheduler.step(val_engine.state.metrics["loss/total"
                                                                   ]),
    )
    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        log_metrics,
        "Predicate classification validation",
        "val",
        json_logger,
        tb_logger,
        pred_class_trainer.global_step,
    )
    pred_class_validator.add_event_handler(
        Events.EPOCH_COMPLETED,
        PredicateClassificationLogger(
            grid=(2, 3),
            tag="val",
            logger=tb_img_logger.writer,
            metadata=dataset_metadata["val"],
            global_step_fn=pred_class_trainer.global_step,
        ),
    )
    pred_class_validator.add_event_handler(
        Events.COMPLETED,
        EarlyStopping(
            patience=conf.session.early_stopping.patience,
            score_function=lambda val_engine: -val_engine.state.metrics[
                "loss/total"],
            trainer=pred_class_trainer,
        ),
    )
    pred_class_validator.add_event_handler(
        Events.COMPLETED,
        Checkpoint(
            {
                "model": pred_class_trainer.model,
                "optimizer": pred_class_trainer.optimizer,
                "scheduler": scheduler,
                "trainer": pred_class_trainer,
            },
            DiskSaver(
                Path(conf.checkpoint.folder).expanduser().resolve() /
                conf.fullname),
            score_function=lambda val_engine: val_engine.state.metrics[
                "pc/recall_at_5"],
            score_name="pc_recall_at_5",
            n_saved=conf.checkpoint.keep,
            global_step_transform=pred_class_trainer.global_step,
        ),
    )
    # endregion

    if "test" in conf.dataset:
        # region Predicate classification testing callbacks
        if conf.losses["bce"]["weight"] > 0:
            Average(
                output_transform=lambda o: o["losses"]["loss/bce"],
                device=conf.session.device,
            ).attach(pred_class_tester, "loss/bce")
        if conf.losses["rank"]["weight"] > 0:
            Average(
                output_transform=lambda o: o["losses"]["loss/rank"],
                device=conf.session.device,
            ).attach(pred_class_tester, "loss/rank")
        Average(
            output_transform=lambda o: o["losses"]["loss/total"],
            device=conf.session.device,
        ).attach(pred_class_tester, "loss/total")

        PredicateClassificationMeanAveragePrecisionEpoch(
            itemgetter("target", "output")).attach(pred_class_tester, "pc/mAP")
        RecallAtEpoch((5, 10),
                      itemgetter("target",
                                 "output")).attach(pred_class_tester,
                                                   "pc/recall_at")

        ProgressBar(persist=True,
                    desc="Pred class test").attach(pred_class_tester)

        pred_class_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            log_metrics,
            "Predicate classification test",
            "test",
            json_logger,
            tb_logger,
            pred_class_trainer.global_step,
        )
        pred_class_tester.add_event_handler(
            Events.EPOCH_COMPLETED,
            PredicateClassificationLogger(
                grid=(2, 3),
                tag="test",
                logger=tb_img_logger.writer,
                metadata=dataset_metadata["test"],
                global_step_fn=pred_class_trainer.global_step,
            ),
        )
        # endregion

    # region Run
    log_effective_config(conf, pred_class_trainer, tb_logger)
    if not ("resume" in conf and conf.resume.test_only):
        max_epochs = conf.session.max_epochs
        if "resume" in conf:
            max_epochs += pred_class_trainer.state.epoch
        pred_class_trainer.run(
            dataloaders["train"],
            max_epochs=max_epochs,
            seed=conf.session.seed,
            epoch_length=len(dataloaders["train"]),
        )

    if "test" in conf.dataset:
        pred_class_tester.run(dataloaders["test"])

    add_session_end(tb_logger.writer, "SUCCESS")
    tb_logger.close()
    tb_img_logger.close()
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()
        y_pred = model(x)

        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # This can be helpful for XLA to avoid performance slow down if fetch loss.item() every iteration
        if (config["log_every_iters"] > 0 and
            (engine.state.iteration - 1) % config["log_every_iters"] == 0):
            batch_loss = loss.item()
            engine.state.saved_batch_loss = batch_loss
        else:
            batch_loss = engine.state.saved_batch_loss

        return {"batch loss": batch_loss}

    trainer = Engine(train_step)
    trainer.state.saved_batch_loss = -1.0
    trainer.state_dict_user_keys.append("saved_batch_loss")
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }
    metric_names = ["batch loss"]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(
        ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 22
0
def run(
    train_batch_size,
    val_batch_size,
    epochs,
    lr,
    momentum,
    log_interval,
    log_dir,
    checkpoint_every,
    resume_from,
    crash_iteration=-1,
    deterministic=False,
):
    # Setup seed to have same model's initialization:
    manual_seed(75)

    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    writer = SummaryWriter(log_dir=log_dir)
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    criterion = nn.NLLLoss()
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.5)

    # Setup trainer and evaluator
    if deterministic:
        tqdm.write("Setup deterministic trainer")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device,
                                        deterministic=deterministic)

    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                "accuracy": Accuracy(),
                                                "nll": Loss(criterion)
                                            },
                                            device=device)

    # Apply learning rate scheduling
    @trainer.on(Events.EPOCH_COMPLETED)
    def lr_step(engine):
        lr_scheduler.step()

    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=f"Epoch {0} - loss: {0:.4f} - lr: {lr:.4f}")

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        lr = optimizer.param_groups[0]["lr"]
        pbar.desc = f"Epoch {engine.state.epoch} - loss: {engine.state.output:.4f} - lr: {lr:.4f}"
        pbar.update(log_interval)
        writer.add_scalar("training/loss", engine.state.output,
                          engine.state.iteration)
        writer.add_scalar("lr", lr, engine.state.iteration)

    if crash_iteration > 0:

        @trainer.on(Events.ITERATION_COMPLETED(once=crash_iteration))
        def _(engine):
            raise Exception(f"STOP at {engine.state.iteration}")

    if resume_from is not None:

        @trainer.on(Events.STARTED)
        def _(engine):
            pbar.n = engine.state.iteration % engine.state.epoch_length

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # Compute and log validation metrics
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            f"Validation Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        pbar.n = pbar.last_print_n = 0
        writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("valdation/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # Setup object to checkpoint
    objects_to_checkpoint = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    training_checkpoint = Checkpoint(
        to_save=objects_to_checkpoint,
        save_handler=DiskSaver(log_dir, require_empty=False),
        n_saved=None,
        global_step_transform=lambda *_: trainer.state.epoch,
    )
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_every),
                              training_checkpoint)

    # Setup logger to print and dump into file: model weights, model grads and data stats
    # - first 3 iterations
    # - 4 iterations after checkpointing
    # This helps to compare resumed training with checkpointed training
    def log_event_filter(e, event):
        if event in [1, 2, 3]:
            return True
        elif 0 <= (event % (checkpoint_every * e.state.epoch_length)) < 5:
            return True
        return False

    fp = Path(log_dir) / ("run.log"
                          if resume_from is None else "resume_run.log")
    fp = fp.as_posix()
    for h in [log_data_stats, log_model_weights, log_model_grads]:
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(event_filter=log_event_filter),
            h,
            model=model,
            fp=fp)

    if resume_from is not None:
        tqdm.write(f"Resume from the checkpoint: {resume_from}")
        checkpoint = torch.load(resume_from)
        Checkpoint.load_objects(to_load=objects_to_checkpoint,
                                checkpoint=checkpoint)

    try:
        # Synchronize random states
        manual_seed(15)
        trainer.run(train_loader, max_epochs=epochs)
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    pbar.close()
    writer.close()
Esempio n. 23
0
    def setup(self):
        self.trainer = Engine(self.train_step)
        self.evaluator = Engine(self.eval_step)

        self.logger._init_logger(self.trainer, self.evaluator)
        # TODO: Multi-gpu support
        self.model.to(self.device)
        if self.use_amp:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level="O1")

        if self.checkpoint_dir is not None:
            if not self.load_model_only:
                objects_to_checkpoint = {
                    "trainer": self.trainer,
                    "model": self.model,
                    "optimizer": self.optimizer,
                    "scheduler": self.scheduler
                }
                if self.use_amp:
                    objects_to_checkpoint["amp"] = amp
            else:
                objects_to_checkpoint = {"model": self.model}
            objects_to_checkpoint = {
                k: v
                for k, v in objects_to_checkpoint.items() if v is not None
            }
            checkpoint = torch.load(self.checkpoint_dir)
            Checkpoint.load_objects(to_load=objects_to_checkpoint,
                                    checkpoint=checkpoint)

        train_handler_params = {
            "model": self.model,
            "optimizer": self.optimizer,
            "scheduler": self.scheduler,
            "metrics": self.train_metrics,
            "add_pbar": self.add_pbar
        }

        self.logger._add_train_events(**train_handler_params)

        to_save = {
            "model": self.model,
            "trainer": self.trainer,
            "optimizer": self.optimizer,
            "scheduler": self.scheduler
        }
        if self.use_amp:
            to_save["amp"] = amp

        eval_handler_params = {
            "metrics": self.validation_metrics,
            "validloader": self.val_loader,
            "to_save": to_save,
            "add_pbar": self.add_pbar
        }

        eval_handler_params["to_save"] = {
            k: v
            for k, v in eval_handler_params["to_save"].items() if v is not None
        }
        self.logger._add_eval_events(**eval_handler_params)

        if self.scheduler:
            self.trainer.add_event_handler(Events.ITERATION_STARTED,
                                           self.scheduler)

        self.trainer.logger = setup_logger("trainer")
        self.evaluator.logger = setup_logger("evaluator")
Esempio n. 24
0
def main():
  parser = argparse.ArgumentParser()

  # Required parameters
  parser.add_argument("--model", type=str, default='ffn', help="model's name")
  parser.add_argument("--mode", type=int, choices=[0, 1, 2], default=None)
  parser.add_argument("--SNRdb", type=float, default=None)
  parser.add_argument("--pilot_version", type=int, choices=[1, 2], default=1)
  parser.add_argument("--loss_type", type=str, default="BCELoss")
  parser.add_argument("--train_batch_size", type=int, default=128)
  parser.add_argument("--valid_batch_size", type=int, default=128)
  parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
  parser.add_argument("--max_norm", type=float, default=-1)
  parser.add_argument("--lr", type=float, default=1e-3)
  parser.add_argument("--noise_lambda", type=float, default=1.0)
  parser.add_argument("--lr_scheduler", type=str, choices=["linear", "cycle", "cosine"], default="linear")
  parser.add_argument("--reset_lr_scheduler", type=str, choices=["linear", "cycle", "cosine"], default=None)
  parser.add_argument("--reset_trainer", action='store_true')
  parser.add_argument("--modify_model", action='store_true')
  parser.add_argument("--wd", type=float, default=1e-4, help="weight decay")
  parser.add_argument("--eval_iter", type=int, default=10)
  parser.add_argument("--save_iter", type=int, default=10)
  parser.add_argument("--n_epochs", type=int, default=10)
  parser.add_argument("--flush_dataset", type=int, default=0)
  parser.add_argument("--no_cache", action='store_true')
  parser.add_argument("--with_pure_y", action='store_true') 
  parser.add_argument("--with_h", action='store_true') 
  parser.add_argument("--only_l1", action='store_true', help="Only loss 1")
  parser.add_argument("--interpolation", action='store_true', help="if interpolate between pure and reconstruction.") 
  parser.add_argument("--data_dir", type=str, default="data")
  parser.add_argument("--cache_dir", type=str, default="train_cache")
  parser.add_argument("--output_path", type=str, default="runs", help="model save")
  parser.add_argument("--resume_from", type=str, default=None, help="resume training.")
  parser.add_argument("--first_cache_index", type=int, default=0)
  parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                      help="Device (cuda or cpu)")
  parser.add_argument("--local_rank", type=int, default=-1,
                      help="Local rank for distributed training (-1: not distributed)")
  parser.add_argument("--seed", type=int, default=43)
  parser.add_argument("--debug", action='store_true')
  args = parser.parse_args()

  args.output_path = os.path.join(args.output_path, f'pilot_{args.pilot_version}')
  args.cache_dir = os.path.join(args.data_dir, args.cache_dir)
  # Setup CUDA, GPU & distributed training
  args.distributed = (args.local_rank != -1)
  if not args.distributed:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend="nccl", init_method='env://')
  args.n_gpu = torch.cuda.device_count() if not args.distributed else 1
  args.device = device

  # Set seed
  set_seed(args)
  logger = setup_logger("trainer", distributed_rank=args.local_rank)

  # Model construction
  model = getattr(models, args.model)(args)
  model = model.to(device)
  optimizer = AdamW(model.parameters(), lr = args.lr, weight_decay=args.wd)

  if args.loss_type == "MSELoss":
    criterion = nn.MSELoss(reduction='sum').to(device)
  else:
    criterion = getattr(nn, args.loss_type, getattr(auxiliary, args.loss_type, None))().to(device)
  criterion2 = nn.MSELoss(reduction='sum').to(device)

  if args.local_rank != -1:
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
    )

  train_dataset = SIGDataset(args, data_type="train")
  valid_dataset = SIGDataset(args, data_type="valid")
  train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
  valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
  train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, pin_memory=True, shuffle=(not args.distributed))
  valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, pin_memory=True, shuffle=False)
  
  lr_scheduler = None
  if args.lr_scheduler == "linear":
    lr_scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
  elif args.lr_scheduler == "cycle":
    lr_scheduler = LinearCyclicalScheduler(optimizer, 'lr', 0.0, args.lr, args.eval_iter * len(train_loader))
  elif args.lr_scheduler == "cosine":
    lr_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, args.eval_iter * len(train_loader))

  # Training function and trainer
  def update(engine, batch):
      model.train()
      y, x_label, y_pure, H = train_dataset.prepare_batch(batch, device=args.device)

      if args.with_pure_y and args.with_h:
        x_pred, y_pure_pred, H_pred = model(y, pure=y_pure, H=H, opp=True)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        if args.loss_type == "MSELoss":
          loss_1 = loss_1 / x_pred.size(0)
        loss_noise = criterion2(y_pure_pred, y_pure) / y.size(0) / args.gradient_accumulation_steps
        loss_noise_h = criterion2(H_pred, H) / H.size(0) / args.gradient_accumulation_steps
        if args.only_l1:
          loss = loss_1
        else:
          loss = loss_1 + loss_noise * args.noise_lambda + loss_noise_h
        output = (loss.item(), loss_1.item(), loss_noise.item(), loss_noise_h.item())
      elif args.with_pure_y:
        x_pred, y_pure_pred = model(y, pure=y_pure if args.interpolation else None, opp=True)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        loss_noise = criterion2(y_pure_pred, y_pure) / y.size(0) / args.gradient_accumulation_steps
        loss = loss_1 + loss_noise * args.noise_lambda
        output = (loss.item(), loss_1.item(), loss_noise.item())
      elif args.with_h:
        x_pred, H_pred = model(y, opp=True)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        loss_noise = criterion2(H_pred, H) / H.size(0) / args.gradient_accumulation_steps
        loss = loss_1 + loss_noise * args.noise_lambda
        output = (loss.item(), loss_1.item(), loss_noise.item())
      else:
        x_pred = model(y)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        loss = loss_1
        output = (loss.item(), loss_1.item(), torch.zeros_like(loss_1).item())

      loss.backward()
      if args.max_norm > 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
      if engine.state.iteration % args.gradient_accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
      return output
  trainer = Engine(update)

  to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
  metric_names = ["loss", "l1", "ln"]
  if args.with_pure_y and args.with_h:
    metric_names.append("lnH")

  common.setup_common_training_handlers(
    trainer=trainer,
    train_sampler=train_loader.sampler,
    to_save=to_save,
    save_every_iters=len(train_loader) * args.save_iter,
    lr_scheduler=lr_scheduler,
    output_names=metric_names,
    with_pbars=False,
    clear_cuda_cache=False,
    output_path=args.output_path,
    n_saved=2,
  )

  resume_from = args.resume_from
  if resume_from is not None:
    checkpoint_fp = Path(resume_from)
    assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix())
    logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
    checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
    if args.reset_trainer:
      to_save.pop("trainer")
    checkpoint_to_load = to_save if 'validation' not in resume_from else {"model": model}
    Checkpoint.load_objects(to_load=checkpoint_to_load, checkpoint=checkpoint)
    if args.reset_lr_scheduler is not None:
      if args.reset_lr_scheduler == "linear":
        lr_scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
      elif args.reset_lr_scheduler == "cycle":
        lr_scheduler = LinearCyclicalScheduler(optimizer, 'lr', 0.0, args.lr, args.eval_iter * len(train_loader))
      elif args.reset_lr_scheduler == "cosine":
        lr_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, args.eval_iter * len(train_loader))

  metrics = {
    "accuracy": Accuracy(lambda output: (torch.round(output[0][0]), output[1][0])), 
    "loss_1": Loss(criterion, output_transform=lambda output: (output[0][0], output[1][0])),
    "loss_noise": Loss(criterion2, output_transform=lambda output: (output[0][1], output[1][1]))
  }
  if args.with_pure_y and args.with_h:
    metrics["loss_noise_h"] = Loss(criterion2, output_transform=lambda output: (output[0][2], output[1][2]))

  def _inference(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
    model.eval()
    with torch.no_grad():
      x, y, x_pure, H = valid_dataset.prepare_batch(batch, device=args.device, non_blocking=True)
      if args.with_pure_y and args.with_h:
        y_pred, x_pure_pred, h_pred = model(x, opp=True)
        outputs = (y_pred, x_pure_pred, h_pred), (y, x_pure, H)
      elif args.with_pure_y:
        y_pred, x_pure_pred = model(x, opp=True)
        outputs = (y_pred, x_pure_pred), (y, x_pure)
      elif args.with_h:
        y_pred, h_pred = model(x, opp=True)
        outputs = (y_pred, h_pred), (y, H)
      else:
        y_pred = model(x)
        x_pure_pred = x_pure
        outputs = (y_pred, x_pure_pred), (y, x_pure)       
      return outputs
  evaluator = Engine(_inference)
  for name, metric in metrics.items():
      metric.attach(evaluator, name)

  trainer.add_event_handler(Events.EPOCH_COMPLETED(every=args.eval_iter), lambda _: evaluator.run(valid_loader))

  if args.flush_dataset > 0:
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=args.n_epochs//args.flush_dataset), 
                  lambda _: train_loader.dataset.reset() if args.no_cache else train_loader.dataset.reload())

  # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
  if args.local_rank in [-1, 0]:
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=metric_names, output_transform=lambda _: {"lr": f"{optimizer.param_groups[0]['lr']:.2e}"})
    evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

    tb_logger = common.setup_tb_logging(args.output_path, trainer, optimizer, evaluators={'validation': evaluator}, log_every_iters=1)

  # Store 3 best models by validation accuracy:
  common.gen_save_best_models_by_val_score(
    save_handler=DiskSaver(args.output_path, require_empty=False),
    evaluator=evaluator,
    models={"model": model},
    metric_name="accuracy",
    n_saved=3,
    trainer=trainer,
    tag="validation"
  )

  # Run the training
  trainer.run(train_loader, max_epochs=args.n_epochs)

  if args.local_rank in [-1, 0]:
    tb_logger.close()
Esempio n. 25
0
def main(args):
    fix_seeds()
    # if os.path.exists('./logs'):
    #     shutil.rmtree('./logs')
    # os.mkdir('./logs')
    # writer = SummaryWriter(log_dir='./logs')
    vis = visdom.Visdom()
    val_avg_loss_window = create_plot_window(vis,
                                             '#Epochs',
                                             'Loss',
                                             'Average Loss',
                                             legend=['Train', 'Val'])
    val_avg_accuracy_window = create_plot_window(vis,
                                                 '#Epochs',
                                                 'Accuracy',
                                                 'Average Accuracy',
                                                 legend=['Val'])
    size = (args.height, args.width)
    train_transform = transforms.Compose([
        transforms.Resize(size),
        # transforms.RandomResizedCrop(size=size, scale=(0.5, 1)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(10,
                                translate=(0.1, 0.1),
                                scale=(0.8, 1.2),
                                resample=PIL.Image.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    val_transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_dataset = TextDataset(args.data_path,
                                'train.txt',
                                size=args.train_size,
                                transform=train_transform)
    val_dataset = TextDataset(args.data_path,
                              'val.txt',
                              size=args.val_size,
                              transform=val_transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=args.workers,
                              shuffle=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            num_workers=args.workers,
                            shuffle=False)

    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(512, 16)

    model.load_state_dict(torch.load(args.resume_from)['model'])

    device = 'cpu'
    if args.cuda:
        device = 'cuda'
    print(device)
    metrics = {'accuracy': Accuracy(), 'loss': Loss(criterion)}
    evaluator = create_supervised_evaluator(model, metrics, device=device)

    @trainer.on(Events.ITERATION_COMPLETED)
    def lr_step(engine):
        if model.training:
            scheduler.step()

    global pbar, desc
    pbar, desc = None, None

    @trainer.on(Events.EPOCH_STARTED)
    def create_train_pbar(engine):
        global desc, pbar
        if pbar is not None:
            pbar.close()
        desc = 'Train iteration - loss: {:.4f} - lr: {:.4f}'
        pbar = tqdm(initial=0,
                    leave=False,
                    total=len(train_loader),
                    desc=desc.format(0, lr))

    @trainer.on(Events.EPOCH_COMPLETED)
    def create_val_pbar(engine):
        global desc, pbar
        if pbar is not None:
            pbar.close()
        desc = 'Validation iteration - loss: {:.4f}'
        pbar = tqdm(initial=0,
                    leave=False,
                    total=len(val_loader),
                    desc=desc.format(0))

    # desc_val = 'Validation iteration - loss: {:.4f}'
    # pbar_val = tqdm(initial=0, leave=False, total=len(val_loader), desc=desc_val.format(0))

    log_interval = 1
    e = Events.ITERATION_COMPLETED(every=log_interval)

    train_losses = []

    @trainer.on(e)
    def log_training_loss(engine):
        lr = optimizer.param_groups[0]['lr']
        train_losses.append(engine.state.output)
        pbar.desc = desc.format(engine.state.output, lr)
        pbar.update(log_interval)
        # writer.add_scalar("training/loss", engine.state.output, engine.state.iteration)
        # writer.add_scalar("lr", lr, engine.state.iteration)

    @evaluator.on(e)
    def log_validation_loss(engine):
        label = engine.state.batch[1].to(device)
        output = engine.state.output[0]
        pbar.desc = desc.format(criterion(output, label))
        pbar.update(log_interval)

    # if args.resume_from is not None:
    #     @trainer.on(Events.STARTED)
    #     def _(engine):
    #         pbar.n = engine.state.iteration

    # @trainer.on(Events.EPOCH_COMPLETED(every=1))
    # def log_train_results(engine):
    #     evaluator.run(train_loader) # eval on train set to check for overfitting
    #     metrics = evaluator.state.metrics
    #     avg_accuracy = metrics['accuracy']
    #     avg_nll = metrics['loss']
    #     tqdm.write(
    #         "Train Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
    #         .format(engine.state.epoch, avg_accuracy, avg_nll))
    #     pbar.n = pbar.last_print_n = 0

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        pbar.refresh()
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['loss']
        tqdm.write(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))
        # pbar.n = pbar.last_print_n = 0

        # writer.add_scalars("avg losses", {"train": statistics.mean(train_losses),
        #                                   "valid": avg_nll}, engine.state.epoch)
        # # writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        # writer.add_scalar("avg_accuracy", avg_accuracy, engine.state.epoch)
        vis.line(X=np.array([engine.state.epoch]),
                 Y=np.array([avg_accuracy]),
                 win=val_avg_accuracy_window,
                 update='append')
        vis.line(X=np.column_stack(
            (np.array([engine.state.epoch]), np.array([engine.state.epoch]))),
                 Y=np.column_stack((np.array([statistics.mean(train_losses)]),
                                    np.array([avg_nll]))),
                 win=val_avg_loss_window,
                 update='append',
                 opts=dict(legend=['Train', 'Val']))
        del train_losses[:]

    objects_to_checkpoint = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "scheduler": scheduler
    }
    training_checkpoint = Checkpoint(to_save=objects_to_checkpoint,
                                     save_handler=DiskSaver(
                                         args.snapshot_dir,
                                         require_empty=False))
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                              training_checkpoint)
    if args.resume_from not in [None, '']:
        tqdm.write("Resume from a checkpoint: {}".format(args.resume_from))
        checkpoint = torch.load(args.resume_from)
        Checkpoint.load_objects(to_load=objects_to_checkpoint,
                                checkpoint=checkpoint)

    try:
        trainer.run(train_loader, max_epochs=args.epochs)
        pbar.close()
    except Exception as e:
        import traceback
        print(traceback.format_exc())
Esempio n. 26
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    cutmix_beta = config["cutmix_beta"]
    cutmix_prob = config["cutmix_prob"]
    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()

        with autocast(enabled=with_amp):
            r = torch.rand(1).item()
            if cutmix_beta > 0 and r < cutmix_prob:
                output, loss = utils.cutmix_forward(model, x, criterion, y,
                                                    cutmix_beta)
            else:
                output = model(x)
                loss = criterion(output, y)

        optimizer.zero_grad()
        scaler.scale(loss).backward()

        if idist.backend() == "horovod":
            optimizer.synchronize()
            with optimizer.skip_synchronize():
                scaler.step(optimizer)
                scaler.update()
        else:
            scaler.step(optimizer)
            scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    if config["with_pbar"] and idist.get_rank() == 0:
        ProgressBar().attach(trainer)

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert (checkpoint_fp.exists()
                ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 27
0
def create_trainer(
    train_step,
    output_names,
    model,
    ema_model,
    optimizer,
    lr_scheduler,
    supervised_train_loader,
    test_loader,
    cfg,
    logger,
    cta=None,
    unsup_train_loader=None,
    cta_probe_loader=None,
):

    trainer = Engine(train_step)
    trainer.logger = logger

    output_path = os.getcwd()

    to_save = {
        "model": model,
        "ema_model": ema_model,
        "optimizer": optimizer,
        "trainer": trainer,
        "lr_scheduler": lr_scheduler,
    }
    if cta is not None:
        to_save["cta"] = cta

    common.setup_common_training_handlers(
        trainer,
        train_sampler=supervised_train_loader.sampler,
        to_save=to_save,
        save_every_iters=cfg.solver.checkpoint_every,
        output_path=output_path,
        output_names=output_names,
        lr_scheduler=lr_scheduler,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    ProgressBar(persist=False).attach(
        trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED
    )

    unsupervised_train_loader_iter = None
    if unsup_train_loader is not None:
        unsupervised_train_loader_iter = cycle(unsup_train_loader)

    cta_probe_loader_iter = None
    if cta_probe_loader is not None:
        cta_probe_loader_iter = cycle(cta_probe_loader)

    # Setup handler to prepare data batches
    @trainer.on(Events.ITERATION_STARTED)
    def prepare_batch(e):
        sup_batch = e.state.batch
        e.state.batch = {
            "sup_batch": sup_batch,
        }
        if unsupervised_train_loader_iter is not None:
            unsup_batch = next(unsupervised_train_loader_iter)
            e.state.batch["unsup_batch"] = unsup_batch

        if cta_probe_loader_iter is not None:
            cta_probe_batch = next(cta_probe_loader_iter)
            cta_probe_batch["policy"] = [
                deserialize(p) for p in cta_probe_batch["policy"]
            ]
            e.state.batch["cta_probe_batch"] = cta_probe_batch

    # Setup handler to update EMA model
    @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay)
    def update_ema_model(ema_decay):
        # EMA on parametes
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay)

    # Setup handlers for debugging
    if cfg.debug:

        @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100))
        @idist.one_rank_only()
        def log_weights_norms():
            wn = []
            ema_wn = []
            for ema_param, param in zip(ema_model.parameters(), model.parameters()):
                wn.append(torch.mean(param.data))
                ema_wn.append(torch.mean(ema_param.data))

            msg = "\n\nWeights norms"
            msg += "\n- Raw model: {}".format(
                to_list_str(torch.tensor(wn[:10] + wn[-10:]))
            )
            msg += "\n- EMA model: {}\n".format(
                to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:]))
            )
            logger.info(msg)

            rmn = []
            rvar = []
            ema_rmn = []
            ema_rvar = []
            for m1, m2 in zip(model.modules(), ema_model.modules()):
                if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d):
                    rmn.append(torch.mean(m1.running_mean))
                    rvar.append(torch.mean(m1.running_var))
                    ema_rmn.append(torch.mean(m2.running_mean))
                    ema_rvar.append(torch.mean(m2.running_var))

            msg = "\n\nBN buffers"
            msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10])))
            msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10])))
            msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10])))
            msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10])))
            logger.info(msg)

        # TODO: Need to inspect a bug
        # if idist.get_rank() == 0:
        #     from ignite.contrib.handlers import ProgressBar
        #
        #     profiler = BasicTimeProfiler()
        #     profiler.attach(trainer)
        #
        #     @trainer.on(Events.ITERATION_COMPLETED(every=200))
        #     def log_profiling(_):
        #         results = profiler.get_results()
        #         profiler.print_results(results)

    # Setup validation engine
    metrics = {
        "accuracy": Accuracy(),
    }

    if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU):
        metrics.update({
            "precision": Precision(average=False),
            "recall": Recall(average=False),
        })

    eval_kwargs = dict(
        metrics=metrics,
        prepare_batch=sup_prepare_batch,
        device=idist.device(),
        non_blocking=True,
    )

    evaluator = create_supervised_evaluator(model, **eval_kwargs)
    ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs)

    def log_results(epoch, max_epochs, metrics, ema_metrics):
        msg1 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()]
        )
        msg2 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()]
        )
        logger.info(
            "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2)
        )
        if cta is not None:
            logger.info("\n" + stats(cta))

    @trainer.on(
        Events.EPOCH_COMPLETED(every=cfg.solver.validate_every)
        | Events.STARTED
        | Events.COMPLETED
    )
    def run_evaluation():
        evaluator.run(test_loader)
        ema_evaluator.run(test_loader)
        log_results(
            trainer.state.epoch,
            trainer.state.max_epochs,
            evaluator.state.metrics,
            ema_evaluator.state.metrics,
        )

    # setup TB logging
    if idist.get_rank() == 0:
        tb_logger = common.setup_tb_logging(
            output_path,
            trainer,
            optimizers=optimizer,
            evaluators={"validation": evaluator, "ema validation": ema_evaluator},
            log_every_iters=15,
        )
        if cfg.online_exp_tracking.wandb:
            from ignite.contrib.handlers import WandBLogger

            wb_dir = Path("/tmp/output-fixmatch-wandb")
            if not wb_dir.exists():
                wb_dir.mkdir()

            _ = WandBLogger(
                project="fixmatch-pytorch",
                name=cfg.name,
                config=cfg,
                sync_tensorboard=True,
                dir=wb_dir.as_posix(),
                reinit=True,
            )

    resume_from = cfg.solver.resume_from
    if resume_from is not None:
        resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*"))
        if len(resume_from) > 0:
            # get latest
            checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime)
            assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
                checkpoint_fp.as_posix()
            )
            logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
            checkpoint = torch.load(checkpoint_fp.as_posix())
            Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    @trainer.on(Events.COMPLETED)
    def release_all_resources():
        nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter

        if idist.get_rank() == 0:
            tb_logger.close()

        if unsupervised_train_loader_iter is not None:
            unsupervised_train_loader_iter = None

        if cta_probe_loader_iter is not None:
            cta_probe_loader_iter = None

    return trainer
Esempio n. 28
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]
        labels = batch["label"].view(-1, 1)

        if input_ids.device != device:
            input_ids = input_ids.to(device,
                                     non_blocking=True,
                                     dtype=torch.long)
            attention_mask = attention_mask.to(device,
                                               non_blocking=True,
                                               dtype=torch.long)
            token_type_ids = token_type_ids.to(device,
                                               non_blocking=True,
                                               dtype=torch.long)
            labels = labels.to(device, non_blocking=True, dtype=torch.float)

        model.train()

        with autocast(enabled=with_amp):
            y_pred = model(input_ids, attention_mask, token_type_ids)
            loss = criterion(y_pred, labels)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]
    if config["log_every_iters"] == 0:
        # Disable logging training metrics:
        metric_names = None
        config["log_every_iters"] = 15

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=utils.get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        log_every_iters=config["log_every_iters"],
        with_pbars=not config["with_clearml"],
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(
        ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Esempio n. 29
0
def main():
  parser = argparse.ArgumentParser()

  # Required parameters
  parser.add_argument("--model", type=str, default='ffn', help="model's name")
  parser.add_argument("--checkpoint", type=str, required=True, help="checkpoint file path")
  parser.add_argument("--pilot_version", type=int, choices=[1, 2], default=1)
  parser.add_argument("--batch_size", type=int, default=1024)
  parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                      help="Device (cuda or cpu)")
  parser.add_argument("--local_rank", type=int, default=-1,
                      help="Local rank for distributed training (-1: not distributed)")
  parser.add_argument("--seed", type=int, default=43)
  parser.add_argument("--debug", action='store_true') 
  args = parser.parse_args()

  # Setup CUDA, GPU & distributed training
  args.distributed = (args.local_rank != -1)
  if not args.distributed:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend="nccl", init_method='env://')
  args.n_gpu = torch.cuda.device_count() if not args.distributed else 1
  args.device = device

  # Set seed
  set_seed(args)
  logger = setup_logger("Testing", distributed_rank=args.local_rank)

  # Model construction
  model = getattr(models, args.model)(args)
  checkpoint_fp = Path(args.checkpoint)
  assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix())
  logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
  checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
  Checkpoint.load_objects(to_load={"model": model}, checkpoint=checkpoint)
  model = model.to(device)

  if args.local_rank != -1:
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
    )

  datapath = f'data/Y_{args.pilot_version}.csv'
  dataY = pd.read_csv(datapath, header=None).values

  test_dataset = torch.tensor(dataY, dtype=torch.float32)
  test_loader = DataLoader(test_dataset, batch_size=args.batch_size, pin_memory=True, shuffle=False)

  pred = []
  model.eval()
  for batch in tqdm(test_loader, desc="Runing Testing"):
    batch = batch.to(device)
    x_pred = model(batch)
    x_pred = x_pred > 0.5
    pred.append(x_pred.cpu().numpy())
  np.concatenate(pred).tofile(f'{os.path.split(args.checkpoint)[0]}/X_pre_{args.pilot_version}.bin')
  if args.debug:
    np.ones_like(np.concatenate(pred)).tofile(f'{os.path.split(args.checkpoint)[0]}/X_pre_2.bin') 
def run(conf: DictConfig, local_rank=0, distributed=False):
    epochs = conf.train.epochs
    epoch_length = conf.train.epoch_length
    torch.manual_seed(conf.general.seed)

    if distributed:
        rank = dist.get_rank()
        num_replicas = dist.get_world_size()
        torch.cuda.set_device(local_rank)
    else:
        rank = 0
        num_replicas = 1
        torch.cuda.set_device(conf.general.gpu)
    device = torch.device('cuda')
    loader_args = dict()
    master_node = rank == 0

    if master_node:
        print(conf.pretty())
    if num_replicas > 1:
        epoch_length = epoch_length // num_replicas
        loader_args = dict(rank=rank, num_replicas=num_replicas)

    train_dl = create_train_loader(conf.data, **loader_args)

    if epoch_length < 1:
        epoch_length = len(train_dl)

    metric_names = list(conf.logging.stats)
    metrics = create_metrics(metric_names, device if distributed else None)

    G = instantiate(conf.model.G).to(device)
    D = instantiate(conf.model.D).to(device)
    G_loss = instantiate(conf.loss.G).to(device)
    D_loss = instantiate(conf.loss.D).to(device)
    G_opt = instantiate(conf.optim.G, G.parameters())
    D_opt = instantiate(conf.optim.D, D.parameters())
    G_ema = None

    if master_node and conf.G_smoothing.enabled:
        G_ema = instantiate(conf.model.G)
        if not conf.G_smoothing.use_cpu:
            G_ema = G_ema.to(device)
        G_ema.load_state_dict(G.state_dict())
        G_ema.requires_grad_(False)

    to_save = {
        'G': G,
        'D': D,
        'G_loss': G_loss,
        'D_loss': D_loss,
        'G_opt': G_opt,
        'D_opt': D_opt,
        'G_ema': G_ema
    }

    if master_node and conf.logging.model:
        logging.info(G)
        logging.info(D)

    if distributed:
        ddp_kwargs = dict(device_ids=[
            local_rank,
        ], output_device=local_rank)
        G = torch.nn.parallel.DistributedDataParallel(G, **ddp_kwargs)
        D = torch.nn.parallel.DistributedDataParallel(D, **ddp_kwargs)

    train_options = {
        'train': dict(conf.train),
        'snapshot': dict(conf.snapshots),
        'smoothing': dict(conf.G_smoothing),
        'distributed': distributed
    }
    bs_dl = int(conf.data.loader.batch_size) * num_replicas
    bs_eff = conf.train.batch_size
    if bs_eff % bs_dl:
        raise AttributeError(
            "Effective batch size should be divisible by data-loader batch size "
            "multiplied by number of devices in use"
        )  # until there is no special bs for master node...
    upd_interval = max(bs_eff // bs_dl, 1)
    train_options['train']['update_interval'] = upd_interval
    if epoch_length < len(train_dl):
        # ideally epoch_length should be tied to the effective batch_size only
        # and the ignite trainer counts data-loader iterations
        epoch_length *= upd_interval

    train_loop, sample_images = create_train_closures(G,
                                                      D,
                                                      G_loss,
                                                      D_loss,
                                                      G_opt,
                                                      D_opt,
                                                      G_ema=G_ema,
                                                      device=device,
                                                      options=train_options)
    trainer = create_trainer(train_loop, metrics, device, num_replicas)
    to_save['trainer'] = trainer

    every_iteration = Events.ITERATION_COMPLETED
    trainer.add_event_handler(every_iteration, TerminateOnNan())

    cp = conf.checkpoints
    pbar = None

    if master_node:
        log_freq = conf.logging.iter_freq
        log_event = Events.ITERATION_COMPLETED(every=log_freq)
        pbar = ProgressBar(persist=False)
        trainer.add_event_handler(Events.EPOCH_STARTED, on_epoch_start)
        trainer.add_event_handler(log_event, log_iter, pbar, log_freq)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, log_epoch)
        pbar.attach(trainer, metric_names=metric_names)
        setup_checkpoints(trainer, to_save, epoch_length, conf)
        setup_snapshots(trainer, sample_images, conf)

    if 'load' in cp.keys() and cp.load is not None:
        if master_node:
            logging.info("Resume from a checkpoint: {}".format(cp.load))
            trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp,
                                      pbar)
        Checkpoint.load_objects(to_load=to_save,
                                checkpoint=torch.load(cp.load,
                                                      map_location=device))

    try:
        trainer.run(train_dl, max_epochs=epochs, epoch_length=epoch_length)
    except Exception as e:
        import traceback
        logging.error(traceback.format_exc())
    if pbar is not None:
        pbar.close()