def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    if sys.version_info > (3, ):
        from ignite.contrib.metrics.gpu_info import GpuInfo

        try:
            GpuInfo().attach(trainer)
        except RuntimeError:
            print(
                "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
                "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
                "install it : `pip install pynvml`")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    tb_logger = TensorboardLogger(log_dir=log_dir)

    tb_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        tag="training",
        output_transform=lambda loss: {"batchloss": loss},
        metric_names="all",
    )

    for tag, evaluator in [("training", train_evaluator),
                           ("validation", validation_evaluator)]:
        tb_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag=tag,
            metric_names=["loss", "accuracy"],
            global_step_transform=global_step_from_engine(trainer),
        )

    tb_logger.attach_opt_params_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        optimizer=optimizer)

    tb_logger.attach(trainer,
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    def score_function(engine):
        return engine.state.metrics["accuracy"]

    model_checkpoint = ModelCheckpoint(
        log_dir,
        n_saved=2,
        filename_prefix="best",
        score_function=score_function,
        score_name="validation_accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                           {"model": model})

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    tb_logger.close()
Beispiel #2
0
def train(cfg: DictConfig) -> None:

    # Determine device (GPU, CPU, etc.)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Model
    model = get_network(cfg)

    # Data Loaders
    train_loader, val_loader = get_dataloaders(cfg, num_workers=cfg.data_loader_workers)

    # Your training loop
    trainer = create_training_loop(model, cfg, "trainer", device=device)
    # Your evaluation loop
    evaluator = create_evaluation_loop(model, cfg, "evaluator", device=device)

    ld = LogDirector(cfg, engines=[trainer, evaluator])

    ########################################################################
    # Logging Callbacks
    ########################################################################

    # Helper to run the evaluation loop
    def run_evaluator():
        evaluator.run(val_loader)
        return evaluator  # NOTE: Must return the engine we want to log from

    ld.set_event_handlers(
        trainer,
        Events.ITERATION_COMPLETED(every=50),
        EngineStateAttr.OUTPUT,
        [
            (LOG_OP.SAVE_IMAGE, ["im"]),  # Save images to a folder
            (LOG_OP.LOG_MESSAGE, ["nll"],),  # Log fields as message in logfile
            (LOG_OP.SAVE_IN_DATA_FILE, ["nll"],),  # Log fields as separate data files
            (
                LOG_OP.NUMBER_TO_VISDOM,
                [
                    # First plot, key is "p1"
                    VisPlot(
                        var_name="nll",
                        plot_key="p1",
                        split="nll_1",
                        # Any opts that Visdom supports
                        opts={"title": "Plot 1", "xlabel": "Iters", "fillarea": True},
                    ),
                    VisPlot(var_name="nll_2", plot_key="p1", split="nll_2",),
                ],
            ),
            (
                LOG_OP.IMAGE_TO_VISDOM,
                [
                    VisImg(
                        var_name="im",
                        img_key="1",
                        env="images",
                        opts={"caption": "a current image", "title": "title"},
                    ),
                    VisImg(
                        var_name="im",
                        img_key="2",
                        env="images",
                        opts={"caption": "a current image", "title": "title"},
                    ),
                ],
            ),
        ],
    )

    ld.set_event_handlers(
        trainer,
        Events.EPOCH_COMPLETED,
        EngineStateAttr.METRICS,
        [
            (
                LOG_OP.LOG_MESSAGE,
                ["nll", "accuracy",],
            ),  # Log fields as message in logfile
            (
                LOG_OP.SAVE_IN_DATA_FILE,
                ["accuracy"],
            ),  # Log fields as separate data files
            (
                LOG_OP.NUMBER_TO_VISDOM,
                [
                    # First plot, key is "p1"
                    VisPlot(
                        var_name="accuracy",
                        plot_key="p3",
                        split="acc",
                        # Any opts that Visdom supports
                        opts={"title": "Eval Acc", "xlabel": "Iters"},
                    ),
                    # First plot, key is "p1"
                    VisPlot(
                        var_name="nll",
                        plot_key="p4",
                        split="nll",
                        # Any opts that Visdom supports
                        opts={"title": "Eval Nll", "xlabel": "Iters", "fillarea": True},
                    ),
                ],
            ),
        ],
        # Run the evaluation loop, then do log operations from the return engine
        pre_op=run_evaluator,
    )

    # Execute training
    trainer.run(train_loader, max_epochs=cfg.mode.train.max_epochs)
Beispiel #3
0
def _setup_common_training_handlers(
    trainer,
    to_save=None,
    save_every_iters=1000,
    output_path=None,
    lr_scheduler=None,
    with_gpu_stats=True,
    output_names=None,
    with_pbars=True,
    with_pbar_on_iters=True,
    log_every_iters=100,
    device="cuda",
):
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      lambda engine: lr_scheduler.step())
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:
        if output_path is None:
            raise ValueError(
                "If to_save argument is provided then output_path argument should be also defined"
            )
        checkpoint_handler = ModelCheckpoint(dirname=output_path,
                                             filename_prefix="training")
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(every=save_every_iters),
            checkpoint_handler, to_save)

    if with_gpu_stats:
        GpuInfo().attach(
            trainer,
            name="gpu",
            event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

    if output_names is not None:

        def output_transform(x, index, name):
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, torch.Tensor):
                return x
            else:
                raise ValueError(
                    "Unhandled type of update_function's output. "
                    "It should either mapping or sequence, but given {}".
                    format(type(x)))

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform,
                                                    index=i,
                                                    name=n),
                           epoch_bound=False,
                           device=device).attach(trainer, n)

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(
                trainer,
                metric_names="all",
                event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

        ProgressBar(persist=True,
                    bar_format="").attach(trainer,
                                          event_name=Events.EPOCH_STARTED,
                                          closing_event_name=Events.COMPLETED)
    def _train(save_iter=None, save_epoch=None, sd=None):
        w_norms = []
        grad_norms = []
        data = []
        chkpt = []

        manual_seed(12)
        arch = [
            nn.Conv2d(3, 10, 3),
            nn.ReLU(),
            nn.Conv2d(10, 10, 3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(10, 5),
            nn.ReLU(),
            nn.Linear(5, 2),
        ]
        if with_dropout:
            arch.insert(2, nn.Dropout2d())
            arch.insert(-2, nn.Dropout())

        model = nn.Sequential(*arch).to(device)
        opt = SGD(model.parameters(), lr=0.001)

        def proc_fn(e, b):
            from ignite.engine.deterministic import _get_rng_states, _repr_rng_state

            s = _repr_rng_state(_get_rng_states())
            model.train()
            opt.zero_grad()
            y = model(b.to(device))
            y.sum().backward()
            opt.step()
            if debug:
                print(
                    trainer.state.iteration, trainer.state.epoch, "proc_fn - b.shape", b.shape, torch.norm(y).item(), s
                )

        trainer = DeterministicEngine(proc_fn)

        if save_iter is not None:
            ev = Events.ITERATION_COMPLETED(once=save_iter)
        elif save_epoch is not None:
            ev = Events.EPOCH_COMPLETED(once=save_epoch)
            save_iter = save_epoch * (data_size // batch_size)

        @trainer.on(ev)
        def save_chkpt(_):
            if debug:
                print(trainer.state.iteration, "save_chkpt")
            fp = os.path.join(dirname, "test.pt")
            from ignite.engine.deterministic import _repr_rng_state

            tsd = trainer.state_dict()
            if debug:
                print("->", _repr_rng_state(tsd["rng_states"]))
            torch.save([model.state_dict(), opt.state_dict(), tsd], fp)
            chkpt.append(fp)

        def log_event_filter(_, event):
            if (event // save_iter == 1) and 1 <= (event % save_iter) <= 5:
                return True
            return False

        @trainer.on(Events.ITERATION_COMPLETED(event_filter=log_event_filter))
        def write_data_grads_weights(e):
            x = e.state.batch
            i = e.state.iteration
            data.append([i, x.mean().item(), x.std().item()])

            total = [0.0, 0.0]
            out1 = []
            out2 = []
            for p in model.parameters():
                n1 = torch.norm(p).item()
                n2 = torch.norm(p.grad).item()
                out1.append(n1)
                out2.append(n2)
                total[0] += n1
                total[1] += n2
            w_norms.append([i, total[0]] + out1)
            grad_norms.append([i, total[1]] + out2)

        if sd is not None:
            sd = torch.load(sd)
            model.load_state_dict(sd[0])
            opt.load_state_dict(sd[1])
            from ignite.engine.deterministic import _repr_rng_state

            if debug:
                print("-->", _repr_rng_state(sd[2]["rng_states"]))
            trainer.load_state_dict(sd[2])

        manual_seed(32)
        trainer.run(random_train_data_loader(size=data_size), max_epochs=5)
        return {"sd": chkpt, "data": data, "grads": grad_norms, "weights": w_norms}
Beispiel #5
0
def _setup_common_training_handlers(
    trainer: Engine,
    to_save: Optional[Mapping] = None,
    save_every_iters: int = 1000,
    output_path: Optional[str] = None,
    lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
    with_gpu_stats: bool = False,
    output_names: Optional[Iterable[str]] = None,
    with_pbars: bool = True,
    with_pbar_on_iters: bool = True,
    log_every_iters: int = 100,
    stop_on_nan: bool = True,
    clear_cuda_cache: bool = True,
    save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
    **kwargs: Any
):
    if output_path is not None and save_handler is not None:
        raise ValueError(
            "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
        )

    if stop_on_nan:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step())
        elif isinstance(lr_scheduler, LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    if torch.cuda.is_available() and clear_cuda_cache:
        trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:

        if output_path is None and save_handler is None:
            raise ValueError(
                "If to_save argument is provided then output_path or save_handler arguments should be also defined"
            )
        if output_path is not None:
            save_handler = DiskSaver(dirname=output_path, require_empty=False)

        checkpoint_handler = Checkpoint(to_save, save_handler, filename_prefix="training", **kwargs)
        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler)

    if with_gpu_stats:
        GpuInfo().attach(trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

    if output_names is not None:

        def output_transform(x, index, name):
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, (torch.Tensor, numbers.Number)):
                return x
            else:
                raise TypeError(
                    "Unhandled type of update_function's output. "
                    "It should either mapping or sequence, but given {}".format(type(x))
                )

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(
                trainer, n
            )

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(
                trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)
            )

        ProgressBar(persist=True, bar_format="").attach(
            trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED
        )
Beispiel #6
0
def train_model(model, data, batch_size=64, lr=0.01, optimizer=None):
    """
    Train function for models used in nets.py. Used to monitor and apply
    early stopping.
    """
    criterion = nn.NLLLoss()
    if optimizer == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.6)
    else:
        optimizer = OPTIMIZERS[optimizer](model.parameters(), lr=lr)
    optimizer.zero_grad()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=model.device)

    val_metrics = {"accuracy": Accuracy(), "nll": Loss(criterion)}
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=val_metrics,
                                                  device=model.device)
    val_evaluator = create_supervised_evaluator(model,
                                                metrics=val_metrics,
                                                device=model.device)

    @trainer.on(Events.ITERATION_COMPLETED(every=100))
    def log_training_loss(trainer):
        print(f"Epoch[{trainer.state.epoch}] Loss: {trainer.state.output:.2f}")

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=100),
                              log_training_loss)

    size = len(data)
    lengths = [int(size * 0.75), size - int(size * 0.75)]
    trainset, valset = random_split(
        data, lengths=lengths, generator=torch.Generator().manual_seed(42))
    train_loader = DataLoader(trainset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=4)
    val_loader = DataLoader(valset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=4)

    training_history = {"accuracy": [], "loss": []}
    validation_history = {"accuracy": [], "loss": []}
    last_epoch = []

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(trainer):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        accuracy = metrics["accuracy"] * 100
        loss = metrics["nll"]
        last_epoch.append(0)
        training_history["accuracy"].append(accuracy)
        training_history["loss"].append(loss)
        print(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(trainer.state.epoch, accuracy, loss))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(trainer):
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        accuracy = metrics["accuracy"] * 100
        loss = metrics["nll"]
        validation_history["accuracy"].append(accuracy)
        validation_history["loss"].append(loss)
        print(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(trainer.state.epoch, accuracy, loss))

    # trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)
    handler = EarlyStopping(patience=10,
                            score_function=score_function,
                            trainer=trainer,
                            min_delta=0)
    # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
    val_evaluator.add_event_handler(Events.COMPLETED, handler)

    trainer.run(train_loader, max_epochs=100000)

    return training_history, validation_history
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = 'cpu'

    if torch.cuda.is_available():
        device = 'cuda'

    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)

    if sys.version_info > (3, ):
        from ignite.contrib.metrics.gpu_info import GpuInfo
        try:
            GpuInfo().attach(trainer)
        except RuntimeError:
            print(
                "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
                "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
                "install it : `pip install pynvml`")

    metrics = {'accuracy': Accuracy(), 'loss': Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    tb_logger = TensorboardLogger(log_dir=log_dir)

    tb_logger.attach(trainer,
                     log_handler=OutputHandler(
                         tag="training",
                         output_transform=lambda loss: {'batchloss': loss},
                         metric_names='all'),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(train_evaluator,
                     log_handler=OutputHandler(
                         tag="training",
                         metric_names=["loss", "accuracy"],
                         another_engine=trainer),
                     event_name=Events.EPOCH_COMPLETED)

    tb_logger.attach(validation_evaluator,
                     log_handler=OutputHandler(
                         tag="validation",
                         metric_names=["loss", "accuracy"],
                         another_engine=trainer),
                     event_name=Events.EPOCH_COMPLETED)

    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)
    tb_logger.close()
Beispiel #8
0
def training(rank, config):
    rank = idist.get_rank()
    manual_seed(config["seed"] + rank)
    device = idist.device()

    # Define output folder:
    config.output = "/tmp/output"

    model = idist.auto_model(config.model)
    optimizer = idist.auto_optim(config.optimizer)
    criterion = config.criterion

    train_set, val_set = config.train_set, config.val_set
    train_loader = idist.auto_dataloader(train_set,
                                         batch_size=config.train_batch_size)
    val_loader = idist.auto_dataloader(val_set,
                                       batch_size=config.val_batch_size)

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED(every=config.val_interval))
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    if rank == 0:
        tb_logger = TensorboardLogger(log_dir=config.output)

        tb_logger.attach_output_handler(
            trainer,
            event_name=Events.ITERATION_COMPLETED(every=100),
            tag="training",
            output_transform=lambda loss: {"batchloss": loss},
            metric_names="all",
        )

        for tag, evaluator in [("training", train_evaluator),
                               ("validation", validation_evaluator)]:
            tb_logger.attach_output_handler(
                evaluator,
                event_name=Events.EPOCH_COMPLETED,
                tag=tag,
                metric_names=["loss", "accuracy"],
                global_step_transform=global_step_from_engine(trainer),
            )

        tb_logger.attach_opt_params_handler(
            trainer,
            event_name=Events.ITERATION_COMPLETED(every=100),
            optimizer=optimizer)

    model_checkpoint = ModelCheckpoint(
        config.output,
        n_saved=2,
        filename_prefix="best",
        score_name="accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                           {"model": model})

    trainer.run(train_loader, max_epochs=config.num_epochs)

    if rank == 0:
        tb_logger.close()
Beispiel #9
0
def main(
    dataset,
    dataroot,
    z_dim,
    g_filters,
    d_filters,
    batch_size,
    epochs,
    learning_rate,
    beta_1,
    saved_G,
    saved_D,
    seed,
    n_workers,
    device,
    alpha,
    output_dir,
):

    # seed
    check_manual_seed(seed)

    # data
    dataset, num_channels = check_dataset(dataset, dataroot)
    loader = data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )

    # netowrks
    netG = Generator(z_dim, g_filters, num_channels).to(device)
    netD = Discriminator(num_channels, d_filters).to(device)

    # criterion
    bce = nn.BCELoss()

    # optimizers
    optimizerG = optim.Adam(netG.parameters(),
                            lr=learning_rate,
                            betas=(beta_1, 0.999))
    optimizerD = optim.Adam(netD.parameters(),
                            lr=learning_rate,
                            betas=(beta_1, 0.999))

    # load pre-trained models
    if saved_G:
        netG.load_state_dict(torch.load(saved_G))

    if saved_D:
        netD.load_state_dict(torch.load(saved_D))

    # misc
    real_labels = torch.ones(batch_size, device=device)
    fake_labels = torch.zeros(batch_size, device=device)
    fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device)

    def get_noise():
        return torch.randn(batch_size, z_dim, 1, 1, device=device)

    # The main function, processing a batch of examples
    def step(engine, batch):

        # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
        real, _ = batch
        real = real.to(device)

        # -----------------------------------------------------------
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()

        # train with real
        output = netD(real)
        errD_real = bce(output, real_labels)
        D_x = output.mean().item()

        errD_real.backward()

        # get fake image from generator
        noise = get_noise()
        fake = netG(noise)

        # train with fake
        output = netD(fake.detach())
        errD_fake = bce(output, fake_labels)
        D_G_z1 = output.mean().item()

        errD_fake.backward()

        # gradient update
        errD = errD_real + errD_fake
        optimizerD.step()

        # -----------------------------------------------------------
        # (2) Update G network: maximize log(D(G(z)))
        netG.zero_grad()

        # Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
        output = netD(fake)
        errG = bce(output, real_labels)
        D_G_z2 = output.mean().item()

        errG.backward()

        # gradient update
        optimizerG.step()

        return {
            "errD": errD.item(),
            "errG": errG.item(),
            "D_x": D_x,
            "D_G_z1": D_G_z1,
            "D_G_z2": D_G_z2,
        }

    # ignite objects
    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         CKPT_PREFIX,
                                         n_saved=10,
                                         require_empty=False)
    timer = Timer(average=True)

    # attach running average metrics
    monitoring_metrics = ["errD", "errG", "D_x", "D_G_z1", "D_G_z2"]
    RunningAverage(alpha=alpha, output_transform=lambda x: x["errD"]).attach(
        trainer, "errD")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["errG"]).attach(
        trainer, "errG")
    RunningAverage(alpha=alpha,
                   output_transform=lambda x: x["D_x"]).attach(trainer, "D_x")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z1"]).attach(
        trainer, "D_G_z1")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z2"]).attach(
        trainer, "D_G_z2")

    # attach progress bar
    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.ITERATION_COMPLETED(every=PRINT_FREQ))
    def print_logs(engine):
        fname = os.path.join(output_dir, LOGS_FNAME)
        columns = [
            "iteration",
        ] + list(engine.state.metrics.keys())
        values = [
            str(engine.state.iteration),
        ] + [str(round(value, 5)) for value in engine.state.metrics.values()]

        with open(fname, "a") as f:
            if f.tell() == 0:
                print("\t".join(columns), file=f)
            print("\t".join(values), file=f)

        message = "[{epoch}/{max_epoch}][{i}/{max_i}]".format(
            epoch=engine.state.epoch,
            max_epoch=epochs,
            i=(engine.state.iteration % len(loader)),
            max_i=len(loader),
        )
        for name, value in zip(columns, values):
            message += " | {name}: {value}".format(name=name, value=value)

        pbar.log_message(message)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_fake_example(engine):
        fake = netG(fixed_noise)
        path = os.path.join(output_dir,
                            FAKE_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(fake.detach(), path, normalize=True)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_real_example(engine):
        img, y = engine.state.batch
        path = os.path.join(output_dir,
                            REAL_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(img, path, normalize=True)

    # adding handlers using `trainer.add_event_handler` method API
    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED,
        handler=checkpoint_handler,
        to_save={
            "netG": netG,
            "netD": netD
        },
    )

    # automatically adding handlers via a special `attach` method of `Timer` handler
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message("Epoch {} done. Time per batch: {:.3f}[s]".format(
            engine.state.epoch, timer.value()))
        timer.reset()

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        try:
            import matplotlib as mpl

            mpl.use("agg")

            import matplotlib.pyplot as plt
            import numpy as np
            import pandas as pd

        except ImportError:
            warnings.warn(
                "Loss plots will not be generated -- pandas or matplotlib not found"
            )

        else:
            df = pd.read_csv(
                os.path.join(output_dir, LOGS_FNAME),
                delimiter="\t",
                index_col="iteration",
            )
            _ = df.plot(subplots=True, figsize=(20, 20))
            _ = plt.xlabel("Iteration number")
            fig = plt.gcf()
            path = os.path.join(output_dir, PLOT_FNAME)

            fig.savefig(path)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            create_plots(engine)
            checkpoint_handler(engine, {
                "netG_exception": netG,
                "netD_exception": netD
            })

        else:
            raise e

    # Setup is done. Now let's run the training
    trainer.run(loader, epochs)
Beispiel #10
0
def train(cfg, model, train_loader, val_loader, optimizer, device):
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        F.nll_loss,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'accuracy': Accuracy(),
                                                'nll': Loss(F.nll_loss)
                                            },
                                            device=device)

    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    @trainer.on(Events.ITERATION_COMPLETED(every=cfg["log"]["interval"]))
    def log_training_loss(engine):
        pbar.desc = desc.format(engine.state.output)
        pbar.update(cfg["log"]["interval"])

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

        pbar.n = pbar.last_print_n = 0

    # # Checkpoint setting
    # ./checkpoints/sample_mymodel_{step_number}
    handler = ModelCheckpoint(dirname=cfg["checkpoint"]["path"],
                              filename_prefix=cfg["checkpoint"]["prefix"],
                              n_saved=3,
                              create_dir=True,
                              require_empty=False)
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler,
                              {'mymodel': model})

    # # Early stopping
    handler = EarlyStopping(patience=5,
                            score_function=score_function,
                            trainer=trainer)
    # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset)
    evaluator.add_event_handler(Events.COMPLETED, handler)

    trainer.run(train_loader, max_epochs=cfg["training"]["epoch"])
    pbar.close()
