def add_progress_bar_eval(evaluator, validation_loader):
    """
    "I can't believe it's not Keras"
    Running average accuracy and loss metrics + TQDM progressbar
    """
    validation_history = {'accuracy': [], 'loss': []}
    last_epoch = []

    RunningAverage(output_transform=lambda x: x[0]).attach(evaluator, 'loss')
    RunningAverage(Accuracy(output_transform=lambda x: (x[0], x[1]))).attach(
        evaluator, 'accuracy')

    prog_bar = ProgressBar()
    prog_bar.attach(evaluator, ['accuracy'])
    #     prog_bar.pbar_cls=tqdm.tqdm

    from ignite.handlers import Timer

    timer = Timer(average=True)
    timer.attach(evaluator,
                 start=Events.EPOCH_STARTED,
                 resume=Events.EPOCH_STARTED,
                 pause=Events.EPOCH_COMPLETED,
                 step=Events.EPOCH_COMPLETED)

    @evaluator.on(Events.EPOCH_COMPLETED)
    def log_validation_results(evaluator):
        metrics = evaluator.state.metrics
        accuracy = metrics['accuracy'] * 100
        loss = metrics['nll']
        validation_history['accuracy'].append(accuracy)
        validation_history['loss'].append(loss)
        val_msg = "Valid Epoch {}:  acc: {:.2f}% loss: {:.2f}, eval time: {:.2f}s".format(
            evaluator.state.epoch, accuracy, loss, timer.value())
        prog_bar.log_message(val_msg)
