def run(tb, vb, lr, epochs, writer):
    device = os.environ['main-device']
    logging.info('Training program start!')
    logging.info('Configuration:')
    logging.info('\n' + json.dumps(INFO, indent=2))

    # ------------------------------------
    # 1. Define dataloader
    train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(
        tb, vb)
    weights = (1 / num_of_images) / ((1 / num_of_images).sum().item())
    weights = weights.to(device=device)

    # ------------------------------------
    # 2. Define model
    model = EfficientNet.from_pretrained(
        'efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
    model = carrier(model)

    # ------------------------------------
    # 3. Define optimizer
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    ignite_scheduler = LRScheduler(scheduler)

    # ------------------------------------
    # 4. Define metrics
    train_metrics = {
        'accuracy':
        Accuracy(),
        'loss':
        Loss(nn.CrossEntropyLoss(weight=weights)),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(), Recall(),
                      train_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(CMatrixTable,
                      ConfusionMatrix(INFO['dataset-info']['num-of-classes']),
                      train_loader.dataset.classes)
    }

    def val_pred_transform(output):
        y_pred, y = output
        new_y_pred = torch.zeros(
            (y_pred.shape[0],
             len(INFO['dataset-info']['known-classes']) + 1)).to(device=device)
        for c in range(y_pred.shape[1]):
            if c == 0:
                new_y_pred[:, mapping[c]] += y_pred[:, c]
            elif mapping[c] == val_loader.dataset.class_to_idx['UNKNOWN']:
                new_y_pred[:, mapping[c]] = torch.where(
                    new_y_pred[:, mapping[c]] > y_pred[:, c],
                    new_y_pred[:, mapping[c]], y_pred[:, c])
            else:
                new_y_pred[:, mapping[c]] += y_pred[:, c]
        return new_y_pred, y

    val_metrics = {
        'accuracy':
        Accuracy(val_pred_transform),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(val_pred_transform),
                      Recall(val_pred_transform), val_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(
            CMatrixTable,
            ConfusionMatrix(len(INFO['dataset-info']['known-classes']) + 1,
                            output_transform=val_pred_transform),
            val_loader.dataset.classes)
    }

    # ------------------------------------
    # 5. Create trainer
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        nn.CrossEntropyLoss(weight=weights),
                                        device=device)

    # ------------------------------------
    # 6. Create evaluator
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=train_metrics,
                                                  device=device)
    val_evaluator = create_supervised_evaluator(model,
                                                metrics=val_metrics,
                                                device=device)

    desc = 'ITERATION - loss: {:.4f}'
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    # ------------------------------------
    # 7. Create event hooks
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        log_interval = 1
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        print('Checking on training set.')
        train_evaluator.run(train4val_loader)
        metrics = train_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['loss']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      <Training> Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, avg_loss,
                 precision_recall['pretty'], cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss},
                           engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        print('Checking on validation set.')
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      <Validating> Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, precision_recall['pretty'],
                 cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars(
            'Aggregate/Score', {
                'Val avg precision': precision_recall['data'][0, -1],
                'Val avg recall': precision_recall['data'][1, -1]
            }, engine.state.epoch)
        pbar.n = pbar.last_print_n = 0

    trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

    # ------------------------------------
    # Run
    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()
예제 #2
0
def run(tb, vb, lr, epochs, writer):
  device = os.environ['main-device']
  logging.info('Training program start!')
  logging.info('Configuration:')
  logging.info('\n'+json.dumps(INFO, indent=2))

  # ------------------------------------
  # 1. Define dataloader
  train_loader, train4val_loader, val_loader = get_dataloaders(tb, vb)
  
  # ------------------------------------
  # 2. Define model
  model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
  model = carrier(model)
  
  # ------------------------------------
  # 3. Define optimizer
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  
  # ------------------------------------
  # 4. Define metrics
  metrics = {
    'accuracy': Accuracy(),
    'loss': Loss(nn.functional.cross_entropy),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes),
    'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(7), train_loader.dataset.classes)
  }
  
  # ------------------------------------
  # 5. Create trainer
  trainer = create_supervised_trainer(model, optimizer, nn.functional.cross_entropy, device=device)
  
  # ------------------------------------
  # 6. Create evaluator
  evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

  desc = 'ITERATION - loss: {:.4f}'
  pbar = tqdm(
    initial=0, leave=False, total=len(train_loader),
    desc=desc.format(0)
  )


  # ------------------------------------
  # 7. Create event hooks
  @trainer.on(Events.ITERATION_COMPLETED)
  def log_training_loss(engine):
    log_interval = 5
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
      pbar.desc = desc.format(engine.state.output)
      pbar.update(log_interval)

  @trainer.on(Events.EPOCH_COMPLETED)
  def log_training_results(engine):
    pbar.refresh()
    print ('Checking on training set.')
    evaluator.run(train4val_loader)
    metrics = evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_loss = metrics['loss']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = GetTemplate('default-log').format('Training',engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss}, engine.state.epoch)
    # writer.add_scalars('Aggregate/Score', {'Train avg precision': precision_recall['data'][0, -1], 'Train avg recall': precision_recall['data'][1, -1]}, engine.state.epoch)
    # pbar.n = pbar.last_print_n = 0
  
  @trainer.on(Events.EPOCH_COMPLETED)
  def log_validation_results(engine):
    print ('Checking on validation set.')
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_loss = metrics['loss']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = GetTemplate('default-log').format('Validating',engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Loss', {'Val Loss': avg_loss}, engine.state.epoch)
    writer.add_scalars('Aggregate/Score', {'Val avg precision': precision_recall['data'][0, -1], 'Val avg recall': precision_recall['data'][1, -1]}, engine.state.epoch)
    pbar.n = pbar.last_print_n = 0

  # ------------------------------------
  # Run
  trainer.run(train_loader, max_epochs=epochs)
  pbar.close()
def run(tb, vb, lr, epochs, writer):
  device = os.environ['main-device']
  logging.info('Training program start!')
  logging.info('Configuration:')
  logging.info('\n'+json.dumps(INFO, indent=2))

  # ------------------------------------
  # 1. Define dataloader
  train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(tb, vb)
  weights = (1/num_of_images)/((1/num_of_images).sum().item())
  # weights = (1/num_of_images)/(1/num_of_images + 1/(num_of_images.sum().item()-num_of_images))
  weights = weights.to(device=device)
  
  # ------------------------------------
  # 2. Define model
  model = EfficientNet.from_pretrained('efficientnet-b4', num_classes=INFO['dataset-info']['num-of-classes'])
  model = carrier(model)
  
  # ------------------------------------
  # 3. Define optimizer
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
  ignite_scheduler = LRScheduler(scheduler)
  
  # ------------------------------------
  # 4. Define metrics

  train_metrics = {
    'accuracy': Accuracy(),
    'loss': Loss(nn.CrossEntropyLoss(weight=weights)),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes),
    'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(INFO['dataset-info']['num-of-classes']), train_loader.dataset.classes)
  }
  # ------------------------------------
  # 5. Create trainer
  trainer = create_supervised_trainer(model, optimizer, nn.CrossEntropyLoss(weight=weights), device=device)
  
  # ------------------------------------
  # 6. Create evaluator
  train_evaluator = create_supervised_evaluator(model, metrics=train_metrics, device=device)

  desc = 'ITERATION - loss: {:.4f}'
  pbar = tqdm(
    initial=0, leave=False, total=len(train_loader),
    desc=desc.format(0)
  )

  # ------------------------------------
  # 7. Create event hooks

  # Update process bar on each iteration completed.
  @trainer.on(Events.ITERATION_COMPLETED)
  def log_training_loss(engine):
    log_interval = 1
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
      pbar.desc = desc.format(engine.state.output)
      pbar.update(log_interval)

  # Refresh Process bar.
  @trainer.on(Events.EPOCH_COMPLETED)
  def refresh_pbar(engine):
    print ('Epoch {} completed!'.format(engine.state.epoch))
    pbar.refresh()
    pbar.n = pbar.last_print_n = 0

  # Compute metrics on train data on each epoch completed.
  # cpe = CustomPeriodicEvent(n_epochs=50)
  # cpe.attach(trainer)
  # @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
  def log_training_results(engine):
    pbar.refresh()
    print ('Checking on training set.')
    train_evaluator.run(train4val_loader)
    metrics = train_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_loss = metrics['loss']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Id: {}
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(os.environ['run-id'],engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss}, engine.state.epoch)
    # pbar.n = pbar.last_print_n = 0

  # Save model ever N epoch.
  save_model_handler = ModelCheckpoint(os.environ['savedir'], '', save_interval=10, n_saved=2)
  trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler, {'model': model})

  cpe = CustomPeriodicEvent(n_epochs=200)
  cpe.attach(trainer)
  trainer.add_event_handler(cpe.Events.EPOCHS_200_COMPLETED, log_training_results)

  # Update learning-rate due to scheduler.
  trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

  # ------------------------------------
  # Run
  trainer.run(train_loader, max_epochs=epochs)
  pbar.close()