Beispiel #11
0
def train(model,
          model_name,
          train_dataloader,
          test_dataloader,
          trainer_name='bb_detection'):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    def _prepare_batch(batch, device=None, non_blocking=False):
        """Prepare batch for training: pass to a device with options.
        """
        images, boxes = batch
        images = [image.to(device) for image in images]
        targets = [{
            'boxes': box.to(device),
            'labels': torch.ones((1), dtype=torch.int64).to(device)
        } for box in boxes]
        return images, targets

    writer = SummaryWriter(log_dir=path.join('logs', trainer_name, model_name))
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              factor=0.5,
                                                              patience=250)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, y = _prepare_batch(batch, device=device)

        loss_dict = model(x, y)
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()

        losses.backward()
        optimizer.step()
        return loss_value

    trainer = Engine(_update)
    evaluator = create_supervised_evaluator(model,
                                            prepare_batch=_prepare_batch,
                                            metrics={'iou': IOUMetric()},
                                            device=device)

    if path.exists(f'{trainer_name}_{model_name}_checkpoint.pt'):
        checkpoint = torch.load(f'{trainer_name}_{model_name}_checkpoint.pt')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        trainer.load_state_dict(checkpoint['trainer'])

    def early_stop_score_function(engine):
        val_acc = engine.state.metrics['iou']
        return val_acc

    early_stop_handler = EarlyStopping(
        patience=20, score_function=early_stop_score_function, trainer=trainer)
    evaluator.add_event_handler(Events.COMPLETED, early_stop_handler)

    checkpoint_handler = ModelCheckpoint(f'models/{trainer_name}/{model_name}',
                                         model_name,
                                         n_saved=20,
                                         create_dir=True)
    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=100),
                              checkpoint_handler, {
                                  'model': model,
                                  'optimizer': optimizer,
                                  'trainer': trainer
                              })

    @trainer.on(Events.ITERATION_COMPLETED(every=10))
    def log_training_loss(trainer):
        lr = optimizer.param_groups[0]['lr']
        print("Epoch[{}]: {} - Loss: {:.4f}, Lr: {}".format(
            trainer.state.epoch, trainer.state.iteration, trainer.state.output,
            lr))
        writer.add_scalar("training/loss", trainer.state.output,
                          trainer.state.iteration)

    @trainer.on(Events.ITERATION_COMPLETED(every=100))
    def log_training_results(trainer):
        evaluator.run(test_dataloader)
        metrics = evaluator.state.metrics
        print("Training Results - Epoch[{}]: {} - Avg IOU: {:.4f}".format(
            trainer.state.epoch, trainer.state.iteration, metrics['iou']))
        writer.add_scalar("training/avg_iou", metrics['iou'],
                          trainer.state.iteration)

        model.eval()
        test_data = iter(test_dataloader)
        x, y = _prepare_batch(next(test_data), device)
        y_pred = model(x)

        for image, output in zip(x, y_pred):
            writer.add_image_with_boxes("training/example_result", image,
                                        output['boxes'],
                                        trainer.state.iteration)
            break
        model.train()

    @trainer.on(Events.ITERATION_COMPLETED(every=10))
    def step_lr(trainer):
        lr_scheduler.step(trainer.state.output)

    @trainer.on(Events.ITERATION_COMPLETED(every=100))
    def read_lr_from_file(trainer):
        if path.exists('lr.txt'):
            with open('lr.txt', 'r', encoding='utf-8') as f:
                lr = float(f.read())
            for group in optimizer.param_groups:
                group['lr'] = lr

    trainer.run(train_dataloader, max_epochs=100)