Ejemplo n.º 2
0
    def __init__(self, num_iters=100, prepare_batch=None, device="cuda"):

        from ignite.handlers import Timer

        def upload_to_gpu(engine, batch):
            if prepare_batch is not None:
                x, y = prepare_batch(batch, device=device, non_blocking=False)

        self.num_iters = num_iters
        self.benchmark_dataflow = Engine(upload_to_gpu)

        @self.benchmark_dataflow.on(Events.ITERATION_COMPLETED(once=num_iters))
        def stop_benchmark_dataflow(engine):
            engine.terminate()

        if dist.is_available() and dist.get_rank() == 0:

            @self.benchmark_dataflow.on(
                Events.ITERATION_COMPLETED(every=num_iters // 100))
            def show_progress_benchmark_dataflow(engine):
                print(".", end=" ")

        self.timer = Timer(average=False)
        self.timer.attach(self.benchmark_dataflow,
                          start=Events.EPOCH_STARTED,
                          resume=Events.ITERATION_STARTED,
                          pause=Events.ITERATION_COMPLETED,
                          step=Events.ITERATION_COMPLETED)
    def __init__(self, model, config, evaluator, data_loader, tb_writer,
                 run_info, logger, checkpoint_dir):
        """
        Creates a new trainer object for training a model.
        :param model: model to train. Needs to inherit from the BaseModel class.
        :param config: dictionary containing the whole configuration of the experiment
        :param evaluator: Instance of the evaluator class, used to run evaluation on a specified schedule
        :param data_loader: pytorch data loader providing the training data
        :param tb_writer: tensorboardX summary writer
        :param run_info: sacred run info for loging training progress
        :param logger: python logger object
        :param checkpoint_dir: directory path for storing checkpoints
        """
        self.run_info = run_info
        self.logger = logger
        self.data_loader = data_loader
        self.evaluator = evaluator
        self.engine = Engine(self._step)
        self.model = model
        self.config = config
        self.train_cfg = config['train']
        self.tb_writer = tb_writer

        self.pbar = ProgressBar(ascii=True, desc='* Epoch')
        self.timer = Timer(average=True)
        self.save_last_checkpoint_handler = ModelCheckpoint(
            checkpoint_dir,
            'last',
            save_interval=self.train_cfg['save_interval'],
            n_saved=self.train_cfg['save_n_last'],
            require_empty=False)

        self.add_handler()
Ejemplo n.º 4
0
def setup_timer(engine):
    timer = Timer(average=True)
    timer.attach(engine,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED)
    return timer
Ejemplo n.º 5
0
    def attach_events(self, description, environment=None, save_file = None):

        tim = Timer()
        tim.attach( self.engine,
                    start=Events.STARTED,
                    step=Events.ITERATION_COMPLETED,
        )

        log_interval = 100
        plot_interval = 10

        @self.engine.on(Events.ITERATION_COMPLETED)
        def print_training_loss(engine):
            iter = (engine.state.iteration -1)
            if iter % log_interval == 0:
                print("Epoch[{}] Iteration: {} Time: {} Loss: {:.2f}".format(
                    engine.state.epoch, iter, str(datetime.timedelta(seconds=int(tim.value()))), engine.state.output['loss']
                ))
        if environment:

            vis = visdom.Visdom(env=environment)

            def create_plot_window(vis, xlabel, ylabel, title):
                return vis.line(X=np.array([1]), Y=np.array([np.nan]), opts=dict(xlabel=xlabel, ylabel=ylabel, title=title))

            train_loss_window = create_plot_window(vis, '#Iterations', 'Loss', 'Training Loss {0}'.format(description))

            @self.engine.on(Events.ITERATION_COMPLETED)
            def plot_training_loss(engine):
                iter = (engine.state.iteration -1)
                if iter % plot_interval == 0:
                    vis.line(X=np.array([engine.state.iteration]),
                            Y=np.array([engine.state.output['loss']]),
                            update='append',
                            win=train_loss_window)
Ejemplo n.º 6
0
def do_validate(cfg, model, val_loader):
    device = cfg.MODEL.DEVICE
    if device == "cuda":
        torch.cuda.set_device(cfg.MODEL.CUDA)
    evaluator = create_evaluator(model, device=device)
    RunningAverage(output_transform=lambda x: x).attach(
        evaluator, 'eva_avg_acc')

    timer = Timer(average=True)

    timer.attach(evaluator,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    acc_list = list()

    @evaluator.on(Events.ITERATION_COMPLETED)
    def log_accuracy(engine):
        iter = (engine.state.iteration - 1) % len(val_loader) + 1
        print("Iteration[{}/{}]".format(iter, len(val_loader)))
        acc_list.append(engine.state.metrics['eva_avg_acc'])

    evaluator.run(val_loader)
    print("Validation Accuracy: {:1%}".format(np.array(acc_list).mean()))
Ejemplo n.º 7
0
class DataflowBenchmark:
    def __init__(self, num_iters=100, prepare_batch=None):

        from ignite.handlers import Timer

        device = idist.device()

        def upload_to_gpu(engine, batch):
            if prepare_batch is not None:
                x, y = prepare_batch(batch, device=device, non_blocking=False)

        self.num_iters = num_iters
        self.benchmark_dataflow = Engine(upload_to_gpu)

        @self.benchmark_dataflow.on(Events.ITERATION_COMPLETED(once=num_iters))
        def stop_benchmark_dataflow(engine):
            engine.terminate()

        if idist.get_rank() == 0:

            @self.benchmark_dataflow.on(
                Events.ITERATION_COMPLETED(every=num_iters // 100))
            def show_progress_benchmark_dataflow(engine):
                print(".", end=" ")

        self.timer = Timer(average=False)
        self.timer.attach(
            self.benchmark_dataflow,
            start=Events.EPOCH_STARTED,
            resume=Events.ITERATION_STARTED,
            pause=Events.ITERATION_COMPLETED,
            step=Events.ITERATION_COMPLETED,
        )

    def attach(self, trainer, train_loader):

        from torch.utils.data import DataLoader

        @trainer.on(Events.STARTED)
        def run_benchmark(_):
            if idist.get_rank() == 0:
                print("-" * 50)
                print(" - Dataflow benchmark")

            self.benchmark_dataflow.run(train_loader)
            t = self.timer.value()

            if idist.get_rank() == 0:
                print(" ")
                print(" Total time ({} iterations) : {:.5f} seconds".format(
                    self.num_iters, t))
                print(" time per iteration         : {} seconds".format(
                    t / self.num_iters))

                if isinstance(train_loader, DataLoader):
                    num_images = train_loader.batch_size * self.num_iters
                    print(" number of images / s       : {}".format(
                        num_images / t))

                print("-" * 50)
    def __init__(self,
                 model,
                 criterion,
                 optimizer,
                 lr_scheduler=None,
                 metrics=None,
                 test_metrics=None,
                 save_path=".",
                 name="Net"):

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.metrics = metrics or {}
        self.test_metrics = test_metrics
        if test_metrics is None:
            self.test_metrics = metrics.copy()
            if 'loss' in metrics and isinstance(metrics['loss'], TrainLoss):
                self.test_metrics['loss'] = Loss(criterion=criterion)
        self.save_path = os.path.join(save_path, 'trainer')
        self.name = name

        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        log_dir = os.path.join(save_path, 'runs', self.name, current_time)
        self.writer = SummaryWriter(log_dir)

        self.metric_history = defaultdict(list)
        self.device = 'cuda' if CUDA else 'cpu'
        self._timer = Timer()
        self._epochs = 0

        self.model.to(self.device)
Ejemplo n.º 9
0
def decorate_trainer(trainer, evaluator, scheduler, trainloader, valloader,
                     sampler, workspace_dir, logfilename, logterm,
                     max_epochs, checkpointer, model):
    trainer.setup_logger('trainer', workspace_dir, get_rank(), logfilename)
    trainer.setup_metric_logger()
    evaluator.setup_logger('evaluator', workspace_dir, get_rank(), logfilename)

    timer = Timer(average=True)
    timer.attach(trainer, step=E.ITERATION_COMPLETED)

    @trainer.on(E.EPOCH_STARTED)
    def start_epoch(engine):
        epoch = engine.state.epoch
        engine.logger.info(f'start training epoch {epoch}')
        if sampler is not None:
            sampler.set_epoch(epoch)

    trainer.add_event_handler(E.ITERATION_COMPLETED, scheduler)

    @trainer.on(E.ITERATION_COMPLETED)
    def log_results(engine):
        engine.metric_logger.update(**engine.state.output)
        global_iter = engine.state.iteration
        if global_iter % logterm == 0:
            epoch = engine.state.epoch
            iter_per_epoch = len(trainloader)
            local_iter = global_iter - (engine.state.epoch-1) * iter_per_epoch
            iter_remain = max_epochs * iter_per_epoch - global_iter

            elapse = timer._elapsed()
            elapse = str(datetime.timedelta(seconds=int(elapse)))

            eta = iter_remain * timer.value()
            eta = str(datetime.timedelta(seconds=int(eta)))

            logstr = f'elapse: {elapse} eta: {eta} '
            logstr += f'Epoch [{epoch}/{max_epochs}] '
            logstr += f'[{local_iter}/{iter_per_epoch}] ({global_iter}) '
            logstr += f'lr: {engine.optimizer.param_groups[0]["lr"]:.3e} '
            logstr += str(engine.metric_logger)
            engine.logger.info(logstr)

    if evaluator is not None and valloader is not None:
        @trainer.on(E.EPOCH_COMPLETED)
        def validate(engine):
            epoch = engine.state.epoch
            evaluator.logger.info(f'start evaluation epoch {epoch}')
            evaluator.run()

            result = evaluator.state.metrics
            fmtstr = f'Epoch {epoch} validation result: '
            for k, v in result.items():
                fmtstr += f'{k}: {v:.4f} '
            evaluator.logger.info(fmtstr)

    if get_rank() == 0:
        trainer.add_event_handler(
            E.EPOCH_COMPLETED, checkpointer, {'epoch', model}
        )
Ejemplo n.º 10
0
def main(dataset_path, batch_size=256, max_epochs=10):
    assert torch.cuda.is_available()
    assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled."
    torch.backends.cudnn.benchmark = True

    device = "cuda"

    train_loader, test_loader, eval_train_loader = get_train_eval_loaders(
        dataset_path, batch_size=batch_size)

    model = wide_resnet50_2(num_classes=100).to(device)
    optimizer = SGD(model.parameters(), lr=0.01)
    criterion = CrossEntropyLoss().to(device)

    def train_step(engine, batch):
        x = convert_tensor(batch[0], device, non_blocking=True)
        y = convert_tensor(batch[1], device, non_blocking=True)

        optimizer.zero_grad()

        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()

        optimizer.step()

        return loss.item()

    trainer = Engine(train_step)
    timer = Timer(average=True)
    timer.attach(trainer, step=Events.EPOCH_COMPLETED)
    ProgressBar(persist=True).attach(
        trainer, output_transform=lambda out: {"batch loss": out})

    metrics = {"Accuracy": Accuracy(), "Loss": Loss(criterion)}

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)

    def log_metrics(engine, title):
        for name in metrics:
            print("\t{} {}: {:.2f}".format(title, name,
                                           engine.state.metrics[name]))

    @trainer.on(Events.COMPLETED)
    def run_validation(_):
        print("- Mean elapsed time for 1 epoch: {}".format(timer.value()))
        print("- Metrics:")
        with evaluator.add_event_handler(Events.COMPLETED, log_metrics,
                                         "Train"):
            evaluator.run(eval_train_loader)

        with evaluator.add_event_handler(Events.COMPLETED, log_metrics,
                                         "Test"):
            evaluator.run(test_loader)

    trainer.run(train_loader, max_epochs=max_epochs)
Ejemplo n.º 11
0
def run(cfg, train_loader, tr_comp, saver, trainer, valid_dict):
    # TODO resume

    # trainer = Engine(...)
    # trainer.load_state_dict(state_dict)
    # trainer.run(data)
    # checkpoint
    handler = ModelCheckpoint(saver.model_dir, 'train', n_saved=3, create_dir=True)
    checkpoint_params = tr_comp.state_dict()
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              handler,
                              checkpoint_params)

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)
    # average metric to attach on trainer
    names = ["Acc", "Loss"]
    names.extend(tr_comp.loss_function_map.keys())
    for n in names:
        RunningAverage(output_transform=Run(n)).attach(trainer, n)

    @trainer.on(Events.EPOCH_COMPLETED)
    def adjust_learning_rate(engine):
        tr_comp.scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED(every=cfg.TRAIN.LOG_ITER_PERIOD))
    def log_training_loss(engine):
        message = f"Epoch[{engine.state.epoch}], " + \
                  f"Iteration[{engine.state.iteration}/{len(train_loader)}], " + \
                  f"Base Lr: {tr_comp.scheduler.get_last_lr()[0]:.2e}, "

        for loss_name in engine.state.metrics.keys():
            message += f"{loss_name}: {engine.state.metrics[loss_name]:.4f}, "

        if tr_comp.xent and tr_comp.xent.learning_weight:
            message += f"xentWeight: {tr_comp.xent.uncertainty.mean().item():.4f}, "

        logger.info(message)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
                    .format(engine.state.epoch, timer.value() * timer.step_count,
                            train_loader.batch_size / timer.value()))
        logger.info('-' * 80)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED(every=cfg.EVAL.EPOCH_PERIOD))
    def log_validation_results(engine):
        logger.info(f"Valid - Epoch: {engine.state.epoch}")
        eval_multi_dataset(cfg, valid_dict, tr_comp)

    trainer.run(train_loader, max_epochs=cfg.TRAIN.MAX_EPOCHS)
