Exemple #1
0
def test_lr_scheduler_asserts():

    with pytest.raises(TypeError):
        LRScheduler(123)

    with pytest.raises(TypeError):
        LRScheduler.simulate_values(1, None)
Exemple #2
0
def test_lr_scheduler_asserts():

    t1 = torch.zeros([1], requires_grad=True)
    t2 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([
        {
            "params": t1,
            'lr': 0.1
        },
        {
            "params": t2,
            'lr': 0.1
        },
    ])
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,
                                                          gamma=0.98)

    with pytest.raises(ValueError):
        LRScheduler(lr_scheduler)

    with pytest.raises(ValueError):
        LRScheduler.simulate_values(num_events=100, lr_scheduler=lr_scheduler)

    with pytest.raises(TypeError):
        LRScheduler(123)
Exemple #3
0
def test_lr_scheduler_asserts():

    with pytest.raises(
        TypeError, match=r"Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler"
    ):
        LRScheduler(123)

    with pytest.raises(
        TypeError, match=r"Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler"
    ):
        LRScheduler.simulate_values(1, None)
Exemple #4
0
    def _test(torch_lr_scheduler_cls, **kwargs):

        tensor = torch.zeros([1], requires_grad=True)
        optimizer1 = torch.optim.SGD([tensor], lr=0.01)
        optimizer2 = torch.optim.SGD([tensor], lr=0.01)

        torch_lr_scheduler1 = torch_lr_scheduler_cls(optimizer=optimizer1,
                                                     **kwargs)
        torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2,
                                                     **kwargs)
        scheduler = LRScheduler(torch_lr_scheduler1)

        lrs = []
        lrs_true = []

        def dummy_update(engine, batch):
            optimizer1.step()
            optimizer2.step()

        trainer = Engine(dummy_update)

        @trainer.on(Events.ITERATION_STARTED)
        def save_lr(engine):
            lrs.append(optimizer1.param_groups[0]['lr'])

        @trainer.on(Events.ITERATION_STARTED)
        def save_true_lr(engine):
            lrs_true.append(optimizer2.param_groups[0]['lr'])

        @trainer.on(Events.ITERATION_COMPLETED)
        def torch_lr_scheduler_step(engine):
            torch_lr_scheduler2.step()

        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        data = [0] * 10
        max_epochs = 2
        trainer.run(data, max_epochs=max_epochs)

        assert lrs_true == pytest.approx(lrs)

        optimizer3 = torch.optim.SGD([tensor], lr=0.01)
        torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3,
                                                     **kwargs)

        simulated_values = LRScheduler.simulate_values(
            num_events=len(data) * max_epochs,
            lr_scheduler=torch_lr_scheduler3)
        assert lrs == pytest.approx([v for i, v in simulated_values])
Exemple #5
0
    def run(self):
        self.load_checkpoints()

        if self.mode == l.TRAIN:
            if self.checkpoint_dir is not None:
                self.bind_checkpoints()

            if self.log_dir is not None:
                self.writer = SummaryWriter(logdir=self.log_dir)

            lr_schedule = self.experiment_config.get(LR_SCHEDULE)
            if lr_schedule is not None:
                lr_decay = self.experiment_config[LR_DECAY]
                lr_scheduler = MultiStepLR(self.optimizer_chain.optimizer,
                                           lr_schedule, lr_decay)

                self.loops[0].engine.add_event_handler(
                    Events.EPOCH_COMPLETED, LRScheduler(lr_scheduler))

        output = self.loops[0].run(self.experiment_config[NUM_EPOCHS],
                                   self.experiment_config[RETURN_OUTPUT])

        if self.writer is not None:
            self.writer.close()

        return output
def attach_exponential_decay_lr(trainer, optimizer, config):

    scheduler = LRScheduler(
        StepLR(optimizer, config['steps_per_epoch'], gamma=config['lr_decay']))
    if config['warmup_lr']:
        attach_lr_warmup(trainer, config, scheduler)
    else:
        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
Exemple #7
0
    def _run(
        self,
        trainer: Engine,
        optimizer: Optimizer,
        output_transform: Callable,
        num_iter: int,
        end_lr: float,
        step_mode: str,
        smooth_f: float,
        diverge_th: float,
    ):

        self._history = {"lr": [], "loss": []}
        self._best_loss = None
        self._diverge_flag = False

        # attach LRScheduler to trainer.
        if num_iter is None:
            num_iter = trainer.state.epoch_length * trainer.state.max_epochs
        else:
            max_iter = trainer.state.epoch_length * trainer.state.max_epochs
            if num_iter > max_iter:
                warnings.warn(
                    "Desired num_iter {} is unreachable with the current run setup of {} iteration "
                    "({} epochs)".format(num_iter, max_iter,
                                         trainer.state.max_epochs),
                    UserWarning,
                )

        if not trainer.has_event_handler(self._reached_num_iterations):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._reached_num_iterations, num_iter)

        # attach loss and lr logging
        if not trainer.has_event_handler(self._log_lr_and_loss):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._log_lr_and_loss, output_transform,
                                      smooth_f, diverge_th)

        self.logger.debug(
            "Running LR finder for {} iterations".format(num_iter))
        # Initialize the proper learning rate policy
        if step_mode.lower() == "exp":
            self._lr_schedule = LRScheduler(
                _ExponentialLR(optimizer, end_lr, num_iter))
        else:
            start_lr = optimizer.param_groups[0]["lr"]
            self._lr_schedule = PiecewiseLinear(optimizer,
                                                param_name="lr",
                                                milestones_values=[
                                                    (0, start_lr),
                                                    (num_iter, end_lr)
                                                ])
        if not trainer.has_event_handler(self._lr_schedule):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._lr_schedule, num_iter)
Exemple #8
0
def test_scheduler_with_param_groups():
    def _test(lr_scheduler, optimizer):
        num_iterations = 10
        max_epochs = 20

        state_dict = lr_scheduler.state_dict()

        trainer = Engine(lambda engine, batch: None)

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_lr():
            lrs.append((optimizer.param_groups[0]["lr"], optimizer.param_groups[1]["lr"]))

        trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

        data = [0] * num_iterations

        for _ in range(2):
            lrs = []
            trainer.run(data, max_epochs=max_epochs)
            assert [lr[0] for lr in lrs] == pytest.approx([lr[1] for lr in lrs])
            lr_scheduler.load_state_dict(state_dict)

    t1 = torch.zeros([1], requires_grad=True)
    t2 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([{"params": t1, "lr": 0.1}, {"params": t2, "lr": 0.1}])

    lr_scheduler = LinearCyclicalScheduler(optimizer, "lr", start_value=1.0, end_value=0.0, cycle_size=10)
    _test(lr_scheduler, optimizer)

    lr_scheduler = PiecewiseLinear(
        optimizer, "lr", milestones_values=[(5, 0.5), (15, 1.0), (25, 0.0), (35, 1.0), (40, 0.5)]
    )
    _test(lr_scheduler, optimizer)

    lr_scheduler = CosineAnnealingScheduler(optimizer, "lr", start_value=0.0, end_value=1.0, cycle_size=10)
    _test(lr_scheduler, optimizer)

    torch_lr_scheduler = ExponentialLR(optimizer, gamma=0.98)
    _test(LRScheduler(torch_lr_scheduler), optimizer)

    torch_lr_scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
    _test(LRScheduler(torch_lr_scheduler), optimizer)
Exemple #9
0
    def _test(torch_lr_scheduler_cls, **kwargs):

        tensor = torch.zeros([1], requires_grad=True)
        optimizer1 = torch.optim.SGD([tensor], lr=0.01)
        optimizer2 = torch.optim.SGD([tensor], lr=0.01)
        opt_state_dict1 = optimizer1.state_dict()
        opt_state_dict2 = optimizer2.state_dict()

        torch_lr_scheduler1 = torch_lr_scheduler_cls(optimizer=optimizer1, **kwargs)
        scheduler = LRScheduler(torch_lr_scheduler1)
        state_dict1 = scheduler.state_dict()
        print(state_dict1)

        torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2, **kwargs)
        state_dict2 = torch_lr_scheduler2.state_dict()

        def dummy_update(engine, batch):
            optimizer1.step()
            optimizer2.step()

        trainer = Engine(dummy_update)

        @trainer.on(Events.ITERATION_STARTED)
        def save_lr(engine):
            lrs.append(optimizer1.param_groups[0]["lr"])

        @trainer.on(Events.ITERATION_STARTED)
        def save_true_lr(engine):
            lrs_true.append(optimizer2.param_groups[0]["lr"])

        @trainer.on(Events.ITERATION_COMPLETED)
        def torch_lr_scheduler_step(engine):
            torch_lr_scheduler2.step()

        trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)

        for _ in range(2):
            lrs = []
            lrs_true = []
            data = [0] * 10
            max_epochs = 2
            trainer.run(data, max_epochs=max_epochs)
            assert lrs_true == pytest.approx(lrs), "{}: {} ({}) vs {} ({})".format(
                _, lrs_true, len(lrs_true), lrs, len(lrs)
            )
            optimizer1.load_state_dict(opt_state_dict1)
            scheduler.load_state_dict(state_dict1)
            optimizer2.load_state_dict(opt_state_dict2)
            torch_lr_scheduler2.load_state_dict(state_dict2)

        optimizer3 = torch.optim.SGD([tensor], lr=0.01)
        torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3, **kwargs)

        simulated_values = LRScheduler.simulate_values(
            num_events=len(data) * max_epochs, lr_scheduler=torch_lr_scheduler3
        )
        assert lrs == pytest.approx([v for i, v in simulated_values])