Beispiel #12
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
    vis = visdom.Visdom()

    # if not vis.check_connection():
    #     raise RuntimeError("Visdom server not running. Please run python -m visdom.server")

    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        F.nll_loss,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                "accuracy": Accuracy(),
                                                "nll": Loss(F.nll_loss)
                                            },
                                            device=device)

    train_loss_window = create_plot_window(vis, "#Iterations", "Loss",
                                           "Training Loss")
    train_avg_loss_window = create_plot_window(vis, "#Iterations", "Loss",
                                               "Training Average Loss")
    train_avg_accuracy_window = create_plot_window(
        vis, "#Iterations", "Accuracy", "Training Average Accuracy")
    val_avg_loss_window = create_plot_window(vis, "#Epochs", "Loss",
                                             "Validation Average Loss")
    val_avg_accuracy_window = create_plot_window(
        vis, "#Epochs", "Accuracy", "Validation Average Accuracy")

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        print(
            f"Epoch[{engine.state.epoch}] Iteration[{engine.state.iteration}/{len(train_loader)}] Loss: {engine.state.output:.2f}"
            "")
        vis.line(
            X=np.array([engine.state.iteration]),
            Y=np.array([engine.state.output]),
            update="append",
            win=train_loss_window,
        )

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        print(
            f"Training Results - Epoch: {engine.state.epoch}  Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        vis.line(X=np.array([engine.state.epoch]),
                 Y=np.array([avg_accuracy]),
                 win=train_avg_accuracy_window,
                 update="append")
        vis.line(X=np.array([engine.state.epoch]),
                 Y=np.array([avg_nll]),
                 win=train_avg_loss_window,
                 update="append")

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        print(
            f"Validation Results - Epoch: {engine.state.epoch}  Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        vis.line(X=np.array([engine.state.epoch]),
                 Y=np.array([avg_accuracy]),
                 win=val_avg_accuracy_window,
                 update="append")
        vis.line(X=np.array([engine.state.epoch]),
                 Y=np.array([avg_nll]),
                 win=val_avg_loss_window,
                 update="append")

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)
def run(
    train_batch_size,
    val_batch_size,
    epochs,
    lr,
    momentum,
    log_interval,
    log_dir,
    checkpoint_every,
    resume_from,
    crash_iteration=-1,
    deterministic=False,
):
    # Setup seed to have same model's initialization:
    manual_seed(75)

    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    writer = SummaryWriter(log_dir=log_dir)
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    criterion = nn.NLLLoss()
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.5)

    # Setup trainer and evaluator
    if deterministic:
        tqdm.write("Setup deterministic trainer")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device,
                                        deterministic=deterministic)

    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                "accuracy": Accuracy(),
                                                "nll": Loss(criterion)
                                            },
                                            device=device)

    # Apply learning rate scheduling
    @trainer.on(Events.EPOCH_COMPLETED)
    def lr_step(engine):
        lr_scheduler.step()

    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=f"Epoch {0} - loss: {0:.4f} - lr: {lr:.4f}")

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        lr = optimizer.param_groups[0]["lr"]
        pbar.desc = f"Epoch {engine.state.epoch} - loss: {engine.state.output:.4f} - lr: {lr:.4f}"
        pbar.update(log_interval)
        writer.add_scalar("training/loss", engine.state.output,
                          engine.state.iteration)
        writer.add_scalar("lr", lr, engine.state.iteration)

    if crash_iteration > 0:

        @trainer.on(Events.ITERATION_COMPLETED(once=crash_iteration))
        def _(engine):
            raise Exception(f"STOP at {engine.state.iteration}")

    if resume_from is not None:

        @trainer.on(Events.STARTED)
        def _(engine):
            pbar.n = engine.state.iteration % engine.state.epoch_length

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # Compute and log validation metrics
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            f"Validation Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        pbar.n = pbar.last_print_n = 0
        writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("valdation/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # Setup object to checkpoint
    objects_to_checkpoint = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    training_checkpoint = Checkpoint(
        to_save=objects_to_checkpoint,
        save_handler=DiskSaver(log_dir, require_empty=False),
        n_saved=None,
        global_step_transform=lambda *_: trainer.state.epoch,
    )
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_every),
                              training_checkpoint)

    # Setup logger to print and dump into file: model weights, model grads and data stats
    # - first 3 iterations
    # - 4 iterations after checkpointing
    # This helps to compare resumed training with checkpointed training
    def log_event_filter(e, event):
        if event in [1, 2, 3]:
            return True
        elif 0 <= (event % (checkpoint_every * e.state.epoch_length)) < 5:
            return True
        return False

    fp = Path(log_dir) / ("run.log"
                          if resume_from is None else "resume_run.log")
    fp = fp.as_posix()
    for h in [log_data_stats, log_model_weights, log_model_grads]:
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(event_filter=log_event_filter),
            h,
            model=model,
            fp=fp)

    if resume_from is not None:
        tqdm.write(f"Resume from the checkpoint: {resume_from}")
        checkpoint = torch.load(resume_from)
        Checkpoint.load_objects(to_load=objects_to_checkpoint,
                                checkpoint=checkpoint)

    try:
        # Synchronize random states
        manual_seed(15)
        trainer.run(train_loader, max_epochs=epochs)
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    pbar.close()
    writer.close()