Ejemplo n.º 12
0
 def _start_timer(self):
     timer = Timer(average=True)
     timer.attach(
         self._trainer,
         start=Events.EPOCH_STARTED,
         resume=Events.ITERATION_STARTED,
         pause=Events.ITERATION_COMPLETED,
     )
     return timer
Ejemplo n.º 13
0
def engine_eval_geomreg(cfg, mode):
    prepare_config_eval(cfg)

    ckpt_path = cfg.eval.general.ckpt_path
    gpu = cfg.general.gpu
    root_path = cfg.log.root_path
    seed = cfg.general.seed

    eu.redirect_stdout(root_path, 'eval_geomreg-{}'.format(mode))
    eu.print_config(cfg)

    eu.seed_random(seed)

    device = eu.get_device(gpu)

    dataloader = get_dataloader_eval_geomreg(cfg, mode)
    num_batches = len(dataloader)

    render_model, desc_model = get_models(cfg)
    render_model.to(device)
    render_model.eval_mode()
    render_model.print_params('render_model')
    desc_model.to(device)
    desc_model.eval_mode()
    desc_model.print_params('desc_model')

    assert eu.is_not_empty(ckpt_path)
    render_model.load(ckpt_path)
    desc_model.load(ckpt_path)

    engine = Engine(
        functools.partial(step_eval_geomreg,
                          render_model=render_model,
                          desc_model=desc_model,
                          device=device,
                          cfg=cfg))

    timer = Timer(average=True)
    timer.attach(engine,
                 start=Events.EPOCH_STARTED,
                 pause=Events.EPOCH_COMPLETED,
                 resume=Events.ITERATION_STARTED,
                 step=Events.ITERATION_COMPLETED)

    engine.add_event_handler(Events.ITERATION_COMPLETED,
                             eu.print_eval_log,
                             timer=timer,
                             num_batches=num_batches)

    engine.add_event_handler(Events.EXCEPTION_RAISED, eu.handle_exception)

    engine.run(dataloader, 1)

    return root_path
Ejemplo n.º 14
0
def timer_metric(engine, name='timer'):
    timer = Timer(average=True)
    timer.attach(engine,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    def handler(_engine):
        _engine.state.metrics[name] = timer.value()

    engine.add_event_handler(event_name=Events.ITERATION_COMPLETED,
                             handler=handler)
    def create_trainer(self):
        self.trainer = Engine(self.train_step)
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.K_step)
        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.log_metrics, 'train')
        self.trainer.add_event_handler(Events.ITERATION_COMPLETED, self.write_metrics, 'train')

        self.pbar = ProgressBar()

        self.timer = Timer(average=True)
        self.timer.attach(self.trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                          pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self.print_times)
Ejemplo n.º 16
0
 def _create_timer(self):
     """
     Create and attach a new timer to the trainer, registering callbacks.
     :return: the newly created timer
     :type: ignite.handlers.Timer
     """
     timer = Timer(average=True)
     timer.attach(self.trainer_engine,
                  start=Events.EPOCH_STARTED,
                  resume=Events.ITERATION_STARTED,
                  pause=Events.ITERATION_COMPLETED,
                  step=Events.ITERATION_COMPLETED)
     return timer
def attach_evaluator_events(evaluator, experiment_dir, data_set: str):

    # Timers initializations
    timer_iter = Timer()
    timer_iter.attach(evaluator,
                      start=Events.ITERATION_STARTED,
                      resume=Events.ITERATION_STARTED,
                      pause=Events.ITERATION_COMPLETED,
                      step=Events.ITERATION_COMPLETED)

    timer_epoch = Timer()
    timer_epoch.attach(evaluator,
                       start=Events.EPOCH_STARTED,
                       resume=Events.ITERATION_STARTED,
                       pause=Events.ITERATION_COMPLETED,
                       step=Events.ITERATION_COMPLETED)

    # Evaluator iteration events
    evaluator.add_event_handler(Events.ITERATION_COMPLETED,
                                log_eval_iter_screen,
                                timer=timer_iter)

    # Ending of evaluation events
    eval_dir = osp.join(experiment_dir, 'inference_results')
    evaluator.add_event_handler(Events.COMPLETED,
                                evaluate_model_without_training,
                                eval_dir=eval_dir,
                                timer=timer_epoch,
                                data_set=data_set)