def run(tb, vb, lr, epochs, writer):
    device = os.environ['main-device']
    logging.info('Training program start!')
    logging.info('Configuration:')
    logging.info('\n' + json.dumps(INFO, indent=2))

    # ------------------------------------
    # 1. Define dataloader
    train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(
        tb, vb)
    # train_loader, train4val_loader, val_loader, num_of_images = get_dataloaders(tb, vb)
    weights = (1 / num_of_images) / ((1 / num_of_images).sum().item())
    # weights = (1/num_of_images)/(1/num_of_images + 1/(num_of_images.sum().item()-num_of_images))
    weights = weights.to(device=device)

    # ------------------------------------
    # 2. Define model
    model = EfficientNet.from_pretrained(
        'efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
    model = carrier(model)

    # ------------------------------------
    # 3. Define optimizer
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    ignite_scheduler = LRScheduler(scheduler)

    # ------------------------------------
    # 4. Define metrics

    class DOCLoss(nn.Module):
        def __init__(self, weight):
            super(DOCLoss, self).__init__()
            self.class_weights = weight

        def forward(self, input, target):
            sigmoid = 1 - 1 / (1 + torch.exp(-input))
            sigmoid[range(0, sigmoid.shape[0]),
                    target] = 1 - sigmoid[range(0, sigmoid.shape[0]), target]
            sigmoid = torch.log(sigmoid)
            if self.class_weights is not None:
                loss = -torch.sum(sigmoid * self.class_weights)
            else:
                loss = -torch.sum(sigmoid)
            return loss

    train_metrics = {
        'accuracy':
        Accuracy(),
        'loss':
        Loss(DOCLoss(weight=weights)),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(), Recall(),
                      train_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(CMatrixTable,
                      ConfusionMatrix(INFO['dataset-info']['num-of-classes']),
                      train_loader.dataset.classes)
    }

    def val_pred_transform(output):
        y_pred, y = output
        new_y_pred = torch.zeros(
            (y_pred.shape[0],
             INFO['dataset-info']['num-of-classes'] + 1)).to(device=device)
        for ind, c in enumerate(train_loader.dataset.classes):
            new_col = val_loader.dataset.class_to_idx[c]
            new_y_pred[:, new_col] += y_pred[:, ind]
        ukn_ind = val_loader.dataset.class_to_idx['UNKNOWN']
        import math
        new_y_pred[:, ukn_ind] = -math.inf
        return new_y_pred, y

    val_metrics = {
        'accuracy':
        Accuracy(),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(val_pred_transform),
                      Recall(val_pred_transform), val_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(
            CMatrixTable,
            ConfusionMatrix(INFO['dataset-info']['num-of-classes'] + 1,
                            output_transform=val_pred_transform),
            val_loader.dataset.classes)
    }

    # ------------------------------------
    # 5. Create trainer
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        DOCLoss(weight=weights),
                                        device=device)

    # ------------------------------------
    # 6. Create evaluator
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=train_metrics,
                                                  device=device)
    val_evaluator = create_supervised_evaluator(model,
                                                metrics=val_metrics,
                                                device=device)

    desc = 'ITERATION - loss: {:.4f}'
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    # ------------------------------------
    # 7. Create event hooks

    # Update process bar on each iteration completed.
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        log_interval = 1
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

    # Compute metrics on train data on each epoch completed.
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        print('Checking on training set.')
        train_evaluator.run(train4val_loader)
        metrics = train_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['loss']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, avg_loss,
                 precision_recall['pretty'], cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss},
                           engine.state.epoch)

    # Compute metrics on val data on each epoch completed.
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        print('Checking on validation set.')
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      Validating Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, precision_recall['pretty'],
                 cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars(
            'Aggregate/Score', {
                'Val avg precision': precision_recall['data'][0, -1],
                'Val avg recall': precision_recall['data'][1, -1]
            }, engine.state.epoch)
        pbar.n = pbar.last_print_n = 0

    # Save model ever N epoch.
    save_model_handler = ModelCheckpoint(os.environ['savedir'],
                                         '',
                                         save_interval=50,
                                         n_saved=2)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler,
                              {'model': model})

    # Update learning-rate due to scheduler.
    trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

    # ------------------------------------
    # Run
    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()