Beispiel #14
0
def training(local_rank, config, logger, with_clearml):

    rank = idist.get_rank()
    manual_seed(config.seed + local_rank)

    train_loader = config.train_loader
    val_loader = config.val_loader
    train_eval_loader = config.train_eval_loader

    model, optimizer, criterion = utils.initialize(config)

    # Setup trainer for this specific task
    trainer = create_trainer(model, optimizer, criterion, train_loader.sampler,
                             config, logger, with_clearml)

    # Setup evaluators
    num_classes = config.num_classes
    cm_metric = ConfusionMatrix(num_classes=num_classes)

    val_metrics = {
        "IoU": IoU(cm_metric),
        "mIoU_bg": mIoU(cm_metric),
    }

    if ("val_metrics" in config) and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    evaluator = create_evaluator(model,
                                 val_metrics,
                                 config,
                                 with_clearml,
                                 tag="val")
    train_evaluator = create_evaluator(model,
                                       val_metrics,
                                       config,
                                       with_clearml,
                                       tag="train")

    val_interval = config.get("val_interval", 1)

    # Run validation on every val_interval epoch, in the end of the training
    # and in the begining if config.start_by_validation is True
    event = Events.EPOCH_COMPLETED(every=val_interval)
    if config.num_epochs % val_interval != 0:
        event |= Events.COMPLETED
    if config.get("start_by_validation", False):
        event |= Events.STARTED

    @trainer.on(event)
    def run_validation():
        epoch = trainer.state.epoch
        state = train_evaluator.run(train_eval_loader)
        utils.log_metrics(logger, epoch, state.times["COMPLETED"], "Train",
                          state.metrics)
        state = evaluator.run(val_loader)
        utils.log_metrics(logger, epoch, state.times["COMPLETED"], "Test",
                          state.metrics)

    score_metric_name = "mIoU_bg"
    if "es_patience" in config:
        common.add_early_stopping_by_val_score(config.es_patience,
                                               evaluator,
                                               trainer,
                                               metric_name=score_metric_name)

    # Store 2 best models by validation accuracy:
    common.gen_save_best_models_by_val_score(
        save_handler=utils.get_save_handler(config.output_path.as_posix(),
                                            with_clearml),
        evaluator=evaluator,
        models=model,
        metric_name=score_metric_name,
        n_saved=2,
        trainer=trainer,
        tag="val",
    )

    # Setup Tensorboard logger
    if rank == 0:
        tb_logger = common.setup_tb_logging(
            config.output_path.as_posix(),
            trainer,
            optimizer,
            evaluators={
                "training": train_evaluator,
                "validation": evaluator
            },
        )

        # Log validation predictions as images
        # We define a custom event filter to log less frequently the images (to reduce storage size)
        # - we plot images with masks of the middle validation batch
        # - once every 3 validations and
        # - at the end of the training
        def custom_event_filter(_, val_iteration):
            c1 = val_iteration == len(val_loader) // 2
            c2 = trainer.state.epoch % (config.get("val_interval", 1) * 3) == 0
            c2 |= trainer.state.epoch == config.num_epochs
            return c1 and c2

        # Image denormalization function to plot predictions with images
        mean = config.get("mean", (0.485, 0.456, 0.406))
        std = config.get("std", (0.229, 0.224, 0.225))
        img_denormalize = partial(data.denormalize, mean=mean, std=std)

        tb_logger.attach(
            evaluator,
            log_handler=vis.predictions_gt_images_handler(
                img_denormalize_fn=img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="validation",
            ),
            event_name=Events.ITERATION_COMPLETED(
                event_filter=custom_event_filter),
        )

    # Log confusion matrix to ClearML:
    if with_clearml:
        trainer.add_event_handler(Events.COMPLETED, compute_and_log_cm,
                                  cm_metric, trainer.state.iteration)

    trainer.run(train_loader, max_epochs=config.num_epochs)

    if idist.get_rank() == 0:
        tb_logger.close()