Ejemplo n.º 18
0
def visdom_loss_handler(modules_dict, model_name):
    """
    Attaches plots and metrics to trainer.
    This handler creates or connects to an environment on a running Visdom dashboard and creates a line plot that tracks the loss function of a
    training loop as a function of the number of iterations. This can be attached to an Ignite Engine, and the training closure must
    have 'loss' as one of the keys in its return dict for this plot to be made.
    See documentation for Ignite (https://github.com/pytorch/ignite) and Visdom (https://github.com/facebookresearch/visdom) for more information.
    """

    tim = Timer()
    tim.attach(
        trainer,
        start=Events.STARTED,
        step=Events.ITERATION_COMPLETED,
    )

    vis = visdom.Visdom(env=environment)

    def create_plot_window(vis, xlabel, ylabel, title):
        return vis.line(X=np.array([1]),
                        Y=np.array([np.nan]),
                        opts=dict(xlabel=xlabel, ylabel=ylabel, title=title))

    train_loss_window = create_plot_window(vis, '#Iterations', 'Loss',
                                           description)
    log_interval = 10

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1)
        if iter % log_interval == 0:
            print("Epoch[{}] Iteration: {} Time: {} Loss: {:.2f}".format(
                engine.state.epoch, iter,
                str(datetime.timedelta(seconds=int(tim.value()))),
                engine.state.output))
        vis.line(X=np.array([engine.state.iteration]),
                 Y=np.array([engine.state.output]),
                 update='append',
                 win=train_loss_window)

    save_interval = 50
    handler = ModelCheckpoint('/tmp/models',
                              model_name,
                              save_interval=save_interval,
                              n_saved=5,
                              create_dir=True,
                              require_empty=False)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, handler,
                              modules_dict)
Ejemplo n.º 19
0
    def __init__(self,
                 G,
                 D,
                 criterionG,
                 criterionD,
                 optimizerG,
                 optimizerD,
                 lr_schedulerG=None,
                 lr_schedulerD=None,
                 make_latent=None,
                 metrics=None,
                 save_path=".",
                 name="GAN",
                 gan_type='gan'):

        self.G = G
        self.D = D
        self.criterionG = criterionG
        self.criterionD = criterionD
        self.optimizerG = optimizerG
        self.optimizerD = optimizerD
        self.lr_schedulerG = lr_schedulerG
        self.lr_schedulerD = lr_schedulerD
        self.make_latent = make_latent
        self.metrics = metrics or {}
        self.name = name
        root = Path(save_path).expanduser().absolute()
        self.save_path = root / 'gan_trainer' / self.name

        self.metric_history = defaultdict(list)
        self.device = 'cuda' if CUDA else 'cpu'
        self._timer = Timer()
        self._iterations = 0

        self.G.to(self.device)
        self.D.to(self.device)

        assert gan_type in ['gan', 'acgan', 'cgan', 'infogan']
        if gan_type == 'gan':
            self.create_fn = create_gan_trainer
        elif gan_type == 'acgan':
            self.create_fn = create_acgan_trainer
        elif gan_type == 'cgan':
            self.create_fn = create_cgan_trainer
        elif gan_type == 'infogan':
            self.create_fn = create_infogan_trainer
Ejemplo n.º 20
0
    def __init__(self,
                 model,
                 criterion,
                 optimizer,
                 lr_scheduler=None,
                 metrics=None,
                 test_metrics=None,
                 save_path=".",
                 name="Net",
                 fp16=False,
                 lr_step_on_iter=None):

        self.fp16 = fp16
        self.device = 'cuda' if CUDA else 'cpu'
        model.to(self.device)
        if self.fp16:
            from apex import amp
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level="O1",
                                              verbosity=0)

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.metrics = metrics or {}
        self.test_metrics = test_metrics
        if test_metrics is None:
            self.test_metrics = metrics.copy()
            if 'loss' in metrics and isinstance(metrics['loss'], TrainLoss):
                self.test_metrics['loss'] = Loss(criterion=criterion)
        self.save_path = os.path.join(save_path, 'trainer')
        self.name = name
        self.lr_step_on_iter = lr_step_on_iter

        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        log_dir = os.path.join(save_path, 'runs', self.name, current_time)
        self.writer = SummaryWriter(log_dir)

        self.metric_history = defaultdict(list)
        self._timer = Timer()
        self._epochs = 0

        self._verbose = True
Ejemplo n.º 21
0
def warp_common_handler(engine, option, networks_to_save, monitoring_metrics,
                        add_message, use_folder_pathes):
    # attach progress bar
    pbar = ProgressBar()
    pbar.attach(engine, metric_names=monitoring_metrics)
    timer = Timer(average=True)
    timer.attach(engine,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)
    create_plots = make_handle_create_plots(option.output_dir, LOGS_FNAME,
                                            PLOT_FNAME)
    checkpoint_handler = ModelCheckpoint(option.output_dir,
                                         CKPT_PREFIX,
                                         save_interval=option.save_interval,
                                         n_saved=option.n_saved,
                                         require_empty=False,
                                         create_dir=True,
                                         save_as_state_dict=True)

    engine.add_event_handler(Events.ITERATION_COMPLETED,
                             checkpoint_handler,
                             to_save=networks_to_save)
    engine.add_event_handler(Events.ITERATION_COMPLETED, create_plots)
    engine.add_event_handler(
        Events.EXCEPTION_RAISED,
        make_handle_handle_exception(checkpoint_handler, networks_to_save,
                                     create_plots))
    engine.add_event_handler(
        Events.STARTED,
        make_handle_make_dirs(option.output_dir, use_folder_pathes))
    engine.add_event_handler(Events.STARTED, make_move_html(option.output_dir))
    engine.add_event_handler(Events.STARTED, make_create_option_data(option))
    engine.add_event_handler(Events.EPOCH_COMPLETED,
                             make_handle_print_times(timer, pbar))
    engine.add_event_handler(
        Events.ITERATION_COMPLETED,
        make_handle_print_logs(option.output_dir, option.epochs,
                               option.print_freq, pbar, add_message))
    return engine
    def __init__(self):
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = None
        self.processing_times = None
        self.event_handlers_times = None

        self._events = [
            Events.EPOCH_STARTED, Events.EPOCH_COMPLETED,
            Events.ITERATION_STARTED, Events.ITERATION_COMPLETED,
            Events.GET_BATCH_STARTED, Events.GET_BATCH_COMPLETED,
            Events.COMPLETED
        ]
        self._fmethods = [
            self._as_first_epoch_started, self._as_first_epoch_completed,
            self._as_first_iter_started, self._as_first_iter_completed,
            self._as_first_get_batch_started,
            self._as_first_get_batch_completed, self._as_first_completed
        ]
        self._lmethods = [
            self._as_last_epoch_started, self._as_last_epoch_completed,
            self._as_last_iter_started, self._as_last_iter_completed,
            self._as_last_get_batch_started, self._as_last_get_batch_completed,
            self._as_last_completed
        ]