def run(tb, vb, lr, epochs, writer):
  device = os.environ['main-device']
  logging.info('Training program start!')
  logging.info('Configuration:')
  logging.info('\n'+json.dumps(INFO, indent=2))

  # ------------------------------------
  # 1. Define dataloader
  train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(tb, vb)
  weights = (1/num_of_images)/((1/num_of_images).sum().item())
  # weights = (1/num_of_images)/(1/num_of_images + 1/(num_of_images.sum().item()-num_of_images))
  weights = weights.to(device=device)
  
  # ------------------------------------
  # 2. Define model
  model = EfficientNet.from_pretrained('efficientnet-b4', num_classes=INFO['dataset-info']['num-of-classes'])
  model = carrier(model)
  
  # ------------------------------------
  # 3. Define optimizer
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
  ignite_scheduler = LRScheduler(scheduler)
  
  # ------------------------------------
  # 4. Define metrics

  train_metrics = {
    'accuracy': Accuracy(),
    'loss': Loss(nn.CrossEntropyLoss(weight=weights)),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes),
    'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(INFO['dataset-info']['num-of-classes']), train_loader.dataset.classes)
  }
  # ------------------------------------
  # 5. Create trainer
  trainer = create_supervised_trainer(model, optimizer, nn.CrossEntropyLoss(weight=weights), device=device)
  
  # ------------------------------------
  # 6. Create evaluator
  train_evaluator = create_supervised_evaluator(model, metrics=train_metrics, device=device)

  desc = 'ITERATION - loss: {:.4f}'
  pbar = tqdm(
    initial=0, leave=False, total=len(train_loader),
    desc=desc.format(0)
  )

  # ------------------------------------
  # 7. Create event hooks

  # Update process bar on each iteration completed.
  @trainer.on(Events.ITERATION_COMPLETED)
  def log_training_loss(engine):
    log_interval = 1
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
      pbar.desc = desc.format(engine.state.output)
      pbar.update(log_interval)

  # Refresh Process bar.
  @trainer.on(Events.EPOCH_COMPLETED)
  def refresh_pbar(engine):
    print ('Epoch {} completed!'.format(engine.state.epoch))
    pbar.refresh()
    pbar.n = pbar.last_print_n = 0

  # Compute metrics on train data on each epoch completed.
  # cpe = CustomPeriodicEvent(n_epochs=50)
  # cpe.attach(trainer)
  # @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
  def log_training_results(engine):
    pbar.refresh()
    print ('Checking on training set.')
    train_evaluator.run(train4val_loader)
    metrics = train_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_loss = metrics['loss']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Id: {}
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(os.environ['run-id'],engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss}, engine.state.epoch)
    # pbar.n = pbar.last_print_n = 0

  # Save model ever N epoch.
  save_model_handler = ModelCheckpoint(os.environ['savedir'], '', save_interval=10, n_saved=2)
  trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler, {'model': model})

  cpe = CustomPeriodicEvent(n_epochs=200)
  cpe.attach(trainer)
  trainer.add_event_handler(cpe.Events.EPOCHS_200_COMPLETED, log_training_results)

  # Update learning-rate due to scheduler.
  trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

  # ------------------------------------
  # Run
  trainer.run(train_loader, max_epochs=epochs)
  pbar.close()
def test_lr_scheduler():
    def _test(torch_lr_scheduler_cls, **kwargs):

        tensor = torch.zeros([1], requires_grad=True)
        optimizer1 = torch.optim.SGD([tensor], lr=0.01)
        optimizer2 = torch.optim.SGD([tensor], lr=0.01)
        opt_state_dict1 = optimizer1.state_dict()
        opt_state_dict2 = optimizer2.state_dict()

        torch_lr_scheduler1 = torch_lr_scheduler_cls(optimizer=optimizer1,
                                                     **kwargs)
        scheduler = LRScheduler(torch_lr_scheduler1)
        state_dict1 = scheduler.state_dict()

        torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2,
                                                     **kwargs)
        state_dict2 = torch_lr_scheduler2.state_dict()

        def dummy_update(engine, batch):
            optimizer1.step()
            optimizer2.step()

        trainer = Engine(dummy_update)

        @trainer.on(Events.ITERATION_STARTED)
        def save_lr(engine):
            lrs.append(optimizer1.param_groups[0]["lr"])

        @trainer.on(Events.ITERATION_STARTED)
        def save_true_lr(engine):
            lrs_true.append(optimizer2.param_groups[0]["lr"])

        @trainer.on(Events.ITERATION_COMPLETED)
        def torch_lr_scheduler_step(engine):
            torch_lr_scheduler2.step()

        trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)

        for _ in range(2):
            lrs = []
            lrs_true = []
            data = [0] * 10
            max_epochs = 2
            trainer.run(data, max_epochs=max_epochs)
            assert lrs_true == pytest.approx(
                lrs), "{}: {} ({}) vs {} ({})".format(_, lrs_true,
                                                      len(lrs_true), lrs,
                                                      len(lrs))
            optimizer1.load_state_dict(opt_state_dict1)
            scheduler.load_state_dict(state_dict1)
            optimizer2.load_state_dict(opt_state_dict2)
            torch_lr_scheduler2.load_state_dict(state_dict2)

        optimizer3 = torch.optim.SGD([tensor], lr=0.01)
        torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3,
                                                     **kwargs)

        simulated_values = LRScheduler.simulate_values(
            num_events=len(data) * max_epochs,
            lr_scheduler=torch_lr_scheduler3)
        assert lrs == pytest.approx([v for i, v in simulated_values])

    _test(torch.optim.lr_scheduler.StepLR, step_size=5, gamma=0.5)
    _test(torch.optim.lr_scheduler.ExponentialLR, gamma=0.78)

    # test _replicate_lr_scheduler
    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0.01)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,
                                                          gamma=0.78)
    init_lr_scheduler_state = dict(lr_scheduler.state_dict())
    copy_lr_scheduler = LRScheduler._replicate_lr_scheduler(lr_scheduler)
    for _ in range(10):
        optimizer.step()
        lr_scheduler.step()

    assert copy_lr_scheduler.state_dict() == init_lr_scheduler_state

    with pytest.raises(TypeError):
        LRScheduler._replicate_lr_scheduler(12)