Beispiel #15
0
def run(train_loader, val_loader, epochs, lr, momentum, log_interval, log_dir):

    model = Vgg16()

    writer = create_summary_writer(model, train_loader, log_dir)

    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'

    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=0.001)

    lr_scheduler = ExponentialLR(optimizer, gamma=0.975)

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        F.nll_loss,
                                        device=device)

    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'accuracy': Accuracy(),
                                                'nll': Loss(F.nll_loss)
                                            },
                                            device=device)

    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda engine: lr_scheduler.step())

    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    # store the best model
    best_model_handler = ModelCheckpoint(dirname=log_dir,
                                         filename_prefix="best",
                                         n_saved=3,
                                         score_name="test_acc",
                                         score_function=default_score_fn)
    evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {
        'model': model,
    })

    # add early stopping
    es_patience = 5
    es_handler = EarlyStopping(patience=es_patience,
                               score_function=default_score_fn,
                               trainer=trainer)
    evaluator.add_event_handler(Events.COMPLETED, es_handler)

    def empty_cuda_cache(engine):
        torch.cuda.empty_cache()
        import gc
        gc.collect()

    trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
    evaluator.add_event_handler(Events.COMPLETED, empty_cuda_cache)

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
              "".format(engine.state.epoch, engine.state.iteration,
                        len(train_loader), engine.state.output))
        writer.add_scalar("training/loss", engine.state.output,
                          engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        print(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))
        writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        print(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))
        writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("valdation/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    writer.close()
    print(args)

    if len(args.load_model) > 0:
        load_model_path = args.load_model
        print("load mode " + load_model_path)
        to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer}
        checkpoint = torch.load(load_model_path)
        Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
        print("load model complete")
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr
            print("change lr to ", args.lr)
    else:
        print("do not load, keep training")

    @trainer.on(Events.ITERATION_COMPLETED(every=50))
    def log_training_loss(trainer):
        timestamp = get_readable_time()
        print(timestamp + " Epoch[{}] Loss: {:.2f}".format(
            trainer.state.epoch, trainer.state.output))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(trainer):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        timestamp = get_readable_time()
        print(
            timestamp +
            " Training set Results - Epoch: {}  Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
            .format(trainer.state.epoch, metrics['mae'], metrics['mse'],
                    metrics['loss']))
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval,
        log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    writer = SummaryWriter(log_dir=log_dir)
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.NLLLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)

    val_metrics = {"accuracy": Accuracy(), "nll": Loss(criterion)}
    evaluator = create_supervised_evaluator(model,
                                            metrics=val_metrics,
                                            device=device)

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        print(
            f"Epoch[{engine.state.epoch}] Iteration[{engine.state.iteration}/{len(train_loader)}] "
            f"Loss: {engine.state.output:.2f}")
        writer.add_scalar("training/loss", engine.state.output,
                          engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        print(
            f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        print(
            f"Validation Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("valdation/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    writer.close()
def training(local_rank, config, logger=None):
    #
    # if not getattr(config, "use_fp16", True):
    #     raise RuntimeError("This training script uses by default fp16 AMP")

    torch.backends.cudnn.benchmark = True

    set_seed(config.seed + local_rank)

    train_loader, val_loader, train_eval_loader = config.train_loader, config.val_loader, config.train_eval_loader

    # Setup model, optimizer, criterion
    model, optimizer, criterion = initialize(config)

    # Setup trainer for this specific task
    trainer = create_trainer(model, optimizer, criterion, train_loader.sampler,
                             config, logger)

    # Setup evaluators
    num_classes = config.num_classes
    cm_metric = ConfusionMatrix(num_classes=num_classes)

    val_metrics = {
        "IoU": IoU(cm_metric),
        "mIoU_bg": mIoU(cm_metric),
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    evaluator, train_evaluator = create_evaluators(model, val_metrics, config)

    val_interval = getattr(config, "val_interval", 1)

    @trainer.on(Events.EPOCH_COMPLETED(every=val_interval))
    def run_validation():
        epoch = trainer.state.epoch
        state = train_evaluator.run(train_eval_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Train",
                    state.metrics)
        state = evaluator.run(val_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Test",
                    state.metrics)

    if config.num_epochs % val_interval != 0:
        trainer.add_event_handler(Events.COMPLETED, run_validation)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)

    score_metric_name = "mIoU_bg"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience,
                                               evaluator,
                                               trainer,
                                               metric_name=score_metric_name)

    # Store 3 best models by validation accuracy:
    common.gen_save_best_models_by_val_score(
        save_handler=get_save_handler(config),
        evaluator=evaluator,
        models=model,
        metric_name=score_metric_name,
        n_saved=3,
        trainer=trainer,
        tag="val",
    )

    if idist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(
            config.output_path.as_posix(),
            trainer,
            optimizer,
            evaluators={
                "training": train_evaluator,
                "validation": evaluator
            },
        )

        exp_tracking_logger = tracking.setup_logging(trainer,
                                                     optimizer,
                                                     evaluators={
                                                         "training":
                                                         train_evaluator,
                                                         "validation":
                                                         evaluator
                                                     })

        # Log validation predictions as images
        # We define a custom event filter to log less frequently the images (to reduce storage size)
        # - we plot images with masks of the middle validation batch
        # - once every 3 validations and
        # - at the end of the training
        def custom_event_filter(_, val_iteration):
            c1 = val_iteration == len(val_loader) // 2
            c2 = trainer.state.epoch % (getattr(config, "val_interval", 1) *
                                        3) == 0
            c2 |= trainer.state.epoch == config.num_epochs
            return c1 and c2

        tb_logger.attach(
            evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="validation"),
            event_name=Events.ITERATION_COMPLETED(
                event_filter=custom_event_filter),
        )

    # Log confusion matrix to Trains:

    trainer.run(train_loader, max_epochs=config.num_epochs)

    if idist.get_rank() == 0:
        tb_logger.close()
        exp_tracking_logger.close()
def setup_ignite(
        engine: Engine,
        params: SimpleNamespace,
        exp_source,
        run_name: str,
        model,
        optimizer,
        buffer,
        target_net,
        extra_metrics: Iterable[str] = (),
):
    simplefilter("ignore", category=UserWarning)
    handler = EndOfEpisodeHandler(exp_source,
                                  bound_avg_reward=params.stop_reward)
    handler.attach(engine)
    EpisodeFPSHandler().attach(engine)

    objects_to_checkpoint = {
        'model': model,
        'optimizer': optimizer,
        'trainer': engine,
        "buffer": buffer,
        "target_net": target_net
    }
    checkpoint_dir = Path("models backup")
    saver = LightDiskSaver(str(checkpoint_dir),
                           create_dir=True,
                           require_empty=False)
    handler = Checkpoint(objects_to_checkpoint, saver, n_saved=1)
    engine.add_event_handler(Events.ITERATION_COMPLETED(every=30000), handler)

    checkpoints_paths = list(checkpoint_dir.iterdir())
    if checkpoints_paths:
        checkpoint = joblib.load(checkpoints_paths[-1])
        print(f"Loading checkpoint {checkpoints_paths[-1].name}")
        Checkpoint.load_objects(to_load=objects_to_checkpoint,
                                checkpoint=checkpoint)

    @engine.on(EpisodeEvents.EPISODE_COMPLETED)
    def episode_completed(trainer: Engine):
        passed = trainer.state.metrics.get("time_passed", 0)
        print(
            "Episode {}: reward={:.0f}, steps={}, speed={:.1f} f/s, elapsed={}"
            .format(trainer.state.episode, trainer.state.episode_reward,
                    trainer.state.episode_steps,
                    trainer.state.metrics.get("avg_fps", 0),
                    timedelta(seconds=int(passed))))

    @engine.on(EpisodeEvents.BOUND_REWARD_REACHED)
    def game_solved(trainer: Engine):
        passed = trainer.state.metrics["time_passed"]
        print(
            f"Game solved in {timedelta(seconds=int(passed))} after {trainer.state.episode}"
            f" episodes and {trainer.state.iteration} iterations!")
        trainer.should_terminate = True

    now = datetime.now().isoformat(timespec="minutes").replace(":", "-")
    logdir = f"runs/{now}-{params.run_name}-{run_name}"
    tb = TensorboardLogger(log_dir=logdir)
    run_avg = RunningAverage(output_transform=lambda v: v["loss"])
    run_avg.attach(engine, "avg_loss")
    metrics = ["reward", "steps", "avg_reward"]
    handler = OutputHandler(tag="episodes", metric_names=metrics)
    event = EpisodeEvents.EPISODE_COMPLETED
    tb.attach(engine, log_handler=handler, event_name=event)

    # write to tensorboard every 100 iterations
    PeriodicEvents().attach(engine)
    metrics = ["avg_loss", "avg_fps"]
    metrics.extend(extra_metrics)
    handler = OutputHandler(tag="train",
                            metric_names=metrics,
                            output_transform=lambda a: a)
    event = PeriodEvents.ITERS_100_COMPLETED
    tb.attach(engine, log_handler=handler, event_name=event)
Beispiel #20
0
def training(local_rank, config, logger=None):

    if not getattr(config, "use_fp16", True):
        raise RuntimeError("This training script uses by default fp16 AMP")

    torch.backends.cudnn.benchmark = True

    set_seed(config.seed + local_rank)

    train_loader, val_loader, train_eval_loader = config.train_loader, config.val_loader, config.train_eval_loader

    # Setup model, optimizer, criterion
    model, optimizer, criterion = initialize(config)

    if not hasattr(config, "prepare_batch"):
        config.prepare_batch = _prepare_batch

    # Setup trainer for this specific task
    trainer = create_trainer(model, optimizer, criterion, train_loader.sampler,
                             config, logger)

    if getattr(config, "benchmark_dataflow", False):
        benchmark_dataflow_num_iters = getattr(config,
                                               "benchmark_dataflow_num_iters",
                                               1000)
        DataflowBenchmark(benchmark_dataflow_num_iters,
                          prepare_batch=config.prepare_batch).attach(
                              trainer, train_loader)

    # Setup evaluators
    val_metrics = {
        "Accuracy": Accuracy(),
        "Top-5 Accuracy": TopKCategoricalAccuracy(k=5),
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    evaluator, train_evaluator = create_evaluators(model, val_metrics, config)

    @trainer.on(
        Events.EPOCH_COMPLETED(every=getattr(config, "val_interval", 1))
        | Events.COMPLETED)
    def run_validation():
        epoch = trainer.state.epoch
        state = train_evaluator.run(train_eval_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Train",
                    state.metrics)
        state = evaluator.run(val_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Test",
                    state.metrics)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)

    score_metric_name = "Accuracy"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience,
                                               evaluator,
                                               trainer,
                                               metric_name=score_metric_name)

    # Store 3 best models by validation accuracy:
    common.save_best_model_by_val_score(
        config.output_path.as_posix(),
        evaluator,
        model=model,
        metric_name=score_metric_name,
        n_saved=3,
        trainer=trainer,
        tag="val",
    )

    if idist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(
            config.output_path.as_posix(),
            trainer,
            optimizer,
            evaluators={
                "training": train_evaluator,
                "validation": evaluator
            },
        )

        exp_tracking_logger = exp_tracking.setup_logging(trainer,
                                                         optimizer,
                                                         evaluators={
                                                             "training":
                                                             train_evaluator,
                                                             "validation":
                                                             evaluator
                                                         })

        # Log train/val predictions:
        tb_logger.attach(
            evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="validation"),
            event_name=Events.ITERATION_COMPLETED(once=len(val_loader) // 2),
        )

        tb_logger.attach(
            train_evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="training"),
            event_name=Events.ITERATION_COMPLETED(
                once=len(train_eval_loader) // 2),
        )

    trainer.run(train_loader, max_epochs=config.num_epochs)

    if idist.get_rank() == 0:
        tb_logger.close()
        exp_tracking_logger.close()
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval,
        log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    writer = create_summary_writer(model, train_loader, log_dir)
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        F.nll_loss,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                "accuracy": Accuracy(),
                                                "nll": Loss(F.nll_loss)
                                            },
                                            device=device)

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
              "".format(engine.state.epoch, engine.state.iteration,
                        len(train_loader), engine.state.output))
        writer.add_scalar("training/loss", engine.state.output,
                          engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        print(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))
        writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        print(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))
        writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("valdation/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    writer.close()
Beispiel #22
0
def run(output_path, config):
    device = "cuda"

    local_rank = config['local_rank']
    distributed = backend is not None
    if distributed:
        torch.cuda.set_device(local_rank)
        device = "cuda"
    rank = dist.get_rank() if distributed else 0

    torch.manual_seed(config['seed'] + rank)

    # Rescale batch_size and num_workers
    ngpus_per_node = torch.cuda.device_count()
    ngpus = dist.get_world_size() if distributed else 1
    batch_size = config['batch_size'] // ngpus
    num_workers = int(
        (config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node)

    train_loader, test_loader = get_train_test_loaders(
        path=config['data_path'],
        batch_size=batch_size,
        distributed=distributed,
        num_workers=num_workers)

    model = get_model(config['model'])
    model = model.to(device)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[
                local_rank,
            ], output_device=local_rank)

    optimizer = optim.SGD(model.parameters(),
                          lr=config['learning_rate'],
                          momentum=config['momentum'],
                          weight_decay=config['weight_decay'],
                          nesterov=True)

    criterion = nn.CrossEntropyLoss().to(device)

    le = len(train_loader)
    milestones_values = [(0, 0.0),
                         (le * config['num_warmup_epochs'],
                          config['learning_rate']),
                         (le * config['num_epochs'], 0.0)]
    lr_scheduler = PiecewiseLinear(optimizer,
                                   param_name="lr",
                                   milestones_values=milestones_values)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(y, device=device, non_blocking=non_blocking))

    def process_function(engine, batch):

        x, y = _prepare_batch(batch, device=device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            'batch loss': loss.item(),
        }

    trainer = Engine(process_function)
    train_sampler = train_loader.sampler if distributed else None
    to_save = {
        'trainer': trainer,
        'model': model,
        'optimizer': optimizer,
        'lr_scheduler': lr_scheduler
    }
    metric_names = [
        'batch loss',
    ]
    common.setup_common_training_handlers(
        trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config['checkpoint_every'],
        output_path=output_path,
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbar_on_iters=config['display_iters'],
        log_every_iters=10)

    if rank == 0:
        tb_logger = TensorboardLogger(log_dir=output_path)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="train",
                                                   metric_names=metric_names),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer,
                                                            param_name="lr"),
                         event_name=Events.ITERATION_STARTED)

    metrics = {
        "accuracy": Accuracy(device=device if distributed else None),
        "loss": Loss(criterion, device=device if distributed else None)
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine):
        torch.cuda.synchronize()
        train_evaluator.run(train_loader)
        evaluator.run(test_loader)

    trainer.add_event_handler(
        Events.EPOCH_STARTED(every=config['validate_every']), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    if rank == 0:
        if config['display_iters']:
            ProgressBar(persist=False,
                        desc="Train evaluation").attach(train_evaluator)
            ProgressBar(persist=False,
                        desc="Test evaluation").attach(evaluator)

        tb_logger.attach(
            train_evaluator,
            log_handler=OutputHandler(
                tag="train",
                metric_names=list(metrics.keys()),
                global_step_transform=global_step_from_engine(trainer)),
            event_name=Events.COMPLETED)

        tb_logger.attach(
            evaluator,
            log_handler=OutputHandler(
                tag="test",
                metric_names=list(metrics.keys()),
                global_step_transform=global_step_from_engine(trainer)),
            event_name=Events.COMPLETED)

        # Store the best model by validation accuracy:
        common.save_best_model_by_val_score(output_path,
                                            evaluator,
                                            model=model,
                                            metric_name='accuracy',
                                            n_saved=3,
                                            trainer=trainer,
                                            tag="test")

        if config['log_model_grads_every'] is not None:
            tb_logger.attach(trainer,
                             log_handler=GradsHistHandler(
                                 model, tag=model.__class__.__name__),
                             event_name=Events.ITERATION_COMPLETED(
                                 every=config['log_model_grads_every']))

    if config['crash_iteration'] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config['crash_iteration']))
        def _(engine):
            raise Exception("STOP at iteration: {}".format(
                engine.state.iteration))

    resume_from = config['resume_from']
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix())
        print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix())
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    try:
        trainer.run(train_loader, max_epochs=config['num_epochs'])
    except Exception as e:
        import traceback
        print(traceback.format_exc())

    if rank == 0:
        tb_logger.close()