예제 #5
0
def run(tb, vb, lr, epochs, writer):
  device = os.environ['main-device']
  logging.info('Training program start!')
  logging.info('Configuration:')
  logging.info('\n'+json.dumps(INFO, indent=2))

  # ------------------------------------
  # 1. Define dataloader
  train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(tb, vb)
  # train_loader, train4val_loader, val_loader, num_of_images = get_dataloaders(tb, vb)
  weights = (1/num_of_images)/((1/num_of_images).sum().item())
  # weights = (1/num_of_images)/(1/num_of_images + 1/(num_of_images.sum().item()-num_of_images))
  weights = weights.to(device=device)
  
  # ------------------------------------
  # 2. Define model
  model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
  model = carrier(model)
  
  # ------------------------------------
  # 3. Define optimizer
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
  ignite_scheduler = LRScheduler(scheduler)
  
  # ------------------------------------
  # 4. Define metrics

  # class SoftCrossEntropyLoss(nn.Module):
  #   def __init__(self, weight=None):
  #     super(SoftCrossEntropyLoss, self).__init__()
  #     self.class_weights = weight
    
  #   def forward(self, input, target):
  #     softmax = torch.exp(input) / torch.exp(input).sum(1)[:, None]
  #     onehot_labels = to_onehot(target, input.shape[1])
  #     soft_labels = torch.zeros_like(onehot_labels)
  #     soft_labels = torch.where(onehot_labels.cpu() == 1, torch.tensor([0.9]), torch.tensor([0.1/(input.shape[1]-1)])).to(device=device)
  #     if self.class_weights is not None:
  #       # print(soft_labels.shape, softmax.shape)
  #       loss = -torch.sum(torch.log(softmax) * soft_labels * self.class_weights * input.shape[1])
  #     else:
  #       loss = -torch.sum(torch.log(softmax) * soft_labels)
  #     return loss

  class CrossEntropyLoss_forceavg(nn.Module):
    def __init__(self, weight=None):
      super(CrossEntropyLoss_forceavg, self).__init__()
      self.class_weights = weight.to(device=device)

    def forward(self, input, target):
      ce = nn.functional.cross_entropy(input, target, weight=self.class_weights)
      avg = (input ** 2).sum() - (input[range(input.shape[0]), target] ** 2).sum()
      # print(ce, avg/input.shape[0])
      return ce + avg/input.shape[0]

  class EntropyPrediction(metric.Metric):
    def __init__(self, threshold=0.5):
      super(EntropyPrediction, self).__init__()
      self.threshold = threshold
      self.prediction = torch.tensor([], dtype=torch.int)
      self.y = torch.tensor([], dtype=torch.int)
    
    def reset(self):
      # self.threshold = 0.5
      self.prediction = torch.tensor([])
      self.y = torch.tensor([])
      super(EntropyPrediction, self).reset()
    
    def update(self, output):
      y_pred, y = output
      softmax = torch.exp(y_pred) / torch.exp(y_pred).sum(1)[:, None]
      entropy_base = math.log(y_pred.shape[1])
      entropy = (-softmax * torch.log(softmax)).sum(1)/entropy_base
      values, inds = softmax.max(1)
      prediction = torch.where(entropy<self.threshold, inds, torch.tensor([-1]).to(device=device))
      self.prediction = torch.cat((self.prediction.type(torch.LongTensor).to(device=device), torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)))
      self.y = torch.cat((self.y.type(torch.LongTensor).to(device=device), y.to(device=device)))
      # return self.prediction, self.y

    def compute(self):
      return self.prediction, self.y

  train_metrics = {
    'accuracy': Accuracy(),
    'loss': Loss(CrossEntropyLoss_forceavg(weight=weights)),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes),
    'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(INFO['dataset-info']['num-of-classes']), train_loader.dataset.classes)
  }

  val_metrics = {
    'accuracy': MetricsLambda(Labels2Acc, EntropyPrediction(1.0)),
    'precision_recall': MetricsLambda(Labels2PrecisionRecall, EntropyPrediction(1.0), val_loader.dataset.classes),
    'cmatrix': MetricsLambda(Labels2CMatrix, EntropyPrediction(1.0), val_loader.dataset.classes)
  }
  
  # ------------------------------------
  # 5. Create trainer
  trainer = create_supervised_trainer(model, optimizer, CrossEntropyLoss_forceavg(weight=weights), device=device)
  
  # ------------------------------------
  # 6. Create evaluator
  train_evaluator = create_supervised_evaluator(model, metrics=train_metrics, device=device)
  val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)

  desc = 'ITERATION - loss: {:.4f}'
  pbar = tqdm(
    initial=0, leave=False, total=len(train_loader),
    desc=desc.format(0)
  )


  # ------------------------------------
  # 7. Create event hooks

  # Update process bar on each iteration completed.
  @trainer.on(Events.ITERATION_COMPLETED)
  def log_training_loss(engine):
    log_interval = 1
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
      pbar.desc = desc.format(engine.state.output)
      pbar.update(log_interval)

  @trainer.on(Events.EPOCH_STARTED)
  def refresh_pbar(engine):
    print('Finish epoch {}'.format(engine.state.epoch))
    pbar.refresh()
    pbar.n = pbar.last_print_n = 0

  # Compute metrics on train data on each epoch completed.
  cpe = CustomPeriodicEvent(n_epochs=50)
  cpe.attach(trainer)
  @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
  def log_training_results(engine):
    print ('Checking on training set.')
    train_evaluator.run(train4val_loader)
    metrics = train_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_loss = metrics['loss']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss}, engine.state.epoch)
  
  # Compute metrics on val data on each epoch completed.
  cpe = CustomPeriodicEvent(n_epochs=50)
  cpe.attach(trainer)
  @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
  def log_validation_results(engine):
    pbar.clear()
    print('* - * - * - * - * - * - * - * - * - * - * - * - *')
    print ('Checking on validation set.')
    val_evaluator.run(val_loader)
    metrics = val_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Validating Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch,avg_accuracy,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Score', {'Val avg precision': precision_recall['data'][0, -1], 'Val avg recall': precision_recall['data'][1, -1]}, engine.state.epoch)

  # Save model ever N epoch.
  save_model_handler = ModelCheckpoint(os.environ['savedir'], '', save_interval=10, n_saved=2)
  trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler, {'model': model})

  # Update learning-rate due to scheduler.
  trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

  # ------------------------------------
  # Run
  trainer.run(train_loader, max_epochs=epochs)
  pbar.close()