Ejemplo n.º 23
0
    def __init__(self):
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = None
        self.processing_times = None
        self.event_handlers_times = None
Ejemplo n.º 24
0
    def __init__(self) -> None:
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = []  # type: List[float]
        self.processing_times = []  # type: List[float]
        self.event_handlers_times = {}  # type: Dict[EventEnum, Dict[str, List[float]]]
Ejemplo n.º 25
0
    def __init__(self,
                 G,
                 D,
                 criterionG,
                 criterionD,
                 optimizerG,
                 optimizerD,
                 lr_schedulerG=None,
                 lr_schedulerD=None,
                 metrics={},
                 device=None,
                 save_path=".",
                 name="Net"):

        self.G = G
        self.D = D
        self.criterionG = criterionG
        self.criterionD = criterionD
        self.optimizerG = optimizerG
        self.optimizerD = optimizerD
        self.lr_schedulerG = lr_schedulerG
        self.lr_schedulerD = lr_schedulerD
        self.metrics = metrics
        self.device = device or ('cuda'
                                 if torch.cuda.is_available() else 'cpu')
        self.save_path = save_path
        self.name = name

        self.metric_history = defaultdict(list)
        self._print_callbacks = set([lambda msg: print(msg, end='')])
        self._weixin_logined = False
        self._timer = Timer()
        self._epochs = 0

        self.G.to(self.device)
        self.D.to(self.device)
Ejemplo n.º 26
0
    def __init__(self) -> None:
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

        self.dataflow_times = torch.zeros(1)
        self.processing_times = torch.zeros(1)
        self.event_handlers_times = {}  # type: Dict[EventEnum, torch.Tensor]

        self._events = [
            Events.EPOCH_STARTED,
            Events.EPOCH_COMPLETED,
            Events.ITERATION_STARTED,
            Events.ITERATION_COMPLETED,
            Events.GET_BATCH_STARTED,
            Events.GET_BATCH_COMPLETED,
            Events.COMPLETED,
        ]
        self._fmethods = [
            self._as_first_epoch_started,
            self._as_first_epoch_completed,
            self._as_first_iter_started,
            self._as_first_iter_completed,
            self._as_first_get_batch_started,
            self._as_first_get_batch_completed,
            self._as_first_completed,
        ]
        self._lmethods = [
            self._as_last_epoch_started,
            self._as_last_epoch_completed,
            self._as_last_iter_started,
            self._as_last_iter_completed,
            self._as_last_get_batch_started,
            self._as_last_get_batch_completed,
            self._as_last_completed,
        ]
Ejemplo n.º 27
0
# loss function
loss_fn = nn.CrossEntropyLoss()

# optimizer
optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=0.9, weight_decay=5e-4)

# scheduler
scheduler = CosineAnnealingScheduler(optimizer, 'lr', init_lr, end_lr, 4*len(trainloader), cycle_mult=1.5, start_value_mult=0.1)
scheduler = create_lr_scheduler_with_warmup(scheduler, warmup_start_value=0., warmup_end_value=init_lr, warmup_duration=len(trainloader))

# create trainer
trainer = create_trainer(model, optimizer, loss_fn, device=device)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

# add timer for each iteration
timer = Timer(average=False)

# logging training loss
def log_loss(engine):
    i = engine.state.iteration
    e = engine.state.epoch

    if i % 100 == 0:
        print('[Iters {:0>7d}/{:0>2d}, {:.2f}s/100 iters, lr={:.4E}] loss={:.4f}'.format(i, e, timer.value(), optimizer.param_groups[0]['lr'], engine.state.output))
        timer.reset()
trainer.add_event_handler(Events.ITERATION_COMPLETED, log_loss)