Beispiel #23
0
def run(conf: DictConfig, local_rank=0, distributed=False):
    epochs = conf.train.epochs
    epoch_length = conf.train.epoch_length
    torch.manual_seed(conf.general.seed)

    if distributed:
        rank = dist.get_rank()
        num_replicas = dist.get_world_size()
        torch.cuda.set_device(local_rank)
    else:
        rank = 0
        num_replicas = 1
        torch.cuda.set_device(conf.general.gpu)
    device = torch.device('cuda')
    loader_args = dict()
    master_node = rank == 0

    if master_node:
        print(conf.pretty())
    if num_replicas > 1:
        epoch_length = epoch_length // num_replicas
        loader_args = dict(rank=rank, num_replicas=num_replicas)

    train_dl = create_train_loader(conf.data, **loader_args)

    if epoch_length < 1:
        epoch_length = len(train_dl)

    metric_names = list(conf.logging.stats)
    metrics = create_metrics(metric_names, device if distributed else None)

    G = instantiate(conf.model.G).to(device)
    D = instantiate(conf.model.D).to(device)
    G_loss = instantiate(conf.loss.G).to(device)
    D_loss = instantiate(conf.loss.D).to(device)
    G_opt = instantiate(conf.optim.G, G.parameters())
    D_opt = instantiate(conf.optim.D, D.parameters())
    G_ema = None

    if master_node and conf.G_smoothing.enabled:
        G_ema = instantiate(conf.model.G)
        if not conf.G_smoothing.use_cpu:
            G_ema = G_ema.to(device)
        G_ema.load_state_dict(G.state_dict())
        G_ema.requires_grad_(False)

    to_save = {
        'G': G,
        'D': D,
        'G_loss': G_loss,
        'D_loss': D_loss,
        'G_opt': G_opt,
        'D_opt': D_opt,
        'G_ema': G_ema
    }

    if master_node and conf.logging.model:
        logging.info(G)
        logging.info(D)

    if distributed:
        ddp_kwargs = dict(device_ids=[
            local_rank,
        ], output_device=local_rank)
        G = torch.nn.parallel.DistributedDataParallel(G, **ddp_kwargs)
        D = torch.nn.parallel.DistributedDataParallel(D, **ddp_kwargs)

    train_options = {
        'train': dict(conf.train),
        'snapshot': dict(conf.snapshots),
        'smoothing': dict(conf.G_smoothing)
    }
    bs_dl = int(conf.data.loader.batch_size) * num_replicas
    bs_eff = conf.train.batch_size
    if bs_eff % bs_dl:
        raise AttributeError(
            "Effective batch size should be divisible by data-loader batch size "
            "multiplied by number of devices in use"
        )  # until there is no special bs for master node...
    upd_interval = max(bs_eff // bs_dl, 1)
    train_options['train']['update_interval'] = upd_interval
    if epoch_length < len(train_dl):
        # ideally epoch_length should be tied to the effective batch_size only
        # and the ignite trainer counts data-loader iterations
        epoch_length *= upd_interval

    train_loop, sample_images = create_train_closures(G,
                                                      D,
                                                      G_loss,
                                                      D_loss,
                                                      G_opt,
                                                      D_opt,
                                                      G_ema=G_ema,
                                                      device=device,
                                                      options=train_options)
    trainer = create_trainer(train_loop, metrics, device, num_replicas)
    to_save['trainer'] = trainer

    every_iteration = Events.ITERATION_COMPLETED
    trainer.add_event_handler(every_iteration, TerminateOnNan())

    cp = conf.checkpoints
    pbar = None

    if master_node:
        log_freq = conf.logging.iter_freq
        log_event = Events.ITERATION_COMPLETED(every=log_freq)
        pbar = ProgressBar(persist=False)
        trainer.add_event_handler(Events.EPOCH_STARTED, on_epoch_start)
        trainer.add_event_handler(log_event, log_iter, pbar, log_freq)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, log_epoch)
        pbar.attach(trainer, metric_names=metric_names)
        setup_checkpoints(trainer, to_save, epoch_length, conf)
        setup_snapshots(trainer, sample_images, conf)

    if 'load' in cp.keys() and cp.load is not None:
        if master_node:
            logging.info("Resume from a checkpoint: {}".format(cp.load))
            trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp,
                                      pbar)
        Checkpoint.load_objects(to_load=to_save,
                                checkpoint=torch.load(cp.load,
                                                      map_location=device))

    try:
        trainer.run(train_dl, max_epochs=epochs, epoch_length=epoch_length)
    except Exception as e:
        import traceback
        logging.error(traceback.format_exc())
    if pbar is not None:
        pbar.close()