예제 #6
0
def run(tb, vb, lr, epochs, writer):
  device = os.environ['main-device']
  logging.info('Training program start!')
  logging.info('Configuration:')
  logging.info('\n'+json.dumps(INFO, indent=2))

  # ------------------------------------
  # 1. Define dataloader
  train_loader, train4val_loader, val_loader, num_of_images, mapping, support_train_loader, support_val_loader = get_dataloaders(tb, vb)
  weights = (1/num_of_images)/((1/num_of_images).sum().item())
  weights = weights.to(device=device)

  # Build iterable mix up batch loader
  it = iter(train_loader)
  sup_it = iter(support_train_loader)
  mixup_batches = zip(it, cycle(sup_it))

  # ------------------------------------
  # 2. Define model
  model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
  model = carrier(model)
  support_model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['supportset-info']['num-of-classes'])
  support_model = carrier(support_model)
  
  # ------------------------------------
  # 3. Define optimizer
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
  ignite_scheduler = LRScheduler(scheduler)

  support_optimizer = optim.SGD(support_model.parameters(), lr=lr, momentum=0.9)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(support_optimizer, T_max=200)
  support_ignite_scheduler = LRScheduler(scheduler)
  
  # ------------------------------------
  # 4. Define metrics
  train_metrics = {
    'accuracy': Accuracy(),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes),
    'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(INFO['dataset-info']['num-of-classes']), train_loader.dataset.classes)
  }

  support_metrics = {
    'accuracy': Accuracy(),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), support_val_loader.dataset.classes)
  }

  class DeepTransPrediction(metric.Metric):
    def __init__(self, threshold=torch.tensor([0.5]).repeat(len(train_loader.dataset.classes))):
      super(DeepTransPrediction, self).__init__()
      threshold = threshold.to(device=device)
      self.threshold = threshold
      self.prediction = torch.tensor([], dtype=torch.int)
      self.y = torch.tensor([], dtype=torch.int)
    
    def reset(self):
      self.threshold = torch.tensor([0.5]).repeat(len(train_loader.dataset.classes)).to(device=device)
      self.prediction = torch.tensor([])
      self.y = torch.tensor([])
      super(DeepTransPrediction, self).reset()
    
    def update(self, output):
      y_pred, y = output
      softmax = torch.exp(y_pred) / torch.exp(y_pred).sum(1)[:, None]
      values, inds = softmax.max(1)
      prediction = torch.where(values>self.threshold[inds], inds, torch.tensor([-1]).to(device=device))
      self.prediction = torch.cat((self.prediction.type(torch.LongTensor).to(device=device), torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)))
      self.y = torch.cat((self.y.type(torch.LongTensor).to(device=device), y.to(device=device)))
      # return self.prediction, self.y

    def compute(self):
      return self.prediction, self.y

  val_metrics = {
    'accuracy': MetricsLambda(Labels2Acc, DeepTransPrediction()),
    'precision_recall': MetricsLambda(Labels2PrecisionRecall, DeepTransPrediction(), val_loader.dataset.classes),
    'cmatrix': MetricsLambda(Labels2CMatrix, DeepTransPrediction(), val_loader.dataset.classes)
  }
  
  # ------------------------------------
  # 5. Create trainer
  # trainer = create_supervised_trainer(model, optimizer, nn.CrossEntropyLoss(weight=weights), device=device)

  def membership_loss(input, target, weights):
    _lambda = 5
    classes = input.shape[1]
    sigmoid = 1 / (1 + torch.exp(-input))
    part1 = 1-sigmoid[range(sigmoid.shape[0]), target]
    part1 = (part1 * part1 * weights[target]).sum()
    sigmoid[range(sigmoid.shape[0])] = 0
    part2 = (sigmoid * sigmoid * weights).sum()
    return part1 + _lambda*float(1/(classes-1))*part2

  def step(engine, batch):
    model.train()
    support_model.train()

    _alpha1 = 1
    _alpha2 = 1

    known, support = batch
    x_known, y_known = known
    x_support, y_support = support

    x_known = x_known.to(device=device)
    y_known = y_known.to(device=device)
    x_support = x_support.to(device=device)
    y_support = y_support.to(device=device)

    support_scores = support_model(x_support)
    support_cross_entropy = nn.functional.cross_entropy(support_scores, y_support)

    known_scores = model(x_known)
    known_cross_entropy = nn.functional.cross_entropy(known_scores, y_known, weights)
    known_membership = membership_loss(known_scores, y_known, weights)

    loss = support_cross_entropy + known_cross_entropy * _alpha1 + known_membership * _alpha2

    model.zero_grad()
    support_model.zero_grad()

    loss.backward()

    optimizer.step()
    support_optimizer.step()

    return {
      'Rce_loss': support_cross_entropy.item(),
      'Tce_loss': known_cross_entropy.item(),
      'Tm_loss': known_membership.item(),
      'total_loss': loss.item()
    }

  trainer = Engine(step)

  # ------------------------------------
  # 6. Create evaluator
  train_evaluator = create_supervised_evaluator(model, metrics=train_metrics, device=device)
  val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
  support_evaluator = create_supervised_evaluator(support_model, metrics=support_metrics, device=device)

  desc = 'ITERATION - loss: {:.2f}|{:.2f}|{:.2f}|{:.2f}'
  pbar = tqdm(
    initial=0, leave=False, total=len(train_loader),
    desc=desc.format(0,0,0,0)
  )

  # ------------------------------------
  # 7. Create event hooks

  @trainer.on(Events.ITERATION_COMPLETED)
  def log_training_loss(engine):
    log_interval = 1
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
      o = engine.state.output
      pbar.desc = desc.format(o['Rce_loss'], o['Tce_loss'], o['Tm_loss'], o['total_loss'])
      pbar.update(log_interval)

  @trainer.on(Events.EPOCH_STARTED)
  def rebuild_dataloader(engine):
    pbar.clear()
    print('Rebuild dataloader!')
    it = iter(train_loader)
    sup_it = iter(support_train_loader)
    engine.state.dataloader = zip(it, cycle(sup_it))

  @trainer.on(Events.EPOCH_COMPLETED)
  def log_training_results(engine):
    print ('Checking on training set.')
    train_evaluator.run(train4val_loader)
    metrics = train_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Id: {}
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      
      precision_recall: \n{}
      
      confusion matrix: \n{}
      """.format(os.environ['run-id'], engine.state.epoch,avg_accuracy,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)

  @trainer.on(Events.EPOCH_COMPLETED)
  def log_support_results(engine):
    pbar.clear()
    print ('* - * - * - * - * - * - * - * - * - * - * - *')
    print ('Checking on support set.')
    support_evaluator.run(support_val_loader)
    metrics = support_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    precision_recall = metrics['precision_recall']

    prompt = """
    Id: {}
    Support set Results - Epoch: {}
    Avg accuracy: {:.4f}
    precision_recall: \n{}
    """.format(os.environ['run-id'], engine.state.epoch, avg_accuracy, precision_recall['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Support/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Support/Score', {'Avg precision': precision_recall['data'][0, -1], 'Avg recall': precision_recall['data'][1, -1]}, engine.state.epoch)

  @trainer.on(Events.EPOCH_COMPLETED)
  def log_validation_results(engine):
    pbar.clear()
    print ('* - * - * - * - * - * - * - * - * - * - * - *')
    print ('Checking on validation set.')
    val_evaluator.run(val_loader)
    metrics = val_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    unknown = precision_recall['pretty']['UNKNOWN']
    print(unknown)
    prompt = """
      Id: {}
      Validating Results - Epoch: {}
      Avg accuracy: {:.4f}
      Unknown precision: {:.4f}
      Unknown recall: {:.4f}
      
      precision_recall: \n{}
      
      confusion matrix: \n{}
      """.format(os.environ['run-id'], engine.state.epoch,avg_accuracy, unknown['Precision'], unknown['Recall'],precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Score', {'Val avg Precision': precision_recall['data'][0, -1], 'Val avg Recall': precision_recall['data'][1, -1]}, engine.state.epoch)
    writer.add_scalars('Unknown/Score', {'Unknown Precision': unknown['Precision'], 'Unknown Recall': unknown['Recall']}, engine.state.epoch)
    pbar.n = pbar.last_print_n = 0

  trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)
  trainer.add_event_handler(Events.EPOCH_STARTED, support_ignite_scheduler)

  # ------------------------------------
  # Run
  trainer.run(mixup_batches, max_epochs=epochs)
  pbar.close()
def evaluate(tb, vb, modelpath):
    device = os.environ['main-device']
    logging.info('Evaluating program start!')
    threshold = 0.6

    # Get dataloader
    train_loader, train4val_loader, val_loader, num_of_images, mapping, imgs = get_dataloaders(
        tb, vb)

    # Get Model
    model_paths = glob.glob(modelpath + '/*')
    models = []
    for modelpath in model_paths:
        model = EfficientNet.from_pretrained(
            'efficientnet-b0',
            num_classes=INFO['dataset-info']['num-of-classes'])
        model = carrier(model)
        model.load_state_dict(torch.load(modelpath, map_location=device))
        models.append(model)
        # model = torch.load(modelpath, map_location=device)['model']

    class entropy(metric.Metric):
        def __init__(self):
            super(entropy, self).__init__()
            # self.values = torch.tensor([], dtype=torch.float)
            self.entropy_rate = torch.tensor([], dtype=torch.float)
            self.inds = torch.tensor([], dtype=torch.int)
            self.y = torch.tensor([], dtype=torch.int)
            self.softmax = torch.tensor([], dtype=torch.float)

        def reset(self):
            # self.values = torch.tensor([])
            self.entropy_rate = torch.tensor([], dtype=torch.float)
            self.inds = torch.tensor([], dtype=torch.int)
            self.y = torch.tensor([], dtype=torch.int)
            self.softmax = torch.tensor([], dtype=torch.float)
            super(entropy, self).reset()

        def update(self, output):
            y_pred, y = output
            softmax = torch.exp(y_pred) / torch.exp(y_pred).sum(1)[:, None]
            # print(softmax)
            entropy_base = math.log(y_pred.shape[1])
            entropy_rate = (-softmax *
                            torch.log(softmax)).sum(1) / entropy_base
            _, inds = softmax.max(1)
            # prediction = torch.where(entropy>self.threshold, inds, torch.tensor([-1]).to(device=device))
            # self.prediction = torch.cat((self.prediction.type(torch.LongTensor).to(device=device), torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)))
            self.softmax = torch.cat(
                (self.softmax.to(device=device), softmax)).to(device=device)
            self.entropy_rate = torch.cat((self.entropy_rate.to(device=device),
                                           entropy_rate)).to(device=device)
            self.y = torch.cat(
                (self.y.type(torch.LongTensor).to(device=device),
                 y.to(device=device)))
            self.inds = torch.cat(
                (self.inds.type(torch.LongTensor).to(device=device),
                 inds.to(device=device)))

        def compute(self):
            return self.softmax, self.entropy_rate, self.inds, self.y

    val_metrics = {'result': entropy()}

    metric_list = []
    k = 0
    for model in models:
        val_evaluator = create_supervised_evaluator(model,
                                                    metrics=val_metrics,
                                                    device=device)
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics['result']
        softmax, er, inds, y_true = metrics
        m = {
            'model': model_paths[k],
            'softmax': softmax,
            'er': er,
            'inds': inds,
            'y': y_true
        }
        k += 1
        metric_list.append(m)
        print('Finish 1!')

    def log_validation_results(threshold, metric):
        name = metric['model']
        entropy_rate = metric['er']
        inds = metric['inds']
        y = metric['y']
        # print(entropy)
        # print(threshold)
        # print(inds)
        prediction = torch.where(entropy_rate < threshold, inds,
                                 torch.tensor([-1]).to(device=device))
        prediction = torch.tensor([mapping[x.item()]
                                   for x in prediction]).to(device=device)

        avg_accuracy = Labels2Acc((prediction, y))
        precision_recall = Labels2PrecisionRecall((prediction, y),
                                                  val_loader.dataset.classes)
        cmatrix = Labels2CMatrix((prediction, y), val_loader.dataset.classes)
        prompt = """
      Model: {}
      Threshold: {}

      Avg accuracy: {:.4f}

      precision_recall: \n{}

      confusion matrix: \n{}
      """.format(name, threshold, avg_accuracy, precision_recall['pretty'],
                 cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        return {'mean_recall': precision_recall['pretty']['mean']['Recall']}

    def get_mean_softmax(metric_list):
        mean_softmax = None
        for metrics in metric_list:
            softmax = metrics['softmax']
            if mean_softmax is not None:
                mean_softmax = mean_softmax + softmax
            else:
                mean_softmax = softmax
        return mean_softmax / len(metric_list)

    def log_mean_results(threshold, softmax, y_true):
        entropy_base = math.log(softmax.shape[1])
        entropy_rate = (-softmax * torch.log(softmax)).sum(1) / entropy_base
        # print(entropy_rate)
        _, inds = softmax.max(1)
        prediction = torch.where(entropy_rate < threshold, inds,
                                 torch.tensor([-1]).to(device=device))
        prediction = torch.tensor([mapping[x.item()]
                                   for x in prediction]).to(device=device)

        high_confidence_inds = (entropy_rate < 1e-3).nonzero()
        low_confidence_inds = (entropy_rate > threshold).nonzero()
        high_confidence = np.array([{
            'from': int(imgs[x][1]),
            'to': inds[x].item(),
            'img': imgs[x][0],
            'er': entropy_rate[x].item()
        } for x in high_confidence_inds])
        low_confidence = np.array([{
            'from': int(imgs[x][1]),
            'to': -1,
            'img': imgs[x][0],
            'er': entropy_rate[x].item()
        } for x in low_confidence_inds])

        avg_accuracy = Labels2Acc((prediction, y_true))
        precision_recall = Labels2PrecisionRecall((prediction, y_true),
                                                  val_loader.dataset.classes)
        cmatrix = Labels2CMatrix((prediction, y_true),
                                 val_loader.dataset.classes)

        prompt = """
      Threshold: \n{}

      Avg accuracy: {:.4f}

      precision_recall: \n{}

      confusion matrix: \n{}
      """.format(threshold, avg_accuracy, precision_recall['pretty'],
                 cmatrix['pretty'])
        print(prompt)
        return high_confidence, low_confidence

    scores = {}

    # test1 = log_validation_results(1.0)

    for metrics in metric_list:
        score = log_validation_results(threshold, metrics)

    high, low = log_mean_results(threshold, get_mean_softmax(metric_list),
                                 metric_list[0]['y'])

    # print(high)
    print('High confidence known: {} (correct: {})'.format(
        len(high), sum([x['from'] == x['to'] for x in high])))
    print('Low confidence known: {} (correct: {})'.format(
        len(low), sum([x['from'] == mapping[x['to']] for x in low])))

    def transduct(datasets, img_pack, rate=0.8):
        for dset_ind in range(datasets):
            class_to_idx = val_loader.dataset.class_to_idx
            classes = val_loader.dataset.classes
            idx_to_classes = {}
            for c in classes:
                idx_to_classes[class_to_idx[c]] = c

            train_base = '{}/{}/Train'.format(os.environ['datadir-base'],
                                              INFO['dataset'])
            # source_base = '{}/{}/Val'.format(os.environ['datadir-base'], INFO['dataset'])
            dist_base = '{}/{}-transduct{}'.format(os.environ['datadir-base'],
                                                   INFO['dataset'], dset_ind)
            if not os.path.exists(dist_base):
                os.mkdir(dist_base)
            dist_base = '{}/Train'.format(dist_base)
            if not os.path.exists(dist_base):
                os.mkdir(dist_base)

            for c in classes:
                if not os.path.exists('{}/{}'.format(dist_base, c)):
                    os.mkdir('{}/{}'.format(dist_base, c))

            to_move = np.random.choice(img_pack, int(rate * img_pack.shape[0]))
            for img in to_move:
                source = img['img']
                img_name = source[source.rfind('/') + 1:]
                class_name = idx_to_classes[
                    img['to']] if img['to'] != -1 else 'UNKNOWN'
                dist = '{}/{}/{}'.format(dist_base, class_name, img_name)
                print('Move "{}" to "{}"!'.format(source, dist))
                os.popen('cp "{}" "{}"'.format(source, dist))

            # train_imgs = glob.glob('{}/**/*'.format(train_base))
            # for img in train_imgs:
            #   source = img
            #   dist = img.replace(train_base, dist_base)
            #   print('Move "{}" to "{}"!'.format(source, dist))
            #   os.popen('cp "{}" "{}"'.format(source, dist))

    transduct(2, high, 0.9)
    transduct(2, low, 0.6)
예제 #8
0
def evaluate(tb, vb, modelpath):
  device = os.environ['main-device']
  logging.info('Evaluating program start!')
  
  train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(tb, vb)

  model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
  model = carrier(model)
  model.load_state_dict(torch.load(modelpath))

  class sigmoidmax(metric.Metric):
    def __init__(self):
      super(sigmoidmax, self).__init__()
      self.values = torch.tensor([], dtype=torch.float)
      self.inds = torch.tensor([], dtype=torch.int)
      self.y = torch.tensor([], dtype=torch.int)
    
    def reset(self):
      self.values = torch.tensor([])
      self.inds = torch.tensor([])
      self.y = torch.tensor([])
      super(sigmoidmax, self).reset()
    
    def update(self, output):
      y_pred, y = output
      sigmoid = 1 / (1 + torch.exp(-y_pred))
      values, inds = sigmoid.max(1)
      # threshold = self.threshold[inds]
      # prediction = torch.where(values>threshold, inds, torch.tensor([-1]).to(device=device))
      # self.prediction = torch.cat((self.prediction.type(torch.LongTensor).to(device=device), torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)))
      self.values = torch.cat((self.values.to(device=device), values)).to(device=device)
      self.inds = torch.cat((self.inds.type(torch.LongTensor).to(device=device), inds)).to(device=device)
      self.y = torch.cat((self.y.type(torch.LongTensor).to(device=device), y.to(device=device)))

    def compute(self):
      return self.values, self.inds, self.y

  val_metrics = {
    'result': sigmoidmax()
  }

  val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
  val_evaluator.run(val_loader)
  metrics = val_evaluator.state.metrics
  values, inds, y = metrics['result']
  
  def log_validation_results(threshold, checking_ind):

    t = threshold[inds]

    prediction = torch.where(values>t, inds, torch.tensor([-1]).to(device=device))
    prediction = torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)

    avg_accuracy = Labels2Acc((prediction, y))
    precision_recall = Labels2PrecisionRecall((prediction, y), val_loader.dataset.classes)
    cmatrix = Labels2CMatrix((prediction, y), val_loader.dataset.classes)
    unknown = precision_recall['pretty']['UNKNOWN']
    unknown_f1 = 2/((1/unknown['Precision'])+(1/unknown['Recall']))
    prompt = """
      Threshold: \n{}

      Avg accuracy: {:.4f}

      Unknown precision: {:.4f}
      Unknown recall: {:.4f}
      Unknown F1 score: {:.4f}
      
      precision_recall: \n{}

      confusion matrix: \n{}
      """.format(threshold.cpu().numpy(),avg_accuracy,unknown['Precision'],unknown['Recall'],unknown_f1,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    return {
      'unknown_precision': unknown['Precision'],
      'unknown_recall': unknown['Recall'],
      'unknown_f1': unknown_f1
    }

  scores = {}

  key = 0
  threshold = torch.tensor([0.13, 0.13, 0.69, 0.19, 0.19, 0.27]).to(device=device)

  for i in np.arange(0.01, 1.0, 0.01):
    # threshold = torch.tensor([0.5]).repeat(len(train_loader.dataset.classes)).to(device=device)
    threshold[key] = i
    score = log_validation_results(threshold, key)
    scores[i] = score
    print('Finish!')

  import matplotlib.pyplot as plt
  x = list(scores.keys())
  precision = [scores[i]['unknown_precision'] for i in scores]
  recall = [scores[i]['unknown_recall'] for i in scores]
  f1 = [scores[i]['unknown_f1'] for i in scores]

  plt.plot(x, precision, color='red', label='precision')
  plt.plot(x, recall, color='green', label='recall')
  plt.plot(x, f1, color='blue', label='f1')

  plt.xlabel('Threshold[{}]'.format(key))
  plt.grid(linestyle='-.')
  plt.legend()
  
  plt.show()
예제 #9
0
def evaluate(tb, vb, modelpath):
  device = os.environ['main-device']
  logging.info('Evaluating program start!')
  threshold = np.arange(0.5, 1.0, 0.02)
  iterations = 1
  dist = modelpath+'/dist'
  if not os.path.exists(dist):
    os.mkdir(dist)
  savepath = '{}/{}.csv'.format(dist, 'test')
  # rates = [0.7, 0.3]
  
  # Get dataloader
  

  # Get Model
  b0_model_paths = glob.glob(modelpath+'/b0/*')
  b1_model_paths = glob.glob(modelpath+'/b1/*')
  b5_model_paths = glob.glob(modelpath+'/b5/*')
  models = []
  model_weights = []
  for modelpath in b5_model_paths:
    # model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
    # model = carrier(model)
    # model.load_state_dict(torch.load(modelpath, map_location=device))
    model = EfficientNet.from_pretrained('efficientnet-b5', num_classes=8)
    model.load_state_dict(torch.load(modelpath, map_location=torch.device('cpu'))['model'].module.state_dict())
    model = model.to(device=device)
    models.append(model)
    # model_weights.append(rates[0]/len(b3_model_paths))
  
  for modelpath in b1_model_paths:
    model = EfficientNet.from_pretrained('efficientnet-b1', num_classes=INFO['dataset-info']['num-of-classes'])
    model = carrier(model)
    model.load_state_dict(torch.load(modelpath, map_location=device))
    # model = torch.load(modelpath, map_location=device)['model']
    models.append(model)
    # model_weights.append(rates[1]/len(b0_model_paths))
  
  for modelpath in b0_model_paths:
    model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=INFO['dataset-info']['num-of-classes'])
    model = carrier(model)
    model.load_state_dict(torch.load(modelpath, map_location=device))
    # model = torch.load(modelpath, map_location=device)['model']
    models.append(model)
    # model_weights.append(rates[1]/len(b0_model_paths))

  model_paths = b5_model_paths
  model_paths.extend(b1_model_paths)
  model_paths.extend(b0_model_paths)

  class entropy(metric.Metric):
    def __init__(self):
      super(entropy, self).__init__()
      # self.values = torch.tensor([], dtype=torch.float)
      self.entropy_rate = torch.tensor([], dtype=torch.float)
      self.inds = torch.tensor([], dtype=torch.int)
      self.y = torch.tensor([], dtype=torch.int)
      self.softmax = torch.tensor([], dtype=torch.float)
    
    def reset(self):
      # self.values = torch.tensor([])
      self.entropy_rate = torch.tensor([], dtype=torch.float)
      self.inds = torch.tensor([], dtype=torch.int)
      self.y = torch.tensor([], dtype=torch.int)
      self.softmax = torch.tensor([], dtype=torch.float)
      super(entropy, self).reset()
    
    def update(self, output):
      y_pred, y = output
      softmax = torch.exp(y_pred) / torch.exp(y_pred).sum(1)[:, None]
      # print(softmax)
      entropy_base = math.log(y_pred.shape[1])
      entropy_rate = (-softmax * torch.log(softmax)).sum(1)/entropy_base
      _, inds = softmax.max(1)
      # prediction = torch.where(entropy>self.threshold, inds, torch.tensor([-1]).to(device=device))
      # self.prediction = torch.cat((self.prediction.type(torch.LongTensor).to(device=device), torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)))
      self.softmax = torch.cat((self.softmax.to(device=device), softmax)).to(device=device)
      self.entropy_rate = torch.cat((self.entropy_rate.to(device=device), entropy_rate)).to(device=device)
      self.y = torch.cat((self.y.type(torch.LongTensor).to(device=device), y.to(device=device)))
      self.inds = torch.cat((self.inds.type(torch.LongTensor).to(device=device), inds.to(device=device)))

    def compute(self):
      return self.softmax, self.entropy_rate, self.inds, self.y

  val_metrics = {
    'result': entropy()
  }

  metric_list = []
  train_loader, train4val_loader, val_loader, num_of_images, mapping, imgs = get_dataloaders(tb, vb)
  for i in range(iterations):
    _, _, val_loader, _, _, _ = get_dataloaders(tb, vb)
    print('Iteration {}'.format(i))
    for ind, model in enumerate(models):
      val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
      val_evaluator.run(val_loader)
      metrics = val_evaluator.state.metrics['result']
      softmax, er, inds, y_true = metrics
      m = {
        'model': model_paths[ind],
        'softmax': softmax,
        'er': er,
        'inds': inds,
        'y': y_true
      }
      metric_list.append(m)
      print('\tFinish 1!')

  def log_validation_results(threshold, metric):
    name = metric['model']
    entropy_rate = metric['er']
    inds = metric['inds']
    y = metric['y']
    # print(entropy)
    # print(threshold)
    # print(inds)
    prediction = torch.where(entropy_rate<threshold, inds, torch.tensor([-1]).to(device=device))
    prediction = torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)

    avg_accuracy = Labels2Acc((prediction, y))
    precision_recall = Labels2PrecisionRecall((prediction, y), val_loader.dataset.classes)
    cmatrix = Labels2CMatrix((prediction, y), val_loader.dataset.classes)
    prompt = """
      Model: {}
      Threshold: {}

      Avg accuracy: {:.4f}

      precision_recall: \n{}

      confusion matrix: \n{}
      """.format(name, threshold,avg_accuracy,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    return {
      'mean_recall': precision_recall['pretty']['mean']['Recall']
    }

  def get_mean_softmax(metric_list):
    mean_softmax = None
    k = 0
    for metrics in metric_list:
      softmax = metrics['softmax']
      if mean_softmax is not None:
        mean_softmax = mean_softmax + softmax
        k += 1
      else:
        mean_softmax = softmax
        k += 1
    return mean_softmax / len(metric_list)

  def get_min_er_softmax(metric_list):
    res_softmax = None
    res_er = None
    for metrics in metric_list:
      if res_softmax is not None:
        mask = torch.where(res_er < metrics['er'], torch.tensor(0).to(device=device), torch.tensor(1).to(device=device))
        mask = mask.nonzero()[:,0]
        res_softmax[mask] = metrics['softmax'][mask]
        res_er[mask] = metrics['er'][mask]
      else:
        res_softmax = metrics['softmax']
        res_er = metrics['er']
    return res_softmax

  def save_softmax(softmax):
    np_softmax = softmax.cpu().numpy()
    np.savetxt(savepath, np_softmax, delimiter=",")

  def log_mean_results(threshold, softmax, y_true):
    entropy_base = math.log(softmax.shape[1])
    entropy_rate = (-softmax * torch.log(softmax)).sum(1)/entropy_base
    # print(entropy_rate)
    _, inds = softmax.max(1)
    prediction = torch.where(entropy_rate<threshold, inds, torch.tensor([-1]).to(device=device))
    prediction = torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)

    high_confidence_inds = (entropy_rate<1e-1).nonzero()
    low_confidence_inds = (entropy_rate>threshold).nonzero()
    high_confidence = np.array([{
      'from': int(imgs[x][1]),
      'to': inds[x].item(),
      'img': imgs[x][0],
      'er': entropy_rate[x].item()
    } for x in high_confidence_inds])
    low_confidence = np.array([{
      'from': int(imgs[x][1]),
      'to': -1,
      'img': imgs[x][0],
      'er': entropy_rate[x].item()
    } for x in low_confidence_inds])

    avg_accuracy = Labels2Acc((prediction, y_true))
    precision_recall = Labels2PrecisionRecall((prediction, y_true), val_loader.dataset.classes)
    cmatrix = Labels2CMatrix((prediction, y_true), val_loader.dataset.classes)

    prompt = """
      Threshold: {}

      Avg accuracy: {:.4f}

      precision_recall: \n{}

      confusion matrix: \n{}
      """.format(threshold,avg_accuracy,precision_recall['pretty'],cmatrix['pretty'])
    logging.info('\n'+prompt)
    print (prompt)
    return high_confidence, low_confidence

  scores = {}

  # test1 = log_validation_results(1.0)

  for metrics in metric_list:
    score = log_validation_results(1, metrics)


  for t in threshold:
    mean_softmax = get_mean_softmax(metric_list)
    save_softmax(mean_softmax)
    high, low = log_mean_results(t, mean_softmax, metric_list[0]['y'])
    # high, low = log_mean_results(t, get_min_er_softmax(metric_list), metric_list[0]['y'])

  def transduct(datasets, img_pack, rate=0.8):
    for dset_ind in range(datasets):
      class_to_idx = val_loader.dataset.class_to_idx
      classes = val_loader.dataset.classes
      idx_to_classes = {}
      for c in classes:
        idx_to_classes[class_to_idx[c]] = c

      train_base = '{}/{}/Train'.format(os.environ['datadir-base'], INFO['dataset'])
      # source_base = '{}/{}/Val'.format(os.environ['datadir-base'], INFO['dataset'])
      dist_base = '{}/{}-transduct{}'.format(os.environ['datadir-base'], INFO['dataset'], dset_ind)
      if not os.path.exists(dist_base):
        os.mkdir(dist_base)
      dist_base = '{}/Train'.format(dist_base)
      if not os.path.exists(dist_base):
        os.mkdir(dist_base)

      for c in classes:
        if not os.path.exists('{}/{}'.format(dist_base, c)):
          os.mkdir('{}/{}'.format(dist_base, c))

      to_move = np.random.choice(img_pack, int(rate*img_pack.shape[0]))
      for img in to_move:
        source = img['img']
        img_name = source[source.rfind('/')+1:]
        class_name = idx_to_classes[img['to']] if img['to'] != -1 else 'UNKNOWN'
        dist = '{}/{}/{}'.format(dist_base, class_name, img_name)
        print ('Move "{}" to "{}"!'.format(source, dist))
        os.popen('cp "{}" "{}"'.format(source, dist))
예제 #10
0
def run(tb, vb, lr, epochs, writer):
  device = os.environ['main-device']
  logging.info('Training program start!')
  logging.info('Configuration:')
  logging.info('\n'+json.dumps(INFO, indent=2))

  # ------------------------------------
  # 1. Define dataloader
  train_loader, train4val_loader, val_loader, num_of_images, mapping, _ = get_dataloaders(tb, vb)
  weights = (1/num_of_images)/((1/num_of_images).sum().item())
  weights = weights.to(device=device)
  
  # ------------------------------------
  # 2. Define model
  model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
  model = carrier(model)
  
  # ------------------------------------
  # 3. Define optimizer
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
  ignite_scheduler = LRScheduler(scheduler)
  
  # ------------------------------------
  # 4. Define metrics

  class EntropyPrediction(metric.Metric):
    def __init__(self, threshold=1.0):
      super(EntropyPrediction, self).__init__()
      self.threshold = threshold
      self.prediction = torch.tensor([], dtype=torch.int)
      self.y = torch.tensor([], dtype=torch.int)
    
    def reset(self):
      # self.threshold = 0.3
      self.prediction = torch.tensor([])
      self.y = torch.tensor([])
      super(EntropyPrediction, self).reset()
    
    def update(self, output):
      y_pred, y = output
      softmax = torch.exp(y_pred) / torch.exp(y_pred).sum(1)[:, None]
      entropy_base = math.log(y_pred.shape[1])
      entropy = (-softmax * torch.log(softmax)).sum(1)/entropy_base
      values, inds = softmax.max(1)
      prediction = torch.where(entropy<self.threshold, inds, torch.tensor([-1]).to(device=device))
      self.prediction = torch.cat((self.prediction.type(torch.LongTensor).to(device=device), torch.tensor([mapping[x.item()] for x in prediction]).to(device=device)))
      self.y = torch.cat((self.y.type(torch.LongTensor).to(device=device), y.to(device=device)))
      # return self.prediction, self.y

    def compute(self):
      return self.prediction, self.y

  train_metrics = {
    'accuracy': Accuracy(),
    'loss': Loss(CrossEntropywithLS(weight=weights)),
    'precision_recall': MetricsLambda(PrecisionRecallTable, Precision(), Recall(), train_loader.dataset.classes),
    'cmatrix': MetricsLambda(CMatrixTable, ConfusionMatrix(INFO['dataset-info']['num-of-classes']), train_loader.dataset.classes)
  }

  val_metrics = {
    'accuracy': MetricsLambda(Labels2Acc, EntropyPrediction()),
    'precision_recall': MetricsLambda(Labels2PrecisionRecall, EntropyPrediction(), val_loader.dataset.classes),
    'cmatrix': MetricsLambda(Labels2CMatrix, EntropyPrediction(), val_loader.dataset.classes)
  }
  
  # ------------------------------------
  # 5. Create trainer
  trainer = create_supervised_trainer(model, optimizer, CrossEntropywithLS(weight=weights), device=device)
  
  # ------------------------------------
  # 6. Create evaluator
  train_evaluator = create_supervised_evaluator(model, metrics=train_metrics, device=device)
  val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)

  desc = 'Epoch {} - loss: {:.4f}'
  pbar = tqdm(
    initial=0, leave=False, total=len(train_loader),
    desc=desc.format(0, 0)
  )


  # ------------------------------------
  # 7. Create event hooks

  # Basic events on showing training procedure.
  @trainer.on(Events.ITERATION_COMPLETED)
  def log_training_loss(engine):
    log_interval = 1
    iter = (engine.state.iteration - 1) % len(train_loader) + 1
    if iter % log_interval == 0:
      pbar.desc = desc.format(engine.state.epoch, engine.state.output)
      pbar.update(log_interval)

  @trainer.on(Events.EPOCH_COMPLETED)
  def refresh_pbar(engine):
    pbar.refresh()
    print('Finish epoch {}!'.format(engine.state.epoch))
    pbar.n = pbar.last_print_n = 0

  # Trigger: Compute metrics on training data.
  def log_training_results(engine):
    print ('Checking on training set.')
    train_evaluator.run(train4val_loader)
    metrics = train_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_loss = metrics['loss']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch,avg_accuracy,avg_loss,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss}, engine.state.epoch)
  
  # Trigger: Compute metrics on validating data.
  def log_validation_results(engine):
    pbar.clear()
    print ('* - * - * - * - * - * - * - * - * - * - * - * - *')
    print ('Checking on validation set.')
    val_evaluator.run(val_loader)
    metrics = val_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    precision_recall = metrics['precision_recall']
    cmatrix = metrics['cmatrix']
    prompt = """
      Validating Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch,avg_accuracy,precision_recall['pretty'],cmatrix['pretty'])
    tqdm.write(prompt)
    logging.info('\n'+prompt)
    writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
    writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy}, engine.state.epoch)
    writer.add_scalars('Aggregate/Score', {'Val avg precision': precision_recall['data'][0, -1], 'Val avg recall': precision_recall['data'][1, -1]}, engine.state.epoch)

  # ------------------------------------
  # Trainer triggers settings
  
  # Save model ever N epoch.
  save_model_handler = ModelCheckpoint(os.environ['savedir'], '', save_interval=10, n_saved=2)
  trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler, {'model': model})
  
  # Evaluate.
  evaluate_interval = epochs
  cpe = CustomPeriodicEvent(n_epochs=epochs)
  cpe.attach(trainer)
  on_evaluate_event = getattr(cpe.Events, 'EPOCHS_{}_COMPLETED'.format(evaluate_interval))
  trainer.add_event_handler(on_evaluate_event, log_training_results)
  trainer.add_event_handler(on_evaluate_event, log_validation_results)

  # Update learning rate.
  trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

  # ------------------------------------
  # Run
  trainer.run(train_loader, max_epochs=epochs)
  pbar.close()