# Evaluation
metrics = {
    'loss': Loss(loss_fn),
    'acc': Accuracy()
Ejemplo n.º 28
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers,
         cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer,
         warmup, fresh, logittransform, gan, disc_lr, sn, flowgan, eval_every,
         ld_on_samples, weight_gan, weight_prior, weight_logdet,
         jac_reg_lambda, affine_eps, no_warm_up, optim_name, clamp, svd_every,
         eval_only, no_actnorm, affine_scale_eps, actnorm_max_scale,
         no_conv_actnorm, affine_max_scale, actnorm_eps, init_sample, no_split,
         disc_arch, weight_entropy_reg, db):

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset,
                                  batch_size=eval_batch_size,
                                  shuffle=False,
                                  num_workers=n_workers,
                                  drop_last=False)
    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition, logittransform, sn, affine_eps,
                 no_actnorm, affine_scale_eps, actnorm_max_scale,
                 no_conv_actnorm, affine_max_scale, actnorm_eps, no_split)

    model = model.to(device)

    if disc_arch == 'mine':
        discriminator = mine.Discriminator(image_shape[-1])
    elif disc_arch == 'biggan':
        discriminator = cgan_models.Discriminator(
            image_channels=image_shape[-1], conditional_D=False)
    elif disc_arch == 'dcgan':
        discriminator = DCGANDiscriminator(image_shape[0], 64, image_shape[-1])
    elif disc_arch == 'inv':
        discriminator = InvDiscriminator(
            image_shape, hidden_channels, K, L, actnorm_scale,
            flow_permutation, flow_coupling, LU_decomposed, num_classes,
            learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm,
            affine_scale_eps, actnorm_max_scale, no_conv_actnorm,
            affine_max_scale, actnorm_eps, no_split)

    discriminator = discriminator.to(device)
    D_optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                    discriminator.parameters()),
                             lr=disc_lr,
                             betas=(.5, .99),
                             weight_decay=0)
    if optim_name == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               betas=(.5, .99),
                               weight_decay=0)
    elif optim_name == 'adamax':
        optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    if not no_warm_up:
        lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lr_lambda=lr_lambda)

    iteration_fieldnames = [
        'global_iteration', 'fid', 'sample_pad', 'train_bpd', 'eval_bpd',
        'pad', 'batch_real_acc', 'batch_fake_acc', 'batch_acc'
    ]
    iteration_logger = CSVLogger(fieldnames=iteration_fieldnames,
                                 filename=os.path.join(output_dir,
                                                       'iteration_log.csv'))
    iteration_fieldnames = [
        'global_iteration', 'condition_num', 'max_sv', 'min_sv',
        'inverse_condition_num', 'inverse_max_sv', 'inverse_min_sv'
    ]
    svd_logger = CSVLogger(fieldnames=iteration_fieldnames,
                           filename=os.path.join(output_dir, 'svd_log.csv'))

    #
    test_iter = test_loader.__iter__()
    N_inception = 1000
    x_real_inception = torch.cat([
        test_iter.__next__()[0].to(device)
        for _ in range(N_inception // args.batch_size + 1)
    ], 0)[:N_inception]
    x_real_inception = x_real_inception + .5
    x_for_recon = test_iter.__next__()[0].to(device)

    def gan_step(engine, batch):
        assert not y_condition
        if 'iter_ind' in dir(engine):
            engine.iter_ind += 1
        else:
            engine.iter_ind = -1
        losses = {}
        model.train()
        discriminator.train()

        x, y = batch
        x = x.to(device)

        def run_noised_disc(discriminator, x):
            x = uniform_binning_correction(x)[0]
            return discriminator(x)

        real_acc = fake_acc = acc = 0
        if weight_gan > 0:
            fake = generate_from_noise(model, x.size(0), clamp=clamp)

            D_real_scores = run_noised_disc(discriminator, x.detach())
            D_fake_scores = run_noised_disc(discriminator, fake.detach())

            ones_target = torch.ones((x.size(0), 1), device=x.device)
            zeros_target = torch.zeros((x.size(0), 1), device=x.device)

            D_real_accuracy = torch.sum(
                torch.round(F.sigmoid(D_real_scores)) ==
                ones_target).float() / ones_target.size(0)
            D_fake_accuracy = torch.sum(
                torch.round(F.sigmoid(D_fake_scores)) ==
                zeros_target).float() / zeros_target.size(0)

            D_real_loss = F.binary_cross_entropy_with_logits(
                D_real_scores, ones_target)
            D_fake_loss = F.binary_cross_entropy_with_logits(
                D_fake_scores, zeros_target)

            D_loss = (D_real_loss + D_fake_loss) / 2
            gp = gradient_penalty(
                x.detach(), fake.detach(),
                lambda _x: run_noised_disc(discriminator, _x))
            D_loss_plus_gp = D_loss + 10 * gp
            D_optimizer.zero_grad()
            D_loss_plus_gp.backward()
            D_optimizer.step()

            # Train generator
            fake = generate_from_noise(model,
                                       x.size(0),
                                       clamp=clamp,
                                       guard_nans=False)
            G_loss = F.binary_cross_entropy_with_logits(
                run_noised_disc(discriminator, fake),
                torch.ones((x.size(0), 1), device=x.device))

            # Trace
            real_acc = D_real_accuracy.item()
            fake_acc = D_fake_accuracy.item()
            acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item())

        z, nll, y_logits, (prior, logdet) = model.forward(x,
                                                          None,
                                                          return_details=True)
        train_bpd = nll.mean().item()

        loss = 0
        if weight_gan > 0:
            loss = loss + weight_gan * G_loss
        if weight_prior > 0:
            loss = loss + weight_prior * -prior.mean()
        if weight_logdet > 0:
            loss = loss + weight_logdet * -logdet.mean()

        if weight_entropy_reg > 0:
            _, _, _, (sample_prior,
                      sample_logdet) = model.forward(fake,
                                                     None,
                                                     return_details=True)
            # notice this is actually "decreasing" sample likelihood.
            loss = loss + weight_entropy_reg * (sample_prior.mean() +
                                                sample_logdet.mean())
        # Jac Reg
        if jac_reg_lambda > 0:
            # Sample
            x_samples = generate_from_noise(model,
                                            args.batch_size,
                                            clamp=clamp).detach()
            x_samples.requires_grad_()
            z = model.forward(x_samples, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            sample_foward_jac = compute_jacobian_regularizer(x_samples,
                                                             all_z,
                                                             n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            randz = torch.randn(zshape).to(device)
            randz = torch.autograd.Variable(randz, requires_grad=True)
            images = model(z=randz,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [randz] + other_zs
            sample_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # Data
            x.requires_grad_()
            z = model.forward(x, None, return_details=True)[0]
            other_zs = torch.cat([
                split._last_z2.view(x.size(0), -1)
                for split in model.flow.splits
            ], -1)
            all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1)
            data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1)
            _, c2, h, w = model.prior_h.shape
            c = c2 // 2
            zshape = (batch_size, c, h, w)
            z.requires_grad_()
            images = model(z=z,
                           y_onehot=None,
                           temperature=1,
                           reverse=True,
                           batch_size=0)
            other_zs = [split._last_z2 for split in model.flow.splits]
            all_z = [z] + other_zs
            data_inverse_jac = compute_jacobian_regularizer_manyinputs(
                all_z, images, n_proj=1)

            # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac )
            loss = loss + jac_reg_lambda * (sample_foward_jac +
                                            sample_inverse_jac +
                                            data_foward_jac + data_inverse_jac)

        if not eval_only:
            optimizer.zero_grad()
            loss.backward()
            if not db:
                assert max_grad_clip == max_grad_norm == 0
            if max_grad_clip > 0:
                torch.nn.utils.clip_grad_value_(model.parameters(),
                                                max_grad_clip)
            if max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               max_grad_norm)

            # Replace NaN gradient with 0
            for p in model.parameters():
                if p.requires_grad and p.grad is not None:
                    g = p.grad.data
                    g[g != g] = 0

            optimizer.step()

        if engine.iter_ind % 100 == 0:
            with torch.no_grad():
                fake = generate_from_noise(model, x.size(0), clamp=clamp)
                z = model.forward(fake, None, return_details=True)[0]
            print("Z max min")
            print(z.max().item(), z.min().item())
            if (fake != fake).float().sum() > 0:
                title = 'NaNs'
            else:
                title = "Good"
            grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]),
                             nrow=6).permute(1, 2, 0)
            plt.figure(figsize=(10, 10))
            plt.imshow(grid)
            plt.axis('off')
            plt.title(title)
            plt.savefig(
                os.path.join(output_dir, f'sample_{engine.iter_ind}.png'))

        if engine.iter_ind % eval_every == 0:

            def check_all_zero_except_leading(x):
                return x % 10**np.floor(np.log10(x)) == 0

            if engine.iter_ind == 0 or check_all_zero_except_leading(
                    engine.iter_ind):
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt'))

            model.eval()

            with torch.no_grad():
                # Plot recon
                fpath = os.path.join(output_dir, '_recon',
                                     f'recon_{engine.iter_ind}.png')
                sample_pad = run_recon_evolution(
                    model,
                    generate_from_noise(model, args.batch_size,
                                        clamp=clamp).detach(), fpath)
                print(
                    f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}")

                pad = run_recon_evolution(model, x_for_recon, fpath)
                print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}")
                pad = pad.item()
                sample_pad = sample_pad.item()

                # Inception score
                sample = torch.cat([
                    generate_from_noise(model, args.batch_size, clamp=clamp)
                    for _ in range(N_inception // args.batch_size + 1)
                ], 0)[:N_inception]
                sample = sample + .5

                if (sample != sample).float().sum() > 0:
                    print("Sample NaNs")
                    raise
                else:
                    fid = run_fid(x_real_inception.clamp_(0, 1),
                                  sample.clamp_(0, 1))
                    print(f'fid: {fid}, global_iter: {engine.iter_ind}')

                # Eval BPD
                eval_bpd = np.mean([
                    model.forward(x.to(device), None,
                                  return_details=True)[1].mean().item()
                    for x, _ in test_loader
                ])

                stats_dict = {
                    'global_iteration': engine.iter_ind,
                    'fid': fid,
                    'train_bpd': train_bpd,
                    'pad': pad,
                    'eval_bpd': eval_bpd,
                    'sample_pad': sample_pad,
                    'batch_real_acc': real_acc,
                    'batch_fake_acc': fake_acc,
                    'batch_acc': acc
                }
                iteration_logger.writerow(stats_dict)
                plot_csv(iteration_logger.filename)
            model.train()

        if engine.iter_ind + 2 % svd_every == 0:
            model.eval()
            svd_dict = {}
            ret = utils.computeSVDjacobian(x_for_recon, model)
            D_for, D_inv = ret['D_for'], ret['D_inv']
            cn = float(D_for.max() / D_for.min())
            cn_inv = float(D_inv.max() / D_inv.min())
            svd_dict['global_iteration'] = engine.iter_ind
            svd_dict['condition_num'] = cn
            svd_dict['max_sv'] = float(D_for.max())
            svd_dict['min_sv'] = float(D_for.min())
            svd_dict['inverse_condition_num'] = cn_inv
            svd_dict['inverse_max_sv'] = float(D_inv.max())
            svd_dict['inverse_min_sv'] = float(D_inv.min())
            svd_logger.writerow(svd_dict)
            # plot_utils.plot_stability_stats(output_dir)
            # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv')
            model.train()
            if eval_only:
                sys.exit()

        # Dummy
        losses['total_loss'] = torch.mean(nll).item()
        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(gan_step)
    # else:
    #     trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         'glow',
                                         save_interval=5,
                                         n_saved=1,
                                         require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
        'model': model,
        'optimizer': optimizer
    })

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(
        trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x),
         output_transform=lambda x:
         (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(
             evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(
            trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x),
             output_transform=lambda x:
             (x['nll'], torch.empty(x['nll'].shape[0]))).attach(
                 evaluator, 'nll')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        print("Loading...")
        print(saved_model)
        loaded = torch.load(saved_model)
        # if 'Glow' in str(type(loaded)):
        #     model  = loaded
        # else:
        #     raise
        # # if 'Glow' in str(type(loaded)):
        # #     loaded  = loaded.state_dict()
        model.load_state_dict(loaded)
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        if saved_model:
            return
        model.train()
        print("Initializing Actnorm...")
        init_batches = []
        init_targets = []

        if n_init_batches == 0:
            model.set_actnorm_init()
            return
        with torch.no_grad():
            if init_sample:
                generate_from_noise(model,
                                    args.batch_size * args.n_init_batches)
            else:
                for batch, target in islice(train_loader, None,
                                            n_init_batches):
                    init_batches.append(batch)
                    init_targets.append(target)

                init_batches = torch.cat(init_batches).to(device)

                assert init_batches.shape[0] == n_init_batches * batch_size

                if y_condition:
                    init_targets = torch.cat(init_targets).to(device)
                else:
                    init_targets = None

                model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)
        if not no_warm_up:
            scheduler.step()
        metrics = evaluator.state.metrics

        losses = ', '.join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]'
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Ejemplo n.º 29
0
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             loss_fn, num_query, start_epoch):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger("reid_baseline.train")
    logger.info("Start training")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        cfg=cfg,
                                        device=device)
    evaluator = create_supervised_evaluator(
        model,
        metrics={
            'r1_mAP': R1_mAP(num_query,
                             max_rank=50,
                             feat_norm=cfg.TEST.FEAT_NORM)
        },
        device=device)
    checkpointer = ModelCheckpoint(output_dir,
                                   cfg.MODEL.NAME,
                                   checkpoint_period,
                                   n_saved=epochs,
                                   require_empty=False,
                                   start_iter=start_epoch)
    timer = Timer(average=True)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {
        'model': model,
        'optimizer': optimizer
    })
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # average metric to attach on trainer
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')

    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch
        engine.state.total_iteration = 0

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        scheduler.step()
        engine.state.iteration = 0

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_period == 0:
            logger.info(
                "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                .format(engine.state.epoch, iter, len(train_loader),
                        engine.state.metrics['avg_loss'],
                        engine.state.metrics['avg_acc'],
                        scheduler.get_lr()[0]))

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
            .format(engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @evaluator.on(Events.ITERATION_COMPLETED)
    def log_evaluate_extract_features(engine):
        iter = (engine.state.iteration - 1) % len(val_loader) + 1
        if iter % log_period == 0:
            logger.info("Extract Features Iteration[{}/{}]".format(
                iter, len(val_loader)))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        if engine.state.epoch % eval_period == 0 or engine.state.epoch > 120:
            evaluator.run(val_loader)
            cmc, mAP = evaluator.state.metrics['r1_mAP']
            logger.info("Validation Results - Epoch: {}".format(
                engine.state.epoch))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))

    trainer.run(train_loader, max_epochs=epochs)