Beispiel #24
0
def main(data_dir, model_type, emsize, nhid, nlayers, nhead, warm_up, step_size,
         clip, epochs, batch_size, bptt, dropout, tied, seed, use_cuda, spm_path,
         log_interval, val_interval, tb_log, save_path):
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        if not use_cuda:
            print('WARNING: You have a CUDA device, so you should probably run with --cuda')
    device = torch.device('cuda' if use_cuda else 'cpu')

    tokenizer = Tokenizer(spm_path, bos_eos=True)
    ntokens = tokenizer.size()
    padding_index = tokenizer.to_id('<pad>')
    train_loader = get_loader(os.path.join(data_dir, 'train'), batch_size, padding_value=padding_index)
    val_loader = get_loader(os.path.join(data_dir, 'val'), batch_size, padding_value=padding_index)

    if model_type == 'LSTMTransformer':
        model = LSTMTransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
    elif model_type == 'Transformer':
        model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
    else:
        model = RNNModel(model_type, ntokens, emsize, nhid, nlayers, dropout, tied).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1)
    scheduler = TransformerLR(optimizer, emsize, warmup_steps=warm_up, step_size=step_size)
    criterion = nn.NLLLoss(ignore_index=padding_index)

    def _update(engine, batch):
        model.train()

        _batch = batch.to(device)
        x, y = _batch[:-1], _batch[1:]
        seq_length, batch_size = batch.size()

        if model_type == 'LSTMTransformer':
            hidden = model.init_hidden(batch_size)
            mems = None
        elif model_type == 'Transformer':
            pass
        else:
            hidden = model.init_hidden(batch_size)

        total_loss = 0
        for i in range(0, seq_length - 1, bptt):
            optimizer.zero_grad()
            _x, _y = x[i:i + bptt], y[i:i + bptt]
            if model_type == 'LSTMTransformer':
                hidden = repackage_hidden(hidden)
                mems = repackage_hidden(mems) if mems else mems
                output, hidden, mems = model(_x, hidden=hidden, mems=mems)
            elif model_type == 'Transformer':
                output = model(_x)
            else:
                hidden = repackage_hidden(hidden)
                output, hidden = model(_x, hidden)

            loss = criterion(output.view(-1, ntokens), _y.view(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()

        total_loss /= math.ceil((seq_length - 1) / bptt)
        return {'loss': total_loss, 'ppl': math.exp(total_loss)}

    def _evaluate(engine, batch):
        model.eval()

        with torch.no_grad():
            _batch = batch.to(device)
            x, y = _batch[:-1], _batch[1:]
            seq_length, batch_size = batch.size()

            if model_type == 'LSTMTransformer':
                hidden = model.init_hidden(batch_size)
                mems = None
            elif model_type == 'Transformer':
                pass
            else:
                hidden = model.init_hidden(batch_size)

            y_pred = None
            for i in range(0, seq_length - 1, bptt):
                optimizer.zero_grad()
                _x, _y = x[i:i + bptt], y[i:i + bptt]
                if model_type == 'LSTMTransformer':
                    output, hidden, mems = model(_x, hidden=hidden, mems=mems)
                elif model_type == 'Transformer':
                    output = model(_x)
                else:
                    output, hidden = model(_x, hidden)

                y_pred = output if y_pred is None else torch.cat([y_pred, output], dim=0)

        return y_pred.view(-1, ntokens), y.view(-1)

    writer = SummaryWriter(log_dir=tb_log)
    trainer = Engine(_update)
    evaluator = Engine(_evaluate)
    Loss(criterion).attach(evaluator, 'loss')

    @trainer.on(Events.STARTED)
    def assign_var(engine):
        engine.state.min_val_loss = None

    @trainer.on(Events.ITERATION_COMPLETED(every=1))
    def training_loss_tb(engine):
        writer.add_scalar('train/loss', engine.state.output['loss'], engine.state.iteration)
        writer.add_scalar('train/ppl', engine.state.output['ppl'], engine.state.iteration)

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def training_loss_print(engine):
        print('Epoch[{}] Batch[{}/{}] Loss: {:.2f} PPL: {:.2f}'
              ''.format(engine.state.epoch, engine.state.iteration, len(train_loader),
                        engine.state.output['loss'], engine.state.output['ppl']))

    @trainer.on(Events.ITERATION_COMPLETED(every=val_interval))
    @trainer.on(Events.COMPLETED)
    def validation(engine):
        evaluator.run(val_loader)
        loss = evaluator.state.metrics['loss']
        ppl = math.exp(loss)
        print('Validation - Epoch[{}] Batch[{}/{}] Loss: {:.2f} PPL: {:.2f} LR: {:02.5f}'
              ''.format(engine.state.epoch, engine.state.iteration, len(train_loader),
                        loss, ppl, scheduler.get_lr()[0]))
        writer.add_scalar('val/loss', loss, engine.state.iteration)
        writer.add_scalar('val/ppl', ppl, engine.state.iteration)

        # save model if loss decreases
        if engine.state.min_val_loss is None or loss < engine.state.min_val_loss:
            torch.save(model, save_path)
            engine.state.min_val_loss = loss

    @trainer.on(Events.COMPLETED)
    def test(engine):
        nonlocal model
        model = torch.load(save_path)
        test_loader = get_loader(os.path.join(data_dir, 'test'), batch_size, padding_value=padding_index)
        evaluator.run(test_loader)
        loss = evaluator.state.metrics['loss']
        ppl = math.exp(loss)
        print('------------------------------------------------------------')
        print('Test - Loss: {:.2f} PPL: {:.2f}'.format(loss, ppl))
        print('------------------------------------------------------------')
        writer.add_scalar('test/loss', loss, engine.state.iteration)
        writer.add_scalar('test/ppl', ppl, engine.state.iteration)

    trainer.run(train_loader, max_epochs=epochs)

    writer.close()
Beispiel #25
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"
    if torch.cuda.is_available(): device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.NLLLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("trainer")

    val_metrics = {"accuracy": Accuracy(), "nll": Loss(criterion)}
    evaluator = create_supervised_evaluator(model,
                                            metrics=val_metrics,
                                            device=device)
    evaluator.logger = setup_logger("evaluator")

    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        pbar.desc = desc.format(engine.state.output)
        pbar.update(log_interval)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))
        pbar.n = pbar.last_print_n = 0

    @trainer.on(Events.EPOCH_COMPLETED | Events.COMPLETED)
    def log_time(engine):
        tqdm.write("{} took {} seconds".format(
            trainer.last_event_name.name,
            trainer.state.times[trainer.last_event_name.name]))

    trainer.run(train_loader, max_epochs=epochs)