def run(tb, vb, lr, epochs, writer):
    device = os.environ['main-device']
    logging.info('Training program start!')
    logging.info('Configuration:')
    logging.info('\n' + json.dumps(INFO, indent=2))

    # ------------------------------------
    # 1. Define dataloader
    train_loader, train4val_loader, val_loader, num_of_images, mapping = get_dataloaders(
        tb, vb)
    # train_loader, train4val_loader, val_loader, num_of_images = get_dataloaders(tb, vb)
    weights = (1 / num_of_images) / ((1 / num_of_images).sum().item())
    # weights = (1/num_of_images)/(1/num_of_images + 1/(num_of_images.sum().item()-num_of_images))
    weights = weights.to(device=device)

    # ------------------------------------
    # 2. Define model
    model = EfficientNet.from_pretrained(
        'efficientnet-b3', num_classes=INFO['dataset-info']['num-of-classes'])
    model = carrier(model)

    # ------------------------------------
    # 3. Define optimizer
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    ignite_scheduler = LRScheduler(scheduler)

    # ------------------------------------
    # 4. Define metrics

    class CenterLoss(nn.Module):
        """Center loss.
    Reference:
    Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
    Args:
        num_classes (int): number of classes.
        feat_dim (int): feature dimension.
    """
        def __init__(self, num_classes=10, feat_dim=2):
            super(CenterLoss, self).__init__()
            self.num_classes = num_classes
            self.feat_dim = feat_dim
            self.centers = torch.randn(self.num_classes,
                                       self.feat_dim).to(device=device)

        def forward(self, x, labels):
            """
      Args:
        x: feature matrix with shape (batch_size, feat_dim).
        labels: ground truth labels with shape (batch_size).
      """
            batch_size = x.size(0)
            distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                      torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
            distmat.addmm_(1, -2, x, self.centers.t())
            classes = torch.arange(self.num_classes).long()
            classes = classes.to(device=device)
            labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
            mask = labels.eq(classes.expand(batch_size, self.num_classes))
            dist = distmat * mask.float()
            loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
            return loss

    class MixLoss(nn.Module):
        def __init__(self):
            super(MixLoss, self).__init__()

        def forward(self, x, y):
            _lambda = 0.5
            center = CenterLoss(x.shape[0], x.shape[1])
            ce = nn.CrossEntropyLoss(weight=weights)
            center = center(x, y)
            ce = ce(x, y)
            # print('{:.4f} | {:.4f}'.format(center, ce))
            return _lambda * center + ce

    class EntropyPrediction(metric.Metric):
        def __init__(self, threshold=1.0):
            super(EntropyPrediction, self).__init__()
            self.threshold = threshold
            self.prediction = torch.tensor([], dtype=torch.int)
            self.y = torch.tensor([], dtype=torch.int)

        def reset(self):
            # self.threshold = 0.5
            self.prediction = torch.tensor([])
            self.y = torch.tensor([])
            super(EntropyPrediction, self).reset()

        def update(self, output):
            y_pred, y = output
            softmax = torch.exp(y_pred) / torch.exp(y_pred).sum(1)[:, None]
            entropy_base = math.log(y_pred.shape[1])
            entropy = (-softmax * torch.log(softmax)).sum(1) / entropy_base
            values, inds = softmax.max(1)
            prediction = torch.where(entropy < self.threshold, inds,
                                     torch.tensor([-1]).to(device=device))
            self.prediction = torch.cat(
                (self.prediction.type(torch.LongTensor).to(device=device),
                 torch.tensor([mapping[x.item()]
                               for x in prediction]).to(device=device)))
            self.y = torch.cat(
                (self.y.type(torch.LongTensor).to(device=device),
                 y.to(device=device)))
            # return self.prediction, self.y

        def compute(self):
            return self.prediction, self.y

    train_metrics = {
        'accuracy':
        Accuracy(),
        'loss':
        Loss(MixLoss()),
        'precision_recall':
        MetricsLambda(PrecisionRecallTable, Precision(), Recall(),
                      train_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(CMatrixTable,
                      ConfusionMatrix(INFO['dataset-info']['num-of-classes']),
                      train_loader.dataset.classes)
    }

    val_metrics = {
        'accuracy':
        MetricsLambda(Labels2Acc, EntropyPrediction(1.0)),
        'precision_recall':
        MetricsLambda(Labels2PrecisionRecall, EntropyPrediction(1.0),
                      val_loader.dataset.classes),
        'cmatrix':
        MetricsLambda(Labels2CMatrix, EntropyPrediction(1.0),
                      val_loader.dataset.classes)
    }

    # ------------------------------------
    # 5. Create trainer
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        MixLoss(),
                                        device=device)

    # ------------------------------------
    # 6. Create evaluator
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=train_metrics,
                                                  device=device)
    val_evaluator = create_supervised_evaluator(model,
                                                metrics=val_metrics,
                                                device=device)

    desc = 'ITERATION - loss: {:.4f}'
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    # ------------------------------------
    # 7. Create event hooks

    # Update process bar on each iteration completed.
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        log_interval = 1
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

    @trainer.on(Events.EPOCH_STARTED)
    def refresh_pbar(engine):
        torch.cuda.empty_cache()
        print('Finish epoch {}'.format(engine.state.epoch))
        pbar.refresh()
        pbar.n = pbar.last_print_n = 0

    # Compute metrics on train data on each epoch completed.
    # cpe = CustomPeriodicEvent(n_epochs=50)
    # cpe.attach(trainer)
    # @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
    def log_training_results(engine):
        print('Checking on training set.')
        train_evaluator.run(train4val_loader)
        metrics = train_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['loss']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      Training Results - Epoch: {}
      Avg accuracy: {:.4f}
      Avg loss: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, avg_loss,
                 precision_recall['pretty'], cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Train Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars('Aggregate/Loss', {'Train Loss': avg_loss},
                           engine.state.epoch)

    # Compute metrics on val data on each epoch completed.
    # cpe = CustomPeriodicEvent(n_epochs=50)
    # cpe.attach(trainer)
    # @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
    def log_validation_results(engine):
        pbar.clear()
        print('* - * - * - * - * - * - * - * - * - * - * - * - *')
        print('Checking on validation set.')
        val_evaluator.run(val_loader)
        metrics = val_evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        precision_recall = metrics['precision_recall']
        cmatrix = metrics['cmatrix']
        prompt = """
      Validating Results - Epoch: {}
      Avg accuracy: {:.4f}
      precision_recall: \n{}
      confusion matrix: \n{}
      """.format(engine.state.epoch, avg_accuracy, precision_recall['pretty'],
                 cmatrix['pretty'])
        tqdm.write(prompt)
        logging.info('\n' + prompt)
        writer.add_text(os.environ['run-id'], prompt, engine.state.epoch)
        writer.add_scalars('Aggregate/Acc', {'Val Acc': avg_accuracy},
                           engine.state.epoch)
        writer.add_scalars(
            'Aggregate/Score', {
                'Val avg precision': precision_recall['data'][0, -1],
                'Val avg recall': precision_recall['data'][1, -1]
            }, engine.state.epoch)

    cpe = CustomPeriodicEvent(n_epochs=50)
    cpe.attach(trainer)
    # @trainer.on(cpe.Events.EPOCHS_50_COMPLETED)
    trainer.add_event_handler(cpe.Events.EPOCHS_50_COMPLETED,
                              log_training_results)
    trainer.add_event_handler(cpe.Events.EPOCHS_50_COMPLETED,
                              log_validation_results)
    trainer.add_event_handler(Events.STARTED, log_training_results)
    trainer.add_event_handler(Events.STARTED, log_validation_results)

    # Save model ever N epoch.
    save_model_handler = ModelCheckpoint(os.environ['savedir'],
                                         '',
                                         save_interval=10,
                                         n_saved=2)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, save_model_handler,
                              {'model': model})

    # Update learning-rate due to scheduler.
    trainer.add_event_handler(Events.EPOCH_STARTED, ignite_scheduler)

    # ------------------------------------
    # Run
    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()