Exemple #12
0
def run(tb, vb, lr, epochs, writer):
  device = os.environ['main-device']
  logging.info('Training program start!')
  logging.info('Configuration:')
  logging.info('\n'+json.dumps(INFO, indent=2))

  # ------------------------------------
  # 1. Define dataloader
  train_loader, train4val_loader, val_loader, num_of_images, mapping, support_train_loader, support_val_loader = get_dataloaders(tb, vb)
  weights = (1/num_of_images)/((1/num_of_images).sum().item())
  weights = weights.to(device=device)

  # Build iterable mix up batch loader
  it = iter(train_loader)
  sup_it = iter(support_train_loader)
  mixup_batches = zip(it, cycle(sup_it))

  # ------------------------------------
  # 2. Define model
  model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
  model = carrier(model)
  support_model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['supportset-info']['num-of-classes'])
  support_model = carrier(support_model)
  
  # ------------------------------------
  # 3. Define optimizer
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
  ignite_scheduler = LRScheduler(scheduler)

  support_optimizer = optim.SGD(support_model.parameters(), lr=lr, momentum=0.9)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(support_optimizer, T_max=200)
  support_ignite_scheduler = LRScheduler(scheduler)
  
  # ------------------------------------
  # 4. Define metrics
  train_metrics = {
    'accuracy': Accuracy(),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes),
    'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(INFO['dataset-info']['num-of-classes']), train_loader.dataset.classes)
  }

  support_metrics = {
    'accuracy': Accuracy(),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), support_val_loader.dataset.classes)
  }

  class DeepTransPrediction(metric.Metric):
    def __init__(self, threshold=torch.tensor([0.5]).repeat(len(train_loader.dataset.classes))):
      super(DeepTransPrediction, self).__init__()
      threshold = threshold.to(device=device)
      self.threshold = threshold
      self.prediction = torch.tensor([], dtype=torch.int)
      self.y = torch.tensor([], dtype=torch.int)
    
    def reset(self):
      self.threshold = torch.tensor([0.5]).repeat(len(train_loader.dataset.classes)).to(device=device)
      self.prediction = torch.tensor([])
      self.y = torch.tensor([])
      super(DeepTransPrediction, self).reset()
    
    def update(self, output):
      y_pred, y = output
      softmax = torch.exp(y_pred) / torch.exp(y_pred).sum(1)[:, None]
      values, inds = softmax.max(1)
      prediction = torch.where(values>self.threshold[inds], inds, torch.tensor([-1]).to(device=device))
      self.prediction = torch.cat((self.prediction.type(torch.LongTensor).to(device=device), torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)))
      self.y = torch.cat((self.y.type(torch.LongTensor).to(device=device), y.to(device=device)))
      # return self.prediction, self.y

    def compute(self):
      return self.prediction, self.y

  val_metrics = {
    'accuracy': MetricsLambda(Labels2Acc, DeepTransPrediction()),
    'precision_recall': MetricsLambda(Labels2PrecisionRecall, DeepTransPrediction(), val_loader.dataset.classes),
    'cmatrix': MetricsLambda(Labels2CMatrix, DeepTransPrediction(), val_loader.dataset.classes)
  }
  
  # ------------------------------------
  # 5. Create trainer
  # trainer = create_supervised_trainer(model, optimizer, nn.CrossEntropyLoss(weight=weights), device=device)

  def membership_loss(input, target, weights):
    _lambda = 5
    classes = input.shape[1]
    sigmoid = 1 / (1 + torch.exp(-input))
    part1 = 1-sigmoid[range(sigmoid.shape[0]), target]
    part1 = (part1 * part1 * weights[target]).sum()
    sigmoid[range(sigmoid.shape[0])] = 0
    part2 = (sigmoid * sigmoid * weights).sum()
    return part1 + _lambda*float(1/(classes-1))*part2

  def step(engine, batch):
    model.train()
    support_model.train()

    _alpha1 = 1
    _alpha2 = 1

    known, support = batch
    x_known, y_known = known
    x_support, y_support = support

    x_known = x_known.to(device=device)
    y_known = y_known.to(device=device)
    x_support = x_support.to(device=device)
    y_support = y_support.to(device=device)

    support_scores = support_model(x_support)
    support_cross_entropy = nn.functional.cross_entropy(support_scores, y_support)

    known_scores = model(x_known)
    known_cross_entropy = nn.functional.cross_entropy(known_scores, y_known, weights)
    known_membership = membership_loss(known_scores, y_known, weights)

    loss = support_cross_entropy + known_cross_entropy * _alpha1 + known_membership * _alpha2

    model.zero_grad()
    support_model.zero_grad()

    loss.backward()

    optimizer.step()
    support_optimizer.step()

    return {
      'Rce_loss': support_cross_entropy.item(),
      'Tce_loss': known_cross_entropy.item(),
      'Tm_loss': known_membership.item(),
      'total_loss': loss.item()
    }

  trainer = Engine(step)

  # ------------------------------------
  # 6. Create evaluator
  train_evaluator = create_supervised_evaluator(model, metrics=train_metrics, device=device)
  val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
  support_evaluator = create_supervised_evaluator(support_model, metrics=support_metrics, device=device)

  desc = 'ITERATION - loss: {:.2f}|{:.2f}|{:.2f}|{:.2f}'
  pbar = tqdm(
    initial=0, leave=False, total=len(train_loader),
    desc=desc.format(0,0,0,0)
  )

  # ------------------------------------
  # 7. Create event hooks

  @trainer.on(Events.ITERATION_COMPLETED)
  def log_training_loss(engine):
    log_interval = 1
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
      o = engine.state.output
      pbar.desc = desc.format(o['Rce_loss'], o['Tce_loss'], o['Tm_loss'], o['total_loss'])
      pbar.update(log_interval)

  @trainer.on(Events.EPOCH_STARTED)
  def rebuild_dataloader(engine):
    pbar.clear()
    print('Rebuild dataloader!')
    it = iter(train_loader)
    sup_it = iter(support_train_loader)
    engine.state.dataloader = zip(it, cycle(sup_it))

  @trainer.on(Events.EPOCH_COMPLETED)
  def log_training_results(engine):
    print ('Checking on training set.')
    train_evaluator.run(train4val_loader)
    metrics = train_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Id: {}
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      
      precision_recall: \n{}
      
      confusion matrix: \n{}
      """.format(os.environ['run-id'], engine.state.epoch,avg_accuracy,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)

  @trainer.on(Events.EPOCH_COMPLETED)
  def log_support_results(engine):
    pbar.clear()
    print ('* - * - * - * - * - * - * - * - * - * - * - *')
    print ('Checking on support set.')
    support_evaluator.run(support_val_loader)
    metrics = support_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    precision_recall = metrics['precision_recall']

    prompt = """
    Id: {}
    Support set Results - Epoch: {}
    Avg accuracy: {:.4f}
    precision_recall: \n{}
    """.format(os.environ['run-id'], engine.state.epoch, avg_accuracy, precision_recall['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Support/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Support/Score', {'Avg precision': precision_recall['data'][0, -1], 'Avg recall': precision_recall['data'][1, -1]}, engine.state.epoch)

  @trainer.on(Events.EPOCH_COMPLETED)
  def log_validation_results(engine):
    pbar.clear()
    print ('* - * - * - * - * - * - * - * - * - * - * - *')
    print ('Checking on validation set.')
    val_evaluator.run(val_loader)
    metrics = val_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    unknown = precision_recall['pretty']['UNKNOWN']
    print(unknown)
    prompt = """
      Id: {}
      Validating Results - Epoch: {}
      Avg accuracy: {:.4f}
      Unknown precision: {:.4f}
      Unknown recall: {:.4f}
      
      precision_recall: \n{}
      
      confusion matrix: \n{}
      """.format(os.environ['run-id'], engine.state.epoch,avg_accuracy, unknown['Precision'], unknown['Recall'],precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Score', {'Val avg Precision': precision_recall['data'][0, -1], 'Val avg Recall': precision_recall['data'][1, -1]}, engine.state.epoch)
    writer.add_scalars('Unknown/Score', {'Unknown Precision': unknown['Precision'], 'Unknown Recall': unknown['Recall']}, engine.state.epoch)
    pbar.n = pbar.last_print_n = 0

  trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)
  trainer.add_event_handler(Events.EPOCH_STARTED, support_ignite_scheduler)

  # ------------------------------------
  # Run
  trainer.run(mixup_batches, max_epochs=epochs)
  pbar.close()
Exemple #13
0
def run(tb, vb, lr, epochs, writer):
  device = os.environ['main-device']
  logging.info('Training program start!')
  logging.info('Configuration:')
  logging.info('\n'+json.dumps(INFO, indent=2))

  # ------------------------------------
  # 1. Define dataloader
  train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(tb, vb)
  # train_loader, train4val_loader, val_loader, num_of_images = get_dataloaders(tb, vb)
  weights = (1/num_of_images)/((1/num_of_images).sum().item())
  # weights = (1/num_of_images)/(1/num_of_images + 1/(num_of_images.sum().item()-num_of_images))
  weights = weights.to(device=device)
  
  # ------------------------------------
  # 2. Define model
  model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
  model = carrier(model)
  
  # ------------------------------------
  # 3. Define optimizer
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
  ignite_scheduler = LRScheduler(scheduler)
  
  # ------------------------------------
  # 4. Define metrics

  # class SoftCrossEntropyLoss(nn.Module):
  #   def __init__(self, weight=None):
  #     super(SoftCrossEntropyLoss, self).__init__()
  #     self.class_weights = weight
    
  #   def forward(self, input, target):
  #     softmax = torch.exp(input) / torch.exp(input).sum(1)[:, None]
  #     onehot_labels = to_onehot(target, input.shape[1])
  #     soft_labels = torch.zeros_like(onehot_labels)
  #     soft_labels = torch.where(onehot_labels.cpu() == 1, torch.tensor([0.9]), torch.tensor([0.1/(input.shape[1]-1)])).to(device=device)
  #     if self.class_weights is not None:
  #       # print(soft_labels.shape, softmax.shape)
  #       loss = -torch.sum(torch.log(softmax) * soft_labels * self.class_weights * input.shape[1])
  #     else:
  #       loss = -torch.sum(torch.log(softmax) * soft_labels)
  #     return loss

  class CrossEntropyLoss_forceavg(nn.Module):
    def __init__(self, weight=None):
      super(CrossEntropyLoss_forceavg, self).__init__()
      self.class_weights = weight.to(device=device)

    def forward(self, input, target):
      ce = nn.functional.cross_entropy(input, target, weight=self.class_weights)
      avg = (input ** 2).sum() - (input[range(input.shape[0]), target] ** 2).sum()
      # print(ce, avg/input.shape[0])
      return ce + avg/input.shape[0]

  class EntropyPrediction(metric.Metric):
    def __init__(self, threshold=0.5):
      super(EntropyPrediction, self).__init__()
      self.threshold = threshold
      self.prediction = torch.tensor([], dtype=torch.int)
      self.y = torch.tensor([], dtype=torch.int)
    
    def reset(self):
      # self.threshold = 0.5
      self.prediction = torch.tensor([])
      self.y = torch.tensor([])
      super(EntropyPrediction, self).reset()
    
    def update(self, output):
      y_pred, y = output
      softmax = torch.exp(y_pred) / torch.exp(y_pred).sum(1)[:, None]
      entropy_base = math.log(y_pred.shape[1])
      entropy = (-softmax * torch.log(softmax)).sum(1)/entropy_base
      values, inds = softmax.max(1)
      prediction = torch.where(entropy<self.threshold, inds, torch.tensor([-1]).to(device=device))
      self.prediction = torch.cat((self.prediction.type(torch.LongTensor).to(device=device), torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)))
      self.y = torch.cat((self.y.type(torch.LongTensor).to(device=device), y.to(device=device)))
      # return self.prediction, self.y

    def compute(self):
      return self.prediction, self.y

  train_metrics = {
    'accuracy': Accuracy(),
    'loss': Loss(CrossEntropyLoss_forceavg(weight=weights)),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes),
    'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(INFO['dataset-info']['num-of-classes']), train_loader.dataset.classes)
  }

  val_metrics = {
    'accuracy': MetricsLambda(Labels2Acc, EntropyPrediction(1.0)),
    'precision_recall': MetricsLambda(Labels2PrecisionRecall, EntropyPrediction(1.0), val_loader.dataset.classes),
    'cmatrix': MetricsLambda(Labels2CMatrix, EntropyPrediction(1.0), val_loader.dataset.classes)
  }
  
  # ------------------------------------
  # 5. Create trainer
  trainer = create_supervised_trainer(model, optimizer, CrossEntropyLoss_forceavg(weight=weights), device=device)
  
  # ------------------------------------
  # 6. Create evaluator
  train_evaluator = create_supervised_evaluator(model, metrics=train_metrics, device=device)
  val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)

  desc = 'ITERATION - loss: {:.4f}'
  pbar = tqdm(
    initial=0, leave=False, total=len(train_loader),
    desc=desc.format(0)
  )


  # ------------------------------------
  # 7. Create event hooks

  # Update process bar on each iteration completed.
  @trainer.on(Events.ITERATION_COMPLETED)
  def log_training_loss(engine):
    log_interval = 1
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
      pbar.desc = desc.format(engine.state.output)
      pbar.update(log_interval)

  @trainer.on(Events.EPOCH_STARTED)
  def refresh_pbar(engine):
    print('Finish epoch {}'.format(engine.state.epoch))
    pbar.refresh()
    pbar.n = pbar.last_print_n = 0

  # Compute metrics on train data on each epoch completed.
  cpe = CustomPeriodicEvent(n_epochs=50)
  cpe.attach(trainer)
  @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
  def log_training_results(engine):
    print ('Checking on training set.')
    train_evaluator.run(train4val_loader)
    metrics = train_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_loss = metrics['loss']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss}, engine.state.epoch)
  
  # Compute metrics on val data on each epoch completed.
  cpe = CustomPeriodicEvent(n_epochs=50)
  cpe.attach(trainer)
  @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
  def log_validation_results(engine):
    pbar.clear()
    print('* - * - * - * - * - * - * - * - * - * - * - * - *')
    print ('Checking on validation set.')
    val_evaluator.run(val_loader)
    metrics = val_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Validating Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch,avg_accuracy,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Score', {'Val avg precision': precision_recall['data'][0, -1], 'Val avg recall': precision_recall['data'][1, -1]}, engine.state.epoch)

  # Save model ever N epoch.
  save_model_handler = ModelCheckpoint(os.environ['savedir'], '', save_interval=10, n_saved=2)
  trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler, {'model': model})

  # Update learning-rate due to scheduler.
  trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

  # ------------------------------------
  # Run
  trainer.run(train_loader, max_epochs=epochs)
  pbar.close()
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch training script for SUN397 dataset')
    parser.add_argument('conf_file')
    parser.add_argument('output_dir', help='Model save directory')
    parser.add_argument('-w',
                        '--workers',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-b',
                        '--batch-size',
                        default=64,
                        type=int,
                        metavar='N',
                        help='mini-batch size')
    parser.add_argument('-T',
                        '--tensor-board-dir',
                        help='Tensor board log dir',
                        default='runs')
    parser.add_argument('--restart',
                        help='Restart',
                        default=False,
                        action='store_true')
    parser.add_argument('--checkpoint', help='checkpoint file')
    parser.add_argument('--eval',
                        default=False,
                        action='store_true',
                        help='checkpoint file')

    args = parser.parse_args()

    conf = load_conf(args.conf_file)
    train_set, val_set, net, criterion, metrics_dict, (
        score_name, score_function) = task_factory(conf['task'])(conf)

    if args.restart:
        run_id = find_recent_output_dir(conf['tag'], args.output_dir)
    else:
        run_id = '%s_%s' % (conf['tag'], datetime.now().strftime('%Y%m%d%H%M'))
    output_dir = os.path.join(args.output_dir, run_id)

    checkpoint_handler = CheckpointManager(output_dir,
                                           'model',
                                           score_name=score_name,
                                           score_function=score_function,
                                           extra={
                                               'conf': conf,
                                               'args': vars(args)
                                           })
    shutil.copy(args.conf_file, os.path.join(output_dir, 'conf.json'))
    loader_pin_memory = torch.cuda.is_available()
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=loader_pin_memory,
                                               drop_last=False)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=loader_pin_memory,
                                             drop_last=False)

    writer = create_summary_writer(net, train_loader, args.tensor_board_dir,
                                   run_id)

    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
        criterion = criterion.cuda()

    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=conf['lr'],
                                 weight_decay=conf['weight_decay'])

    trainer = create_supervised_trainer(net,
                                        optimizer,
                                        criterion,
                                        device=device,
                                        gradient_clip=conf['clip_gradient'])
    train_evaluator = create_supervised_evaluator(net,
                                                  metrics=metrics_dict,
                                                  device=device)

    evaluator = create_supervised_evaluator(net,
                                            metrics=metrics_dict,
                                            device=device)

    step_scheduler = StepLR(optimizer,
                            step_size=conf['lr_step'],
                            gamma=conf['lr_decay'])
    scheduler = LRScheduler(step_scheduler)
    trainer.add_event_handler(Events.EPOCH_STARTED, scheduler)

    all_params = {
        'model': net,
        'optimizer': optimizer,
        'lr_scheduler': step_scheduler
    }
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                                all_params)
    desc = "ITERATION - loss: {:.2f}"
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))
    log_interval = 10

    # load checkpoint
    if args.restart and checkpoint_handler.is_checkpoint_available():
        state_dicts = checkpoint_handler.load_last()
        load_model(all_params, state_dicts)
    elif args.checkpoint is not None:
        state_dicts = checkpoint_handler.load(args.checkpoint)
        load_model(all_params, state_dicts)

    @trainer.on(Events.EPOCH_STARTED)
    def setup_engine(engine):
        if engine.state.epoch == 1:
            engine.state.epoch = checkpoint_handler.epoch_ + 1

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)
            writer.add_scalar("training/loss", engine.state.output,
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        train_evaluator.run(train_loader)
        log_results(engine, train_evaluator, "Training", writer)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        checkpoint_handler.epoch_ = engine.state.epoch
        evaluator.run(val_loader)
        log_results(engine, evaluator, "Validation", writer)
        pbar.n = pbar.last_print_n = 0

    if args.eval:
        evaluator.run(val_loader)
        log_results(evaluator, evaluator, "Validation", writer)
    else:
        trainer.run(train_loader, max_epochs=conf['epochs'])
    pbar.close()
    print("END")
Exemple #15
0
def test_lr_scheduler_asserts():

    with pytest.raises(TypeError):
        LRScheduler(123)
def run(tb, vb, lr, epochs, writer):
    device = os.environ['main-device']
    logging.info('Training program start!')
    logging.info('Configuration:')
    logging.info('\n' + json.dumps(INFO, indent=2))

    # ------------------------------------
    # 1. Define dataloader
    train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(
        tb, vb)
    # train_loader, train4val_loader, val_loader, num_of_images = get_dataloaders(tb, vb)
    weights = (1 / num_of_images) / ((1 / num_of_images).sum().item())
    # weights = (1/num_of_images)/(1/num_of_images + 1/(num_of_images.sum().item()-num_of_images))
    weights = weights.to(device=device)

    # ------------------------------------
    # 2. Define model
    model = EfficientNet.from_pretrained(
        'efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
    model = carrier(model)

    # ------------------------------------
    # 3. Define optimizer
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    ignite_scheduler = LRScheduler(scheduler)

    # ------------------------------------
    # 4. Define metrics

    class DOCLoss(nn.Module):
        def __init__(self, weight):
            super(DOCLoss, self).__init__()
            self.class_weights = weight

        def forward(self, input, target):
            sigmoid = 1 - 1 / (1 + torch.exp(-input))
            sigmoid[range(0, sigmoid.shape[0]),
                    target] = 1 - sigmoid[range(0, sigmoid.shape[0]), target]
            sigmoid = torch.log(sigmoid)
            if self.class_weights is not None:
                loss = -torch.sum(sigmoid * self.class_weights)
            else:
                loss = -torch.sum(sigmoid)
            return loss

    train_metrics = {
        'accuracy':
        Accuracy(),
        'loss':
        Loss(DOCLoss(weight=weights)),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(), Recall(),
                      train_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(CMatrixTable,
                      ConfusionMatrix(INFO['dataset-info']['num-of-classes']),
                      train_loader.dataset.classes)
    }

    def val_pred_transform(output):
        y_pred, y = output
        new_y_pred = torch.zeros(
            (y_pred.shape[0],
             INFO['dataset-info']['num-of-classes'] + 1)).to(device=device)
        for ind, c in enumerate(train_loader.dataset.classes):
            new_col = val_loader.dataset.class_to_idx[c]
            new_y_pred[:, new_col] += y_pred[:, ind]
        ukn_ind = val_loader.dataset.class_to_idx['UNKNOWN']
        import math
        new_y_pred[:, ukn_ind] = -math.inf
        return new_y_pred, y

    val_metrics = {
        'accuracy':
        Accuracy(),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(val_pred_transform),
                      Recall(val_pred_transform), val_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(
            CMatrixTable,
            ConfusionMatrix(INFO['dataset-info']['num-of-classes'] + 1,
                            output_transform=val_pred_transform),
            val_loader.dataset.classes)
    }

    # ------------------------------------
    # 5. Create trainer
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        DOCLoss(weight=weights),
                                        device=device)

    # ------------------------------------
    # 6. Create evaluator
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=train_metrics,
                                                  device=device)
    val_evaluator = create_supervised_evaluator(model,
                                                metrics=val_metrics,
                                                device=device)

    desc = 'ITERATION - loss: {:.4f}'
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    # ------------------------------------
    # 7. Create event hooks

    # Update process bar on each iteration completed.
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        log_interval = 1
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

    # Compute metrics on train data on each epoch completed.
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        print('Checking on training set.')
        train_evaluator.run(train4val_loader)
        metrics = train_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['loss']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, avg_loss,
                 precision_recall['pretty'], cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss},
                           engine.state.epoch)

    # Compute metrics on val data on each epoch completed.
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        print('Checking on validation set.')
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      Validating Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, precision_recall['pretty'],
                 cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars(
            'Aggregate/Score', {
                'Val avg precision': precision_recall['data'][0, -1],
                'Val avg recall': precision_recall['data'][1, -1]
            }, engine.state.epoch)
        pbar.n = pbar.last_print_n = 0

    # Save model ever N epoch.
    save_model_handler = ModelCheckpoint(os.environ['savedir'],
                                         '',
                                         save_interval=50,
                                         n_saved=2)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler,
                              {'model': model})

    # Update learning-rate due to scheduler.
    trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

    # ------------------------------------
    # Run
    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()
Exemple #17
0
def test_lr_scheduler():
    def _test(torch_lr_scheduler_cls, **kwargs):

        tensor = torch.zeros([1], requires_grad=True)
        optimizer1 = torch.optim.SGD([tensor], lr=0.1)
        optimizer2 = torch.optim.SGD([tensor], lr=0.1)

        torch_lr_scheduler1 = torch_lr_scheduler_cls(optimizer=optimizer1,
                                                     **kwargs)
        torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2,
                                                     **kwargs)
        scheduler = LRScheduler(torch_lr_scheduler1)

        lrs = []
        lrs_true = []

        trainer = Engine(lambda engine, batch: None)

        @trainer.on(Events.ITERATION_STARTED)
        def torch_lr_scheduler_step(engine):
            torch_lr_scheduler2.step()

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_lr(engine):
            lrs.append(optimizer1.param_groups[0]['lr'])

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_true_lr(engine):
            lrs_true.append(optimizer2.param_groups[0]['lr'])

        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        data = [0] * 10
        max_epochs = 2
        trainer.run(data, max_epochs=max_epochs)

        assert lrs_true == pytest.approx(lrs)

        optimizer3 = torch.optim.SGD([tensor], lr=0.1)
        torch_lr_scheduler3 = torch_lr_scheduler_cls(optimizer=optimizer3,
                                                     **kwargs)

        simulated_values = LRScheduler.simulate_values(
            num_events=len(data) * max_epochs,
            lr_scheduler=torch_lr_scheduler3)
        assert lrs == pytest.approx([v for i, v in simulated_values])

    _test(torch.optim.lr_scheduler.StepLR, step_size=5, gamma=0.5)
    _test(torch.optim.lr_scheduler.ExponentialLR, gamma=0.78)

    # test _replicate_lr_scheduler
    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0.1)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,
                                                          gamma=0.78)
    init_lr_scheduler_state = dict(lr_scheduler.state_dict())
    copy_lr_scheduler = LRScheduler._replicate_lr_scheduler(lr_scheduler)
    for _ in range(10):
        lr_scheduler.step()

    assert copy_lr_scheduler.state_dict() == init_lr_scheduler_state
Exemple #18
0
def main(
    dataset_root,
    init_load_file,
    train_dataset_size,
    num_sbs_layers,
    bond_dim_size,
    tb_log_dir,
    models_dir,
    learning_rate,
    momentum,
    batch_size,
    initialization,
    initialization_std,
    scale_layers_using_batch,
    epochs,
    device,
    seed,
    early_stopping_patience_num_epochs,
    warmup_num_epochs,
    warmup_initial_multiplier,
    cos_sin_squared,
    make_input_window_std_one,
    input_multiplier,
    optimizer_type,
    rmsprop_alpha,
    weight_decay,
    shuffle_pixels,
):
    if not shuffle_pixels:
        transform = MNIST_TRANSFORM
    else:
        print("Pixel shuffling is enabled")
        shuffled_pixels_indices = tuple(shuffled(range(h * w)))
        logger.info(f"{hash(shuffled_pixels_indices)=}")
        pixel_shuffle_transform = transforms.Lambda(
            partial(permute_pixels, shuffled_pixels_indices))
        transform = transforms.Compose(
            (MNIST_TRANSFORM, pixel_shuffle_transform))
    dataset = MNIST(dataset_root,
                    train=True,
                    download=True,
                    transform=transform)
    assert len(dataset) == MNIST_DATASET_SIZE
    train_dataset, val_dataset = random_split(
        dataset, (train_dataset_size, MNIST_DATASET_SIZE - train_dataset_size))
    logger.info(f"{hash(tuple(val_dataset.indices))=}")
    train_loader, val_loader = (DataLoader(
        dataset_,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=(device.type == "cuda"),
    ) for dataset_ in (train_dataset, val_dataset))
    if initialization == "dumb-normal":
        assert initialization_std is not None
        init = DumbNormalInitialization(initialization_std)
    elif initialization == "khrulkov-normal":
        init = KhrulkovNormalInitialization(initialization_std)
    elif initialization == "normal-preserving-output-std":
        assert initialization_std is None
        init = NormalPreservingOutputStdInitialization()
    elif initialization == "min-random-eye":
        assert initialization_std is not None
        init = MinRandomEyeInitialization(initialization_std)
    else:
        raise ValueError(f"Invalid value: {initialization=}")
    assert not make_input_window_std_one or input_multiplier is None
    if make_input_window_std_one:
        kernel_size = 3
        window_std = calc_std_of_coordinates_of_windows(
            next(
                iter(
                    DataLoader(dataset,
                               batch_size=MNIST_DATASET_SIZE,
                               shuffle=False)))[0],
            kernel_size=kernel_size,
            cos_sin_squared=cos_sin_squared,
        ).item()
        logger.info(f"{window_std=}")
        input_multiplier = (1.0 / window_std)**(1 / kernel_size**2)
    elif input_multiplier is None:
        input_multiplier = 1.0
    logger.info(f"{input_multiplier=}")
    model = DCTNMnistModel(
        num_sbs_layers,
        bond_dim_size,
        False,
        init,
        cos_sin_squared,
        input_multiplier,
    )
    # with torch.autograd.detect_anomaly():
    #   X, y = next(iter(train_loader))
    #   logits = model(X)
    #   loss = tnnf.cross_entropy(logits, y)
    #   print(loss.item())
    #   loss.backward()
    if init_load_file:
        model.load_state_dict(torch.load(init_load_file, map_location=device))
    elif scale_layers_using_batch is not None:
        model.scale_layers_using_batch(
            next(
                iter(
                    DataLoader(dataset,
                               batch_size=scale_layers_using_batch,
                               shuffle=True)))[0])
        logger.info("Done model.scale_layers_using_batch")
    assert rmsprop_alpha is None or optimizer_type == "rmsprop"
    if optimizer_type == "sgd":
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=learning_rate,
            momentum=momentum,
            weight_decay=weight_decay,
        )
    elif optimizer_type == "rmsprop":
        optimizer = torch.optim.RMSprop(
            model.parameters(),
            lr=learning_rate,
            momentum=momentum,
            alpha=rmsprop_alpha,
            weight_decay=weight_decay,
        )
    else:
        raise ValueError("Invalid optimizer_type: {optimizer_type}")

    prepare_batch_for_trainer = make_standard_prepare_batch_with_events(device)
    trainer = setup_trainer(
        model,
        optimizer,
        tnnf.cross_entropy,
        device=device,
        prepare_batch=prepare_batch_for_trainer,
    )

    scheduler = LRScheduler(
        torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lambda epoch: (warmup_initial_multiplier**(
                (warmup_num_epochs - epoch) / warmup_num_epochs)
                           if epoch < warmup_num_epochs else 1.0),
        ))
    trainer.add_event_handler(Events.EPOCH_STARTED, scheduler)
    metrics = {
        "cross_entropy_loss": Loss(tnnf.cross_entropy),
        "accuracy": Accuracy()
    }
    prepare_batch_for_val_evaluator = make_standard_prepare_batch_with_events(
        device)
    val_evaluator = setup_evaluator(
        model,
        trainer,
        val_loader,
        metrics,
        device=device,
        prepare_batch=prepare_batch_for_val_evaluator,
    )
    add_checkpointing(
        models_dir,
        "cross_entropy_loss",
        val_evaluator,
        objects_to_save={"model": model},
        model=model,
    )
    add_checkpointing_of_last_models(
        models_dir,
        val_evaluator,
        {"model": model},
        model,
        num_checkpoints=10,
        save_interval=20,
    )
    if early_stopping_patience_num_epochs is not None:
        add_early_stopping(
            trainer,
            val_evaluator,
            "cross_entropy_loss",
            patience_num_evaluations=early_stopping_patience_num_epochs,
        )
    with setup_tensorboard_logger(tb_log_dir,
                                  trainer,
                                  metrics.keys(), {"val": val_evaluator},
                                  model=model) as tb_logger:
        add_weights_and_grads_logging(trainer, tb_logger, model)
        add_optimizer_params_logging(optimizer, tb_logger, trainer)
        is_string = lambda _, module: isinstance(module, ConvSBS)
        create_every_n_iters_intermediate_outputs_logger(
            model,
            tb_logger.writer,
            is_string,
            trainer,
            "train",
            every_n_iters=20,
            loggers=(
                log_dumb_mean_of_abs,
                log_dumb_min_of_abs,
                log_dumb_max_of_abs,
                log_dumb_mean,
                log_dumb_std,
                log_dumb_histogram,  # maybe remove this later for performance's sake
            ),
        )
        add_conv_sbs_tt_tensor_statistics_logging(model, tb_logger.writer,
                                                  trainer, 20)
        create_every_n_iters_intermediate_outputs_logger(
            model,
            tb_logger.writer,
            lambda _, module: module is model,
            trainer,
            "train_outputs_of_the_whole_model",
            every_n_iters=20,
            loggers=(
                log_logits_as_probabilities,
                log_dumb_min,
                log_dumb_max,
                log_dumb_mean,
                log_dumb_std,
            ),
        )
        add_quantum_inputs_statistics_logging(model, trainer, tb_logger.writer,
                                              20)
        create_every_n_iters_intermediate_outputs_logger(
            model,
            tb_logger.writer,
            lambda _, module: module is model,
            trainer,
            "train_input",
            20,
            loggers=((
                "std_of_coordinates_of_windows",
                RecordType.SCALAR,
                partial(
                    calc_std_of_coordinates_of_windows,
                    kernel_size=3,
                    cos_sin_squared=cos_sin_squared,
                    multiplier=input_multiplier,
                ),
            ), ),
            use_input=True,
        )
        trainer.run(train_loader, max_epochs=epochs)
Exemple #19
0
def run(tb, vb, lr, epochs, writer):
  device = os.environ['main-device']
  logging.info('Training program start!')
  logging.info('Configuration:')
  logging.info('\n'+json.dumps(INFO, indent=2))

  # ------------------------------------
  # 1. Define dataloader
  train_loader, train4val_loader, val_loader, num_of_images, mapping, _ = get_dataloaders(tb, vb)
  weights = (1/num_of_images)/((1/num_of_images).sum().item())
  weights = weights.to(device=device)
  
  # ------------------------------------
  # 2. Define model
  model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
  model = carrier(model)
  
  # ------------------------------------
  # 3. Define optimizer
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
  ignite_scheduler = LRScheduler(scheduler)
  
  # ------------------------------------
  # 4. Define metrics

  class EntropyPrediction(metric.Metric):
    def __init__(self, threshold=1.0):
      super(EntropyPrediction, self).__init__()
      self.threshold = threshold
      self.prediction = torch.tensor([], dtype=torch.int)
      self.y = torch.tensor([], dtype=torch.int)
    
    def reset(self):
      # self.threshold = 0.3
      self.prediction = torch.tensor([])
      self.y = torch.tensor([])
      super(EntropyPrediction, self).reset()
    
    def update(self, output):
      y_pred, y = output
      softmax = torch.exp(y_pred) / torch.exp(y_pred).sum(1)[:, None]
      entropy_base = math.log(y_pred.shape[1])
      entropy = (-softmax * torch.log(softmax)).sum(1)/entropy_base
      values, inds = softmax.max(1)
      prediction = torch.where(entropy<self.threshold, inds, torch.tensor([-1]).to(device=device))
      self.prediction = torch.cat((self.prediction.type(torch.LongTensor).to(device=device), torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)))
      self.y = torch.cat((self.y.type(torch.LongTensor).to(device=device), y.to(device=device)))
      # return self.prediction, self.y

    def compute(self):
      return self.prediction, self.y

  train_metrics = {
    'accuracy': Accuracy(),
    'loss': Loss(CrossEntropywithLS(weight=weights)),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes),
    'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(INFO['dataset-info']['num-of-classes']), train_loader.dataset.classes)
  }

  val_metrics = {
    'accuracy': MetricsLambda(Labels2Acc, EntropyPrediction()),
    'precision_recall': MetricsLambda(Labels2PrecisionRecall, EntropyPrediction(), val_loader.dataset.classes),
    'cmatrix': MetricsLambda(Labels2CMatrix, EntropyPrediction(), val_loader.dataset.classes)
  }
  
  # ------------------------------------
  # 5. Create trainer
  trainer = create_supervised_trainer(model, optimizer, CrossEntropywithLS(weight=weights), device=device)
  
  # ------------------------------------
  # 6. Create evaluator
  train_evaluator = create_supervised_evaluator(model, metrics=train_metrics, device=device)
  val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)

  desc = 'Epoch {} - loss: {:.4f}'
  pbar = tqdm(
    initial=0, leave=False, total=len(train_loader),
    desc=desc.format(0, 0)
  )


  # ------------------------------------
  # 7. Create event hooks

  # Basic events on showing training procedure.
  @trainer.on(Events.ITERATION_COMPLETED)
  def log_training_loss(engine):
    log_interval = 1
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
      pbar.desc = desc.format(engine.state.epoch, engine.state.output)
      pbar.update(log_interval)

  @trainer.on(Events.EPOCH_COMPLETED)
  def refresh_pbar(engine):
    pbar.refresh()
    print('Finish epoch {}!'.format(engine.state.epoch))
    pbar.n = pbar.last_print_n = 0

  # Trigger: Compute metrics on training data.
  def log_training_results(engine):
    print ('Checking on training set.')
    train_evaluator.run(train4val_loader)
    metrics = train_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_loss = metrics['loss']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss}, engine.state.epoch)
  
  # Trigger: Compute metrics on validating data.
  def log_validation_results(engine):
    pbar.clear()
    print ('* - * - * - * - * - * - * - * - * - * - * - * - *')
    print ('Checking on validation set.')
    val_evaluator.run(val_loader)
    metrics = val_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Validating Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch,avg_accuracy,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Score', {'Val avg precision': precision_recall['data'][0, -1], 'Val avg recall': precision_recall['data'][1, -1]}, engine.state.epoch)

  # ------------------------------------
  # Trainer triggers settings
  
  # Save model ever N epoch.
  save_model_handler = ModelCheckpoint(os.environ['savedir'], '', save_interval=10, n_saved=2)
  trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler, {'model': model})
  
  # Evaluate.
  evaluate_interval = epochs
  cpe = CustomPeriodicEvent(n_epochs=epochs)
  cpe.attach(trainer)
  on_evaluate_event = getattr(cpe.Events, 'EPOCHS_{}_COMPLETED'.format(evaluate_interval))
  trainer.add_event_handler(on_evaluate_event, log_training_results)
  trainer.add_event_handler(on_evaluate_event, log_validation_results)

  # Update learning rate.
  trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

  # ------------------------------------
  # Run
  trainer.run(train_loader, max_epochs=epochs)
  pbar.close()
    json_annotations=os.path.join(cfg.CV_DIR, 'fold-1', 'train_objects_both_train.json'),
    images_dir=cfg.TEST_IMAGES,
    masks_dir=cfg.TEST_IGNORE_MASKS,
    prepare_sample_fn=D.prepare_test_sample_fn_v1
)
test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=8, pin_memory=False, drop_last=False)

model = Model().float().cuda()
# model.load_state_dict(torch.load(os.path.join(os.path.join(cfg.WORKDIR, 'model_1', 'stage_1'), 'checkpoints', '_model_3.pth')))
# for p in model.backbone.parameters():
#     p.requires_grad = False

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0005)

# scheduler = PiecewiseLinear(optimizer, "lr",
#                             milestones_values=[(0, 0.000001), (iters_per_epoch, 0.0001), (6 * iters_per_epoch, 0.000001), (n_epochs * iters_per_epoch, 0.0000001)])
scheduler = LRScheduler(lr_scheduler=StepLR(optimizer=optimizer, step_size=iters_per_epoch, gamma=0.9))

def loss_fn(y_pred, y):
    cls_loss = nn.functional.cross_entropy(y_pred['cls_score'], y['label'])

    rot_loss_l1 = nn.functional.l1_loss(y_pred['rotation'], y['rotation'])
    # rot_loss_cos = 1 - (y_pred['rotation'][:, -1] * y['rotation'][:, -1] + (y_pred['rotation'][:, :-1] * y['rotation'][:, :-1]).sum(dim=1)).mean()
    rot_loss_cos = 1 - (y_pred['rotation'] * y['rotation']).sum(dim=1).mean()

    trans_loss = huber_loss(y_pred['translation'], y['translation'])
    true_distance = mean_distance(y_pred['translation'], y['translation'])

    loss = rot_loss_cos + 0.1 * trans_loss + cls_loss
    return dict(loss=loss, rot_loss_l1=rot_loss_l1, rot_loss_cos=rot_loss_cos, trans_loss=trans_loss, true_distance=true_distance, cls_loss=cls_loss)
Exemple #21
0
                                    'epochs': args.epochs,
                                    'lr': args.lr,
                                    'step_size': args.step_size,
                                    'gamma': args.gamma,
                                    'weight_decay': args.weight_decay,
                                    'model': repr(model)})

optimizer = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.weight_decay)

step_scheduler = StepLR(optimizer,
                        step_size=args.step_size,
                        gamma=args.gamma)

scheduler = LRScheduler(step_scheduler)

criterion = nn.CrossEntropyLoss()


def update(engine, batch):

    inputs, targets = batch
    optimizer.zero_grad()

    with autocast():
        outputs = model(inputs.cuda())
        loss = criterion(outputs, targets.cuda())

    scaler.scale(loss).backward()
    scaler.step(optimizer)
def main(dataset_root, train_dataset_size, tb_log_dir, models_dir,
         learning_rate, batch_size, device, seed, shuffle_pixels,
         load_model: Optional[str], train: bool, test: bool):
    if not shuffle_pixels:
        transform = MNIST_TRANSFORM
    else:
        print("Pixel shuffling is enabled")
        pixel_shuffle_transform = transforms.Lambda(
            partial(permute_pixels, shuffled(range(h * w))))
        transform = transforms.Compose(
            (MNIST_TRANSFORM, pixel_shuffle_transform))
    model = TTMnistModel((r1, r2, r3, r4))
    if load_model is not None:
        model.load_state_dict(torch.load(load_model, "cpu"))
        logger.debug(f"Loaded model from {load_model}")
    metrics = {
        "cross_entropy_loss": Loss(tnnf.cross_entropy),
        "accuracy": Accuracy()
    }
    if train:
        dataset = MNIST(dataset_root,
                        train=True,
                        download=True,
                        transform=transform)
        assert len(dataset) == MNIST_DATASET_SIZE
        train_dataset, val_dataset = random_split(
            dataset,
            (train_dataset_size, MNIST_DATASET_SIZE - train_dataset_size))
        train_loader, val_loader = (DataLoader(
            dataset_,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=(device.type == "cuda")) for dataset_ in (train_dataset,
                                                                 val_dataset))
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=learning_rate,
                                    momentum=0.95,
                                    weight_decay=0.0005)

        prepare_batch_for_trainer = make_standard_prepare_batch_with_events(
            device)
        trainer = setup_trainer(model,
                                optimizer,
                                tnnf.cross_entropy,
                                device=device,
                                prepare_batch=prepare_batch_for_trainer)
        scheduler = LRScheduler(
            torch.optim.lr_scheduler.StepLR(optimizer,
                                            step_size=2,
                                            gamma=0.8547))
        trainer.add_event_handler(Events.EPOCH_STARTED, scheduler)
        prepare_batch_for_val_evaluator = make_standard_prepare_batch_with_events(
            device)
        val_evaluator = setup_evaluator(
            model,
            trainer,
            val_loader,
            metrics,
            device=device,
            prepare_batch=prepare_batch_for_val_evaluator)
        checkpointer = add_checkpointing(models_dir,
                                         "cross_entropy_loss",
                                         val_evaluator,
                                         objects_to_save={"model": model},
                                         model=model)
        add_early_stopping(trainer,
                           val_evaluator,
                           "cross_entropy_loss",
                           patience_num_evaluations=25)
        with setup_tensorboard_logger(tb_log_dir,
                                      trainer,
                                      metrics.keys(), {"val": val_evaluator},
                                      model=model) as tb_logger:
            add_weights_and_grads_logging(trainer, tb_logger, model)
            add_logging_input_images(tb_logger, trainer, "train",
                                     prepare_batch_for_trainer)
            add_logging_input_images(tb_logger,
                                     val_evaluator,
                                     "val",
                                     prepare_batch_for_val_evaluator,
                                     another_engine=trainer)
            trainer.run(train_loader, max_epochs=100)
        if len(checkpointer._saved) > 0:
            best_model_path = checkpointer._saved[0][1][0]
            logger.info(f"The best model is saved at '{best_model_path}'")
            model.load_state_dict(torch.load(best_model_path))
    if test:
        test_dataset = MNIST(dataset_root,
                             train=False,
                             download=True,
                             transform=transform)
        test_loader = DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 pin_memory=(device.type == "cuda"))
        test_evaluator = create_supervised_evaluator(model, metrics, device)
        test_evaluator.run(test_loader)
        print(
            f"On test dataset the best model got: {test_evaluator.state.metrics}"
        )
Exemple #23
0
def train(model_class: Model, params: Params, optim_type: OptimizerType,
          dataset: DataSet):
    if dataset == DataSet.CRC:
        root_dir = params.crc_root_dir
    else:
        root_dir = params.root_dir
    num_classes = params.num_classes
    batch_size = params.batch_size
    num_workers = params.num_workers
    if model_class == SCCNN:
        width = height = 27
    elif model_class == RCCnet:
        width = height = 32
    else:
        raise Exception(
            "Model_class should be either a softmax SCCNN or RCCnet class initializer"
        )
    train_ds = ClassificationDataset(root_dir=root_dir,
                                     train=True,
                                     shift=params.shift)
    train_dl = DataLoader(train_ds,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=num_workers)

    test_ds = ClassificationDataset(root_dir=root_dir,
                                    train=False,
                                    width=width,
                                    height=height)
    test_dl = DataLoader(test_ds,
                         batch_size=batch_size,
                         num_workers=num_workers)

    class_weight_dict = np.load(os.path.join(root_dir, "class_weights.npy"),
                                allow_pickle=True).item()
    class_weights = torch.tensor(
        [
            class_weight_dict[i] / class_weight_dict["total"]
            for i in range(num_classes)
        ],
        dtype=torch.float,
    )
    model: Model = model_class(loss_weights=class_weights,
                               num_classes=num_classes,
                               dropout_p=params.dropout_p)

    if optim_type == OptimizerType.ADAM:
        optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
        scheduler = None
    elif optim_type == OptimizerType.SGD:
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=params.lr,
            weight_decay=params.weight_decay,
            momentum=params.momentum,
        )
        step_scheduler = StepLR(optimizer,
                                step_size=params.lr_step_size,
                                gamma=0.1)
        scheduler = LRScheduler(step_scheduler)
    else:
        raise Exception("Unsupported optimizer type")

    criterion = nn.NLLLoss(weight=model.class_weights)

    time1 = time.time()
    model.train_model(
        train_dl,
        optimizer,
        criterion,
        max_epochs=params.epochs,
        val_loader=test_dl,
        scheduler=scheduler,
    )
    time2 = time.time()
    print(
        "It took {:.5f} seconds to train {} epochs, average of {:.5f} sec/epoch"
        .format((time2 - time1), params.epochs,
                (time2 - time1) / params.epochs))
Exemple #24
0
def test_simulate_values():
    def _test(scheduler_cls, **scheduler_kwargs):

        optimizer = None
        if scheduler_cls == LRScheduler:
            scheduler_kwargs['optimizer'] = scheduler_kwargs[
                'lr_scheduler'].optimizer
            optimizer = scheduler_kwargs['optimizer']
        elif scheduler_cls == ConcatScheduler:
            optimizer = scheduler_kwargs['optimizer']
            del scheduler_kwargs['optimizer']
        else:
            tensor = torch.zeros([1], requires_grad=True)
            scheduler_kwargs['optimizer'] = torch.optim.SGD([tensor], lr=0.1)
            optimizer = scheduler_kwargs['optimizer']

        max_epochs = 2
        data = [0] * 10
        simulated_values = scheduler_cls.simulate_values(num_events=len(data) *
                                                         max_epochs,
                                                         **scheduler_kwargs)

        scheduler = scheduler_cls(**scheduler_kwargs)

        lrs = []

        def save_lr(engine):
            lrs.append(optimizer.param_groups[0]['lr'])

        trainer = Engine(lambda engine, batch: None)
        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
        trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
        trainer.run(data, max_epochs=max_epochs)

        assert lrs == pytest.approx([v for i, v in simulated_values])

        if scheduler_cls == LRScheduler or scheduler_cls == ConcatScheduler:
            # As internal state of torch lr scheduler has been changed the following checks will fail
            return

        # reexecute to check if no internal changes
        simulated_values = scheduler_cls.simulate_values(
            num_events=len(data) * max_epochs,
            save_history=True,  # this will be removed
            **scheduler_kwargs)
        assert lrs == pytest.approx([v for i, v in simulated_values])

    # LinearCyclicalScheduler
    _test(LinearCyclicalScheduler,
          param_name="lr",
          start_value=1.0,
          end_value=0.0,
          cycle_size=10)

    # CosineAnnealingScheduler
    _test(CosineAnnealingScheduler,
          param_name="lr",
          start_value=1.0,
          end_value=0.0,
          cycle_size=10)

    # LRScheduler
    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0.1)
    torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer, gamma=0.5)

    _test(LRScheduler, lr_scheduler=torch_lr_scheduler)

    # ConcatScheduler = [LinearCyclicalScheduler, CosineAnnealingScheduler]
    scheduler_1 = LinearCyclicalScheduler(optimizer,
                                          "lr",
                                          start_value=1.0,
                                          end_value=0.0,
                                          cycle_size=20)
    scheduler_2 = CosineAnnealingScheduler(optimizer,
                                           "lr",
                                           start_value=0.0,
                                           end_value=1.0,
                                           cycle_size=10)
    durations = [
        10,
    ]
    _test(ConcatScheduler,
          optimizer=optimizer,
          schedulers=[scheduler_1, scheduler_2],
          durations=durations)

    # ConcatScheduler = [LinearCyclicalScheduler, LRScheduler]
    tensor = torch.ones([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0.001)
    torch_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer=optimizer, gamma=1.5)
    scheduler_1 = LRScheduler(torch_lr_scheduler)
    scheduler_2 = LinearCyclicalScheduler(optimizer,
                                          "lr",
                                          start_value=0.1,
                                          end_value=0.0,
                                          cycle_size=10)
    durations = [
        10,
    ]
    _test(ConcatScheduler,
          optimizer=optimizer,
          schedulers=[scheduler_1, scheduler_2],
          durations=durations)
Exemple #25
0
def train(epochs: int, model: nn.Module, train_loader: DataLoader,
          valid_loader: DataLoader, criterion: Callable, device: str,
          lr: float, patience: int, lr_decay: float, lr_scheduler: str,
          lr_scheduler_kwargs: Dict[str, Any]):

    model.to(torch.device(device))
    optimizer = optim.Adam(
        [param for param in model.parameters() if param.requires_grad], lr=lr)

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)

    scheduler = LRScheduler(
        getattr(optim.lr_scheduler, lr_scheduler)(optimizer,
                                                  **lr_scheduler_kwargs))
    trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)

    pbar = ProgressBar(False)
    pbar.attach(trainer)

    train_evaluator = create_supervised_evaluator(
        model,
        metrics={
            'ACC': Accuracy(discreted_output_transform),
            'BCE': Loss(criterion),
            'AP': AveragePrecision(probability_output_transform)
        },
        device=device)
    valid_evaluator = create_supervised_evaluator(
        model,
        metrics={
            'ACC': Accuracy(discreted_output_transform),
            'BCE': Loss(criterion),
            'AP': AveragePrecision(probability_output_transform)
        },
        device=device)

    history = {
        col: list()
        for col in [
            'epoch', 'elapsed time', 'iterations', 'lr', 'train BCE',
            'valid BCE', 'train ACC', 'valid ACC', 'train AP', 'valid AP'
        ]
    }

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_loader)

        history['train BCE'] += [train_evaluator.state.metrics['BCE']]
        history['train ACC'] += [train_evaluator.state.metrics['ACC']]
        history['train AP'] += [train_evaluator.state.metrics['AP']]

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        valid_evaluator.run(valid_loader)

        history['epoch'] += [valid_evaluator.state.epoch]
        history['iterations'] += [valid_evaluator.state.epoch_length]
        history['elapsed time'] += [
            0 if len(history['elapsed time']) == 0 else
            history['elapsed time'][-1] +
            valid_evaluator.state.times['COMPLETED']
        ]
        history['lr'] += [scheduler.get_param()]

        history['valid BCE'] += [valid_evaluator.state.metrics['BCE']]
        history['valid ACC'] += [valid_evaluator.state.metrics['ACC']]
        history['valid AP'] += [valid_evaluator.state.metrics['AP']]

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_progress_bar(engine):
        pbar.log_message(
            f"train BCE: {history['train BCE'][-1]:.2f} " \
            + f"train ACC: {history['train ACC'][-1]:.2f} " \
            + f"train AP: {history['train AP'][-1]:.2f} " \
            + f"valid BCE: {history['valid BCE'][-1]:.2f} " \
            + f"valid ACC: {history['valid ACC'][-1]:.2f} " \
            + f"valid AP: {history['valid AP'][-1]:.2f}"
        )

    # Early stopping
    handler = EarlyStopping(patience=patience,
                            score_function=score_function,
                            trainer=trainer)
    valid_evaluator.add_event_handler(Events.EPOCH_COMPLETED, handler)

    trainer.run(train_loader, max_epochs=epochs)
    return pd.DataFrame(history)
def run(tb, vb, lr, epochs, writer):
    device = os.environ['main-device']
    logging.info('Training program start!')
    logging.info('Configuration:')
    logging.info('\n' + json.dumps(INFO, indent=2))

    # ------------------------------------
    # 1. Define dataloader
    train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(
        tb, vb)
    weights = (1 / num_of_images) / ((1 / num_of_images).sum().item())
    weights = weights.to(device=device)

    # ------------------------------------
    # 2. Define model
    model = EfficientNet.from_pretrained(
        'efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
    model = carrier(model)

    # ------------------------------------
    # 3. Define optimizer
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    ignite_scheduler = LRScheduler(scheduler)

    # ------------------------------------
    # 4. Define metrics
    train_metrics = {
        'accuracy':
        Accuracy(),
        'loss':
        Loss(nn.CrossEntropyLoss(weight=weights)),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(), Recall(),
                      train_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(CMatrixTable,
                      ConfusionMatrix(INFO['dataset-info']['num-of-classes']),
                      train_loader.dataset.classes)
    }

    def val_pred_transform(output):
        y_pred, y = output
        new_y_pred = torch.zeros(
            (y_pred.shape[0],
             len(INFO['dataset-info']['known-classes']) + 1)).to(device=device)
        for c in range(y_pred.shape[1]):
            if c == 0:
                new_y_pred[:, mapping[c]] += y_pred[:, c]
            elif mapping[c] == val_loader.dataset.class_to_idx['UNKNOWN']:
                new_y_pred[:, mapping[c]] = torch.where(
                    new_y_pred[:, mapping[c]] > y_pred[:, c],
                    new_y_pred[:, mapping[c]], y_pred[:, c])
            else:
                new_y_pred[:, mapping[c]] += y_pred[:, c]
        return new_y_pred, y

    val_metrics = {
        'accuracy':
        Accuracy(val_pred_transform),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(val_pred_transform),
                      Recall(val_pred_transform), val_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(
            CMatrixTable,
            ConfusionMatrix(len(INFO['dataset-info']['known-classes']) + 1,
                            output_transform=val_pred_transform),
            val_loader.dataset.classes)
    }

    # ------------------------------------
    # 5. Create trainer
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        nn.CrossEntropyLoss(weight=weights),
                                        device=device)

    # ------------------------------------
    # 6. Create evaluator
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=train_metrics,
                                                  device=device)
    val_evaluator = create_supervised_evaluator(model,
                                                metrics=val_metrics,
                                                device=device)

    desc = 'ITERATION - loss: {:.4f}'
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    # ------------------------------------
    # 7. Create event hooks
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        log_interval = 1
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        print('Checking on training set.')
        train_evaluator.run(train4val_loader)
        metrics = train_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['loss']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      <Training> Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, avg_loss,
                 precision_recall['pretty'], cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss},
                           engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        print('Checking on validation set.')
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      <Validating> Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, precision_recall['pretty'],
                 cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars(
            'Aggregate/Score', {
                'Val avg precision': precision_recall['data'][0, -1],
                'Val avg recall': precision_recall['data'][1, -1]
            }, engine.state.epoch)
        pbar.n = pbar.last_print_n = 0

    trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

    # ------------------------------------
    # Run
    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()
def run(tb, vb, lr, epochs, writer):
    device = os.environ['main-device']
    logging.info('Training program start!')
    logging.info('Configuration:')
    logging.info('\n' + json.dumps(INFO, indent=2))

    # ------------------------------------
    # 1. Define dataloader
    train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(
        tb, vb)
    # train_loader, train4val_loader, val_loader, num_of_images = get_dataloaders(tb, vb)
    weights = (1 / num_of_images) / ((1 / num_of_images).sum().item())
    # weights = (1/num_of_images)/(1/num_of_images + 1/(num_of_images.sum().item()-num_of_images))
    weights = weights.to(device=device)

    # ------------------------------------
    # 2. Define model
    model = EfficientNet.from_pretrained(
        'efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
    model = carrier(model)

    # ------------------------------------
    # 3. Define optimizer
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    ignite_scheduler = LRScheduler(scheduler)

    # ------------------------------------
    # 4. Define metrics

    class CenterLoss(nn.Module):
        """Center loss.
    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
        def __init__(self, num_classes=10, feat_dim=2):
            super(CenterLoss, self).__init__()
            self.num_classes = num_classes
            self.feat_dim = feat_dim
            self.centers = torch.randn(self.num_classes,
                                       self.feat_dim).to(device=device)

        def forward(self, x, labels):
            """
      Args:
        x: feature matrix with shape (batch_size, feat_dim).
        labels: ground truth labels with shape (batch_size).
      """
            batch_size = x.size(0)
            distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                      torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
            distmat.addmm_(1, -2, x, self.centers.t())
            classes = torch.arange(self.num_classes).long()
            classes = classes.to(device=device)
            labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
            mask = labels.eq(classes.expand(batch_size, self.num_classes))
            dist = distmat * mask.float()
            loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
            return loss

    class MixLoss(nn.Module):
        def __init__(self):
            super(MixLoss, self).__init__()

        def forward(self, x, y):
            _lambda = 0.5
            center = CenterLoss(x.shape[0], x.shape[1])
            ce = nn.CrossEntropyLoss(weight=weights)
            center = center(x, y)
            ce = ce(x, y)
            # print('{:.4f} | {:.4f}'.format(center, ce))
            return _lambda * center + ce

    class EntropyPrediction(metric.Metric):
        def __init__(self, threshold=1.0):
            super(EntropyPrediction, self).__init__()
            self.threshold = threshold
            self.prediction = torch.tensor([], dtype=torch.int)
            self.y = torch.tensor([], dtype=torch.int)

        def reset(self):
            # self.threshold = 0.5
            self.prediction = torch.tensor([])
            self.y = torch.tensor([])
            super(EntropyPrediction, self).reset()

        def update(self, output):
            y_pred, y = output
            softmax = torch.exp(y_pred) / torch.exp(y_pred).sum(1)[:, None]
            entropy_base = math.log(y_pred.shape[1])
            entropy = (-softmax * torch.log(softmax)).sum(1) / entropy_base
            values, inds = softmax.max(1)
            prediction = torch.where(entropy < self.threshold, inds,
                                     torch.tensor([-1]).to(device=device))
            self.prediction = torch.cat(
                (self.prediction.type(torch.LongTensor).to(device=device),
                 torch.tensor([mapping[x.item()]
                               for x in prediction]).to(device=device)))
            self.y = torch.cat(
                (self.y.type(torch.LongTensor).to(device=device),
                 y.to(device=device)))
            # return self.prediction, self.y

        def compute(self):
            return self.prediction, self.y

    train_metrics = {
        'accuracy':
        Accuracy(),
        'loss':
        Loss(MixLoss()),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(), Recall(),
                      train_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(CMatrixTable,
                      ConfusionMatrix(INFO['dataset-info']['num-of-classes']),
                      train_loader.dataset.classes)
    }

    val_metrics = {
        'accuracy':
        MetricsLambda(Labels2Acc, EntropyPrediction(1.0)),
        'precision_recall':
        MetricsLambda(Labels2PrecisionRecall, EntropyPrediction(1.0),
                      val_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(Labels2CMatrix, EntropyPrediction(1.0),
                      val_loader.dataset.classes)
    }

    # ------------------------------------
    # 5. Create trainer
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        MixLoss(),
                                        device=device)

    # ------------------------------------
    # 6. Create evaluator
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=train_metrics,
                                                  device=device)
    val_evaluator = create_supervised_evaluator(model,
                                                metrics=val_metrics,
                                                device=device)

    desc = 'ITERATION - loss: {:.4f}'
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    # ------------------------------------
    # 7. Create event hooks

    # Update process bar on each iteration completed.
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        log_interval = 1
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

    @trainer.on(Events.EPOCH_STARTED)
    def refresh_pbar(engine):
        torch.cuda.empty_cache()
        print('Finish epoch {}'.format(engine.state.epoch))
        pbar.refresh()
        pbar.n = pbar.last_print_n = 0

    # Compute metrics on train data on each epoch completed.
    # cpe = CustomPeriodicEvent(n_epochs=50)
    # cpe.attach(trainer)
    # @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
    def log_training_results(engine):
        print('Checking on training set.')
        train_evaluator.run(train4val_loader)
        metrics = train_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['loss']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, avg_loss,
                 precision_recall['pretty'], cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss},
                           engine.state.epoch)

    # Compute metrics on val data on each epoch completed.
    # cpe = CustomPeriodicEvent(n_epochs=50)
    # cpe.attach(trainer)
    # @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
    def log_validation_results(engine):
        pbar.clear()
        print('* - * - * - * - * - * - * - * - * - * - * - * - *')
        print('Checking on validation set.')
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      Validating Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, precision_recall['pretty'],
                 cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars(
            'Aggregate/Score', {
                'Val avg precision': precision_recall['data'][0, -1],
                'Val avg recall': precision_recall['data'][1, -1]
            }, engine.state.epoch)

    cpe = CustomPeriodicEvent(n_epochs=50)
    cpe.attach(trainer)
    # @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
    trainer.add_event_handler(cpe.Events.EPOCHS_50_COMPLETED,
                              log_training_results)
    trainer.add_event_handler(cpe.Events.EPOCHS_50_COMPLETED,
                              log_validation_results)
    trainer.add_event_handler(Events.STARTED, log_training_results)
    trainer.add_event_handler(Events.STARTED, log_validation_results)

    # Save model ever N epoch.
    save_model_handler = ModelCheckpoint(os.environ['savedir'],
                                         '',
                                         save_interval=10,
                                         n_saved=2)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler,
                              {'model': model})

    # Update learning-rate due to scheduler.
    trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

    # ------------------------------------
    # Run
    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()