Beispiel #26
0
        train_data_loader, _, _, _ = get_pytorch_dataloader(
            args, train_file_name_prefix, shuffle=True)
        optimizer = Adam(model.parameters(), lr=args.lr)
        '''Learning rate decays every 5 epochs'''
        optimizer_scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
        scheduler = LRScheduler(optimizer_scheduler)
        trainer = Engine(train)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, scheduler)
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  lambda _: evaluator.run(dev_data_loader))

        pbar = ProgressBar(persist=True, desc='Training')
        pbar.attach(trainer, metric_names=["loss"])
        RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")

        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=args.loss_log_interval), lambda engine: \
                logger.info('Loss at iteration %d is %.5f', engine.state.iteration, engine.state.metrics['loss']))
        early_stop_handler = EarlyStopping(patience=args.patience,
                                           score_function=score_function,
                                           trainer=trainer)
        evaluator.add_event_handler(Events.COMPLETED,
                                    lambda engine: after_evaluation(engine))
        evaluator.add_event_handler(Events.COMPLETED, early_stop_handler)

        trainer.run(train_data_loader, max_epochs=args.epochs)
    else:
        '''If current run is evaluation, it will generate prediction json file'''
        evaluator.run(dev_data_loader)
        evaluator.add_event_handler(
            Events.COMPLETED,
            lambda engine: logger.info('Current evaluation accuracy is %.3f',
Beispiel #27
0
def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_3d(128,
                                       128,
                                       128,
                                       num_seg_classes=1,
                                       channel_dim=-1)

        n = nib.Nifti1Image(im, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))

        n = nib.Nifti1Image(seg, np.eye(4))
        nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
    train_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[:20], segs[:20])]
    val_files = [{
        "img": img,
        "seg": seg
    } for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        RandCropByPosNegLabeld(keys=["img", "seg"],
                               label_key="seg",
                               spatial_size=[96, 96, 96],
                               pos=1,
                               neg=1,
                               num_samples=4),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        ToTensord(keys=["img", "seg"]),
    ])
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        ToTensord(keys=["img", "seg"]),
    ])

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              collate_fn=list_data_collate,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["seg"].shape)

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=5,
                            num_workers=8,
                            collate_fn=list_data_collate,
                            pin_memory=torch.cuda.is_available())

    # create UNet, DiceLoss and Adam optimizer
    net = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    )
    loss = monai.losses.DiceLoss(sigmoid=True)
    lr = 1e-3
    opt = torch.optim.Adam(net.parameters(), lr)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Ignite trainer expects batch=(img, seg) and returns output=loss at every iteration,
    # user can add output_transform to return other values, like: y_pred, y, etc.
    def prepare_batch(batch, device=None, non_blocking=False):
        return _prepare_batch((batch["img"], batch["seg"]), device,
                              non_blocking)

    trainer = create_supervised_trainer(net,
                                        opt,
                                        loss,
                                        device,
                                        False,
                                        prepare_batch=prepare_batch)

    # adding checkpoint handler to save models (network params and optimizer stats) during training
    checkpoint_handler = ModelCheckpoint("./runs_dict/",
                                         "net",
                                         n_saved=10,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "net": net,
                                  "opt": opt
                              })

    # StatsHandler prints loss at every iteration and print metrics at every epoch,
    # we don't set metrics for trainer here, so just print loss, user can also customize print functions
    # and can use output_transform to convert engine.state.output if it's not loss value
    train_stats_handler = StatsHandler(name="trainer")
    train_stats_handler.attach(trainer)

    # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler
    train_tensorboard_stats_handler = TensorBoardStatsHandler()
    train_tensorboard_stats_handler.attach(trainer)

    validation_every_n_iters = 5
    # set parameters for validation
    metric_name = "Mean_Dice"
    # add evaluation metric to the evaluator engine
    val_metrics = {metric_name: MeanDice(sigmoid=True, to_onehot_y=False)}

    # Ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration,
    # user can add output_transform to return other values
    evaluator = create_supervised_evaluator(net,
                                            val_metrics,
                                            device,
                                            True,
                                            prepare_batch=prepare_batch)

    @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters))
    def run_validation(engine):
        evaluator.run(val_loader)

    # add early stopping handler to evaluator
    early_stopper = EarlyStopping(
        patience=4,
        score_function=stopping_fn_from_metric(metric_name),
        trainer=trainer)
    evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                handler=early_stopper)

    # add stats event handler to print validation stats via evaluator
    val_stats_handler = StatsHandler(
        name="evaluator",
        output_transform=lambda x:
        None,  # no need to print loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.epoch,
    )  # fetch global epoch number from trainer
    val_stats_handler.attach(evaluator)

    # add handler to record metrics to TensorBoard at every validation epoch
    val_tensorboard_stats_handler = TensorBoardStatsHandler(
        output_transform=lambda x:
        None,  # no need to plot loss value, so disable per iteration output
        global_epoch_transform=lambda x: trainer.state.iteration,
    )  # fetch global iteration number from trainer
    val_tensorboard_stats_handler.attach(evaluator)

    # add handler to draw the first image and the corresponding label and model output in the last batch
    # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations.
    val_tensorboard_image_handler = TensorBoardImageHandler(
        batch_transform=lambda batch: (batch["img"], batch["seg"]),
        output_transform=lambda output: predict_segmentation(output[0]),
        global_iter_transform=lambda x: trainer.state.epoch,
    )
    evaluator.add_event_handler(event_name=Events.ITERATION_COMPLETED(every=2),
                                handler=val_tensorboard_image_handler)

    train_epochs = 5
    state = trainer.run(train_loader, train_epochs)
    print(state)
Beispiel #28
0
def train():
    learning_rate = 0.0001
    save_on_iter_count = 100
    device = "cuda"
    envs = [
        ObservationScaler(gym.make(name))
        for name in ("Breakout-v0", "Pong-v0", "AirRaid-v0")
    ]
    discriminator = Discriminator(img_size=64).to(device)
    generator = Generator().to(device)
    objective = nn.BCELoss()
    discr_optimizer = optim.Adam(params=discriminator.parameters(),
                                 lr=learning_rate,
                                 betas=(0.5, 0.999))
    gen_optimizer = optim.Adam(params=generator.parameters(),
                               lr=learning_rate,
                               betas=(0.5, 0.999))

    def process_batch(trainer, batch):
        batch_size = batch.shape[0]
        gen_input_size = 10

        # get labels and inputs
        generator_inputs = torch.randn(
            (batch_size, gen_input_size, 1, 1)).to(device)
        fake_inputs = generator(generator_inputs).to(device)
        true_inputs = batch.to(device)
        fake_image_labels = torch.zeros((batch_size, )).to(device)
        true_image_labels = torch.ones((batch_size, )).to(device)

        # train discriminator
        discr_optimizer.zero_grad()
        discr_fake_image_output = discriminator(fake_inputs.detach())
        discr_true_image_output = discriminator(true_inputs)

        discr_loss = objective(discr_fake_image_output,
                               fake_image_labels) + objective(
                                   discr_true_image_output, true_image_labels)

        discr_loss.backward()
        discr_optimizer.step()

        # train generator
        gen_optimizer.zero_grad()
        discr_output = discriminator(fake_inputs)
        gen_loss = objective(discr_output, true_image_labels)
        gen_loss.backward()
        gen_optimizer.step()

        # save images
        if trainer.state.iteration % save_on_iter_count == 0:
            fake_img = vutils.make_grid(fake_inputs.data[:64], normalize=True)
            trainer.tb.writer.add_image("fake", fake_img,
                                        trainer.state.iteration)
            real_img = vutils.make_grid(true_inputs.data[:64], normalize=True)
            trainer.tb.writer.add_image("real", real_img,
                                        trainer.state.iteration)
            trainer.tb.writer.flush()
        return discr_loss.item(), gen_loss.item()

    engine = Engine(process_batch)
    tb = tb_logger.TensorboardLogger(log_dir=None)
    engine.tb = tb
    RunningAverage(output_transform=lambda out: out[1]).attach(
        engine, "avg_loss_gen")
    RunningAverage(output_transform=lambda out: out[0]).attach(
        engine, "avg_loss_dis")

    handler = tb_logger.OutputHandler(
        tag="train", metric_names=["avg_loss_gen", "avg_loss_dis"])
    tb.attach(engine,
              log_handler=handler,
              event_name=Events.ITERATION_COMPLETED)

    @engine.on(Events.ITERATION_COMPLETED(every=100))
    def log_training_loss(engine):
        print(f"Epoch[{engine.state.iteration}] Loss:", engine.state.output)

    engine.run(data=generate_batch(envs))
Beispiel #29
0
 def __init__(self, *args: Any, **kwargs: Any) -> None:
     super(BatchFiltered, self).__init__(
         started=Events.EPOCH_STARTED,
         completed=Events.EPOCH_COMPLETED,
         iteration_completed=Events.ITERATION_COMPLETED(*args, **kwargs),
     )
Beispiel #30
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
    vis = visdom.Visdom()

    # if not vis.check_connection():
    #     raise RuntimeError("Visdom server not running. Please run python -m visdom.server")

    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
    model = Net()
    device = 'cpu'

    if torch.cuda.is_available():
        device = 'cuda'

    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={'accuracy': Accuracy(),
                                                     'nll': Loss(F.nll_loss)},
                                            device=device)

    train_loss_window = create_plot_window(vis, '#Iterations', 'Loss', 'Training Loss')
    train_avg_loss_window = create_plot_window(vis, '#Iterations', 'Loss', 'Training Average Loss')
    train_avg_accuracy_window = create_plot_window(vis, '#Iterations', 'Accuracy', 'Training Average Accuracy')
    val_avg_loss_window = create_plot_window(vis, '#Epochs', 'Loss', 'Validation Average Loss')
    val_avg_accuracy_window = create_plot_window(vis, '#Epochs', 'Accuracy', 'Validation Average Accuracy')

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
              "".format(engine.state.epoch, engine.state.iteration, len(train_loader), engine.state.output))
        vis.line(X=np.array([engine.state.iteration]),
                 Y=np.array([engine.state.output]),
                 update='append', win=train_loss_window)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        print("Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
              .format(engine.state.epoch, avg_accuracy, avg_nll))
        vis.line(X=np.array([engine.state.epoch]), Y=np.array([avg_accuracy]),
                 win=train_avg_accuracy_window, update='append')
        vis.line(X=np.array([engine.state.epoch]), Y=np.array([avg_nll]),
                 win=train_avg_loss_window, update='append')

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
              .format(engine.state.epoch, avg_accuracy, avg_nll))
        vis.line(X=np.array([engine.state.epoch]), Y=np.array([avg_accuracy]),
                 win=val_avg_accuracy_window, update='append')
        vis.line(X=np.array([engine.state.epoch]), Y=np.array([avg_nll]),
                 win=val_avg_loss_window, update='append')

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)