Ejemplo n.º 30
0
def run(args):
    train_loader, val_loader = get_data_loaders(args.dir, args.batch_size,
                                                args.num_workers)

    if args.seed is not None:
        torch.manual_seed(args.seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    num_classes = CityscapesDataset.num_instance_classes() + 1
    model = models.box2pix(num_classes=num_classes)
    model.init_from_googlenet()

    writer = create_summary_writer(model, train_loader, args.log_dir)

    if torch.cuda.device_count() > 1:
        print("Using %d GPU(s)" % torch.cuda.device_count())
        model = nn.DataParallel(model)

    model = model.to(device)

    semantics_criterion = nn.CrossEntropyLoss(ignore_index=255)
    offsets_criterion = nn.MSELoss()
    box_criterion = BoxLoss(num_classes, gamma=2)
    multitask_criterion = MultiTaskLoss().to(device)

    box_coder = BoxCoder()
    optimizer = optim.Adam([{
        'params': model.parameters(),
        'weight_decay': 5e-4
    }, {
        'params': multitask_criterion.parameters()
    }],
                           lr=args.lr)

    if args.resume:
        if os.path.isfile(args.resume):
            print("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            multitask_criterion.load_state_dict(checkpoint['multitask'])
            print("Loaded checkpoint '{}' (Epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    def _prepare_batch(batch, non_blocking=True):
        x, instance, boxes, labels = batch

        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(instance,
                               device=device,
                               non_blocking=non_blocking),
                convert_tensor(boxes, device=device,
                               non_blocking=non_blocking),
                convert_tensor(labels,
                               device=device,
                               non_blocking=non_blocking))

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, instance, boxes, labels = _prepare_batch(batch)
        boxes, labels = box_coder.encode(boxes, labels)

        loc_preds, conf_preds, semantics_pred, offsets_pred = model(x)

        semantics_loss = semantics_criterion(semantics_pred, instance)
        offsets_loss = offsets_criterion(offsets_pred, instance)
        box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds,
                                            labels)

        loss = multitask_criterion(semantics_loss, offsets_loss, box_loss,
                                   conf_loss)

        loss.backward()
        optimizer.step()

        return {
            'loss': loss.item(),
            'loss_semantics': semantics_loss.item(),
            'loss_offsets': offsets_loss.item(),
            'loss_ssdbox': box_loss.item(),
            'loss_ssdclass': conf_loss.item()
        }

    trainer = Engine(_update)

    checkpoint_handler = ModelCheckpoint(args.output_dir,
                                         'checkpoint',
                                         save_interval=1,
                                         n_saved=10,
                                         require_empty=False,
                                         create_dir=True,
                                         save_as_state_dict=False)
    timer = Timer(average=True)

    # attach running average metrics
    train_metrics = [
        'loss', 'loss_semantics', 'loss_offsets', 'loss_ssdbox',
        'loss_ssdclass'
    ]
    for m in train_metrics:
        transform = partial(lambda x, metric: x[metric], metric=m)
        RunningAverage(output_transform=transform).attach(trainer, m)

    # attach progress bar
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=train_metrics)

    checkpoint = {
        'model': model.state_dict(),
        'epoch': trainer.state.epoch,
        'optimizer': optimizer.state_dict(),
        'multitask': multitask_criterion.state_dict()
    }
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={'checkpoint': checkpoint})

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            x, instance, boxes, labels = _prepare_batch(batch)
            loc_preds, conf_preds, semantics, offsets_pred = model(x)
            boxes_preds, labels_preds, scores_preds = box_coder.decode(
                loc_preds, F.softmax(conf_preds, dim=1), score_thresh=0.01)

            semantics_loss = semantics_criterion(semantics, instance)
            offsets_loss = offsets_criterion(offsets_pred, instance)
            box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds,
                                                labels)

            semantics_pred = semantics.argmax(dim=1)
            instances = helper.assign_pix2box(semantics_pred, offsets_pred,
                                              boxes_preds, labels_preds)

        return {
            'loss': (semantics_loss, offsets_loss, {
                'box_loss': box_loss,
                'conf_loss': conf_loss
            }),
            'objects':
            (boxes_preds, labels_preds, scores_preds, boxes, labels),
            'semantics':
            semantics_pred,
            'instances':
            instances
        }

    train_evaluator = Engine(_inference)
    Loss(multitask_criterion,
         output_transform=lambda x: x['loss']).attach(train_evaluator, 'loss')
    MeanAveragePrecision(num_classes,
                         output_transform=lambda x: x['objects']).attach(
                             train_evaluator, 'objects')
    IntersectionOverUnion(num_classes,
                          output_transform=lambda x: x['semantics']).attach(
                              train_evaluator, 'semantics')

    evaluator = Engine(_inference)
    Loss(multitask_criterion,
         output_transform=lambda x: x['loss']).attach(evaluator, 'loss')
    MeanAveragePrecision(num_classes,
                         output_transform=lambda x: x['objects']).attach(
                             evaluator, 'objects')
    IntersectionOverUnion(num_classes,
                          output_transform=lambda x: x['semantics']).attach(
                              evaluator, 'semantics')

    @trainer.on(Events.STARTED)
    def initialize(engine):
        if args.resume:
            engine.state.epoch = args.start_epoch

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            "Epoch [{}/{}] done. Time per batch: {:.3f}[s]".format(
                engine.state.epoch, engine.state.max_epochs, timer.value()))
        timer.reset()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iteration = (engine.state.iteration - 1) % len(train_loader) + 1
        if iteration % args.log_interval == 0:
            writer.add_scalar("training/loss", engine.state.output['loss'],
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        loss = metrics['loss']
        mean_ap = metrics['objects']
        iou = metrics['semantics']

        pbar.log_message(
            'Training results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}'
            .format(loss, evaluator.state.epochs, evaluator.state.max_epochs,
                    mean_ap, iou * 100.0))

        writer.add_scalar("train-val/loss", loss, engine.state.epoch)
        writer.add_scalar("train-val/mAP", mean_ap, engine.state.epoch)
        writer.add_scalar("train-val/IoU", iou, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        loss = metrics['loss']
        mean_ap = metrics['objects']
        iou = metrics['semantics']

        pbar.log_message(
            'Validation results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}'
            .format(loss, evaluator.state.epochs, evaluator.state.max_epochs,
                    mean_ap, iou * 100.0))

        writer.add_scalar("validation/loss", loss, engine.state.epoch)
        writer.add_scalar("validation/mAP", mean_ap, engine.state.epoch)
        writer.add_scalar("validation/IoU", iou, engine.state.epoch)

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            checkpoint_handler(engine, {'model_exception': model})
        else:
            raise e

    @trainer.on(Events.COMPLETED)
    def save_final_model(engine):
        checkpoint_handler(engine, {'final': model})

    trainer.run(train_loader, max_epochs=args.epochs)
    writer.close()