示例#1
0
def run_training(model, train, valid, optimizer, loss, lr_find=False):
    print_file(f'Experiment: {rcp.experiment}\nDescription:{rcp.description}',
               f'{rcp.base_path}description.txt')
    print_file(model, f'{rcp.models_path}model.txt')
    print_file(get_transforms(), f'{rcp.models_path}transform_{rcp.stage}.txt')
    # Data
    train.transform = get_transforms()
    valid.transform = get_transforms()
    train.save_csv(f'{rcp.base_path}train_df_{rcp.stage}.csv')
    valid.save_csv(f'{rcp.base_path}valid_df_{rcp.stage}.csv')
    train_loader = DataLoader(train,
                              batch_size=rcp.bs,
                              num_workers=8,
                              shuffle=rcp.shuffle_batch)
    valid_loader = DataLoader(valid,
                              batch_size=rcp.bs,
                              num_workers=8,
                              shuffle=rcp.shuffle_batch)

    if lr_find: lr_finder(model, optimizer, loss, train_loader, valid_loader)

    one_batch = next(iter(train_loader))
    dot = make_dot(model(one_batch[0].to(cfg.device)),
                   params=dict(model.named_parameters()))
    dot.render(f'{rcp.models_path}graph', './', format='png', cleanup=True)
    summary(model,
            one_batch[0].shape[-3:],
            batch_size=rcp.bs,
            device=cfg.device,
            to_file=f'{rcp.models_path}summary_{rcp.stage}.txt')

    # Engines
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss,
                                        device=cfg.device)
    t_evaluator = create_supervised_evaluator(model,
                                              metrics={
                                                  'accuracy':
                                                  Accuracy(),
                                                  'nll':
                                                  Loss(loss),
                                                  'precision':
                                                  Precision(average=True),
                                                  'recall':
                                                  Recall(average=True),
                                                  'topK':
                                                  TopKCategoricalAccuracy()
                                              },
                                              device=cfg.device)
    v_evaluator = create_supervised_evaluator(
        model,
        metrics={
            'accuracy':
            Accuracy(),
            'nll':
            Loss(loss),
            'precision_avg':
            Precision(average=True),
            'recall_avg':
            Recall(average=True),
            'topK':
            TopKCategoricalAccuracy(),
            'conf_mat':
            ConfusionMatrix(num_classes=len(valid.classes), average=None),
        },
        device=cfg.device)

    # Tensorboard
    tb_logger = TensorboardLogger(log_dir=f'{rcp.tb_log_path}{rcp.stage}')
    tb_writer = tb_logger.writer
    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer, "lr"),
                     event_name=Events.EPOCH_STARTED)
    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def tb_and_log_training_stats(engine):
        t_evaluator.run(train_loader)
        v_evaluator.run(valid_loader)
        tb_and_log_train_valid_stats(engine, t_evaluator, v_evaluator,
                                     tb_writer)

    @trainer.on(
        Events.ITERATION_COMPLETED(every=int(1 + len(train_loader) / 100)))
    def print_dash(engine):
        print('-', sep='', end='', flush=True)

    if cfg.show_batch_images:

        @trainer.on(Events.STARTED)
        def show_batch_images(engine):
            imgs, lbls = next(iter(train_loader))
            denormalize = DeNormalize(**rcp.transforms.normalize)
            for i in range(len(imgs)):
                imgs[i] = denormalize(imgs[i])
            imgs = imgs.to(cfg.device)
            grid = thv.utils.make_grid(imgs)
            tb_writer.add_image('images', grid, 0)
            tb_writer.add_graph(model, imgs)
            tb_writer.flush()

    if cfg.show_top_losses:

        @trainer.on(Events.COMPLETED)
        def show_top_losses(engine, k=6):
            nll_loss = nn.NLLLoss(reduction='none')
            df = predict_dataset(model,
                                 valid,
                                 nll_loss,
                                 transform=None,
                                 bs=rcp.bs,
                                 device=cfg.device)
            df.sort_values('loss', ascending=False, inplace=True)
            df.reset_index(drop=True, inplace=True)
            for i, row in df.iterrows():
                img = cv2.imread(str(row['fname']))
                img = th.as_tensor(img.transpose(2, 0, 1))  # #CHW
                tag = f'TopLoss_{engine.state.epoch}/{row.loss:.4f}/{row.target}/{row.pred}/{row.pred2}'
                tb_writer.add_image(tag, img, 0)
                if i >= k - 1: break
            tb_writer.flush()

    if cfg.tb_projector:
        images, labels = train.select_n_random(250)
        # get the class labels for each image
        class_labels = [train.classes[lab] for lab in labels]
        # log embeddings
        features = images.view(-1, images.shape[-1] * images.shape[-2])
        tb_writer.add_embedding(features,
                                metadata=class_labels,
                                label_img=images)

    if cfg.log_pr_curve:

        @trainer.on(Events.COMPLETED)
        def log_pr_curve(engine):
            """
            1. gets the probability predictions in a test_size x num_classes Tensor
            2. gets the preds in a test_size Tensor
            takes ~10 seconds to run
            """
            class_probs = []
            class_preds = []
            with th.no_grad():
                for data in valid_loader:
                    imgs, lbls = data
                    imgs, lbls = imgs.to(cfg.device), lbls.to(cfg.device)
                    output = model(imgs)
                    class_probs_batch = [
                        th.softmax(el, dim=0) for el in output
                    ]
                    _, class_preds_batch = th.max(output, 1)
                    class_probs.append(class_probs_batch)
                    class_preds.append(class_preds_batch)
            test_probs = th.cat([th.stack(batch) for batch in class_probs])
            test_preds = th.cat(class_preds)

            for i in range(len(valid.classes)):
                """ Takes in a "class_index" from 0 to 9 and plots the corresponding precision-recall curve"""
                tensorboard_preds = test_preds == i
                tensorboard_probs = test_probs[:, i]

                tb_writer.add_pr_curve(f'{rcp.stage}/{valid.classes[i]}',
                                       tensorboard_preds,
                                       tensorboard_probs,
                                       global_step=engine.state.epoch,
                                       num_thresholds=127)
                tb_writer.flush()

    print()

    if cfg.lr_scheduler:
        # lr_scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=.5, min_lr=1e-7, verbose=True)
        # v_evaluator.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: lr_scheduler.step(v_evaluator.state.metrics['nll']))
        lr_scheduler = DelayedCosineAnnealingLR(optimizer, 10, 5)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED,
            lambda engine: lr_scheduler.step(trainer.state.epoch))

    if cfg.early_stopping:

        def score_function(engine):
            score = -1 * round(engine.state.metrics['nll'], 5)
            # score = engine.state.metrics['accuracy']
            return score

        es_handler = EarlyStopping(patience=10,
                                   score_function=score_function,
                                   trainer=trainer)
        v_evaluator.add_event_handler(Events.COMPLETED, es_handler)

    if cfg.save_last_checkpoint:

        @trainer.on(Events.EPOCH_COMPLETED(every=1))
        def save_last_checkpoint(engine):
            checkpoint = {}
            objects = {'model': model, 'optimizer': optimizer}
            if cfg.lr_scheduler: objects['lr_scheduler'] = lr_scheduler
            for k, obj in objects.items():
                checkpoint[k] = obj.state_dict()
            th.save(checkpoint,
                    f'{rcp.models_path}last_{rcp.stage}_checkpoint.pth')

    if cfg.save_best_checkpoint:

        def score_function(engine):
            score = -1 * round(engine.state.metrics['nll'], 5)
            # score = engine.state.metrics['accuracy']
            return score

        objects = {'model': model, 'optimizer': optimizer}
        if cfg.lr_scheduler: objects['lr_scheduler'] = lr_scheduler

        save_best = Checkpoint(
            objects,
            DiskSaver(f'{rcp.models_path}',
                      require_empty=False,
                      create_dir=True),
            n_saved=4,
            filename_prefix=f'best_{rcp.stage}',
            score_function=score_function,
            score_name='val_loss',
            global_step_transform=global_step_from_engine(trainer))
        v_evaluator.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                      save_best)
        load_checkpoint = False

        if load_checkpoint:
            resume_epoch = 6
            cp = f'{rcp.models_path}last_{rcp.stage}_checkpoint.pth'
            obj = th.load(f'{cp}')
            Checkpoint.load_objects(objects, obj)

            @trainer.on(Events.STARTED)
            def resume_training(engine):
                engine.state.iteration = (resume_epoch - 1) * len(
                    engine.state.dataloader)
                engine.state.epoch = resume_epoch - 1

    if cfg.save_confusion_matrix:

        @trainer.on(Events.STARTED)
        def init_best_loss(engine):
            engine.state.metrics['best_loss'] = 1e99

        @trainer.on(Events.EPOCH_COMPLETED)
        def confusion_matric(engine):
            if engine.state.metrics['best_loss'] > v_evaluator.state.metrics[
                    'nll']:
                engine.state.metrics['best_loss'] = v_evaluator.state.metrics[
                    'nll']
                cm = v_evaluator.state.metrics['conf_mat']
                cm_df = pd.DataFrame(cm.numpy(),
                                     index=valid.classes,
                                     columns=valid.classes)
                pretty_plot_confusion_matrix(
                    cm_df,
                    f'{rcp.results_path}cm_{rcp.stage}_{trainer.state.epoch}.png',
                    False)

    if cfg.log_stats:

        class Hook:
            def __init__(self, module):
                self.name = module[0]
                self.hook = module[1].register_forward_hook(self.hook_fn)
                self.stats_mean = 0
                self.stats_std = 0

            def hook_fn(self, module, input, output):
                self.stats_mean = output.mean()
                self.stats_std = output.std()

            def close(self):
                self.hook.remove()

        hookF = [Hook(layer) for layer in list(model.cnn.named_children())]

        @trainer.on(Events.ITERATION_COMPLETED)
        def log_stats(engine):
            std = {}
            mean = {}
            for hook in hookF:
                tb_writer.add_scalar(f'std/{hook.name}', hook.stats_std,
                                     engine.state.iteration)
                tb_writer.add_scalar(f'mean/{hook.name}', hook.stats_mean,
                                     engine.state.iteration)

    cfg.save_yaml()
    rcp.save_yaml()
    print(f'# batches: train: {len(train_loader)}, valid: {len(valid_loader)}')
    trainer.run(data=train_loader, max_epochs=rcp.max_epochs)
    tb_writer.close()
    tb_logger.close()
    return model
示例#2
0
net = monai.networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)
loss = monai.losses.DiceLoss(do_sigmoid=True)
lr = 1e-3
opt = torch.optim.Adam(net.parameters(), lr)
device = torch.device('cuda:0')

# ignite trainer expects batch=(img, seg) and returns output=loss at every iteration,
# user can add output_transform to return other values, like: y_pred, y, etc.
trainer = create_supervised_trainer(net, opt, loss, device, False)

# adding checkpoint handler to save models (network params and optimizer stats) during training
checkpoint_handler = ModelCheckpoint('./runs/',
                                     'net',
                                     n_saved=10,
                                     require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                          handler=checkpoint_handler,
                          to_save={
                              'net': net,
                              'opt': opt
                          })

# StatsHandler prints loss at every iteration and print metrics at every epoch,
# we don't set metrics for trainer here, so just print loss, user can also customize print functions
示例#3
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval,
        log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    writer = create_summary_writer(model, train_loader, log_dir)
    device = 'cpu'

    if torch.cuda.is_available():
        device = 'cuda'

    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        F.nll_loss,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'accuracy': Accuracy(),
                                                'nll': Loss(F.nll_loss)
                                            },
                                            device=device)

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
              "".format(engine.state.epoch, engine.state.iteration,
                        len(train_loader), engine.state.output))
        writer.add_scalar("training/loss", engine.state.output,
                          engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        print(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))
        writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        print(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))
        writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("valdation/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    writer.close()
	cross_entropy_loss = nn.CrossEntropyLoss()
	adam_optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	home = os.environ['HOME']
	writer = SummaryWriter(home + '/log.json')
	lowest_loss = np.Inf
	train_loader, val_loader, test_loader = load_data(args.batch_size)

	model.to(device)
	
	if args.summary:
		summary(model, (1, 28, 28))
		print('Batch size: ',  args.batch_size)
		print('Epochs:', args.num_epochs)
	
	trainer = create_supervised_trainer(model, adam_optimizer, cross_entropy_loss, device=device)
	evaluator = create_supervised_evaluator(model, metrics={"accuracy": Accuracy(), "cross": Loss(cross_entropy_loss),
	                                                        "prec": Precision(), "recall": Recall()},
	                                        device=device)
	
	desc = "ITERATION - loss: {:.2f}"
	pbar = tqdm(
		initial=0, leave=False, total=len(train_loader),
		desc=desc.format(0)
	)
	
	
	@trainer.on(Events.ITERATION_COMPLETED)
	def log_training_loss(engine):
		iter = (engine.state.iteration - 1) % len(train_loader) + 1
		
示例#5
0
def do_train(model, train_loader, val_loader, optimizer, scheduler,
             checkpointer, loss_fn, device, checkpoint_period, log_period,
             epochs):
    logger = logging.getLogger("template_model.train")
    logger.info("Start training")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'accuracy': Accuracy(),
                                                'ce_loss': Loss(loss_fn)
                                            },
                                            device=device)

    desc = "ITERATION -loss: {:.3f}"
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_period == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_period)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['ce_loss']
        # tqdm.write("Training Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}"
        #            .format(engine.state.epoch, avg_accuracy, avg_loss)
        #            )
        logger.info(
            "Training Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}"
            .format(engine.state.epoch, avg_accuracy, avg_loss))

    if val_loader is not None:

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(engine):
            evaluator.run(val_loader)
            metrics = evaluator.state.metrics
            avg_accuracy = metrics['accuracy']
            avg_loss = metrics['ce_loss']
            # tqdm.write("Validation Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}"
            #            .format(engine.state.epoch, avg_accuracy, avg_loss)
            #            )
            logger.info(
                "Validation Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}"
                .format(engine.state.epoch, avg_accuracy, avg_loss))
            pbar.n = pbar.last_print_n = 0

    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()
示例#6
0
def run(
    train_batch_size,
    val_batch_size,
    epochs,
    lr,
    momentum,
    log_interval,
    log_dir,
    checkpoint_every,
    resume_from,
    crash_iteration=-1,
    deterministic=False,
):
    # Setup seed to have same model's initialization:
    manual_seed(75)

    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    writer = SummaryWriter(log_dir=log_dir)
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    criterion = nn.NLLLoss()
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.5)

    # Setup trainer and evaluator
    if deterministic:
        tqdm.write("Setup deterministic trainer")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device,
                                        deterministic=deterministic)

    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                "accuracy": Accuracy(),
                                                "nll": Loss(criterion)
                                            },
                                            device=device)

    # Apply learning rate scheduling
    @trainer.on(Events.EPOCH_COMPLETED)
    def lr_step(engine):
        lr_scheduler.step()

    desc = "Epoch {} - loss: {:.4f} - lr: {:.4f}"
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0, 0, lr))

    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        lr = optimizer.param_groups[0]["lr"]
        pbar.desc = desc.format(engine.state.epoch, engine.state.output, lr)
        pbar.update(log_interval)
        writer.add_scalar("training/loss", engine.state.output,
                          engine.state.iteration)
        writer.add_scalar("lr", lr, engine.state.iteration)

    if crash_iteration > 0:

        @trainer.on(Events.ITERATION_COMPLETED(once=crash_iteration))
        def _(engine):
            raise Exception(f"STOP at {engine.state.iteration}")

    if resume_from is not None:

        @trainer.on(Events.STARTED)
        def _(engine):
            pbar.n = engine.state.iteration % engine.state.epoch_length

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            f"Training Results - Epoch: {engine.state.epoch}  Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # Compute and log validation metrics
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        tqdm.write(
            f"Validation Results - Epoch: {engine.state.epoch}  Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
        )
        pbar.n = pbar.last_print_n = 0
        writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("valdation/avg_accuracy", avg_accuracy,
                          engine.state.epoch)

    # Setup object to checkpoint
    objects_to_checkpoint = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    training_checkpoint = Checkpoint(
        to_save=objects_to_checkpoint,
        save_handler=DiskSaver(log_dir, require_empty=False),
        n_saved=None,
        global_step_transform=lambda *_: trainer.state.epoch,
    )
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_every),
                              training_checkpoint)

    # Setup logger to print and dump into file: model weights, model grads and data stats
    # - first 3 iterations
    # - 4 iterations after checkpointing
    # This helps to compare resumed training with checkpointed training
    def log_event_filter(e, event):
        if event in [1, 2, 3]:
            return True
        elif 0 <= (event % (checkpoint_every * e.state.epoch_length)) < 5:
            return True
        return False

    fp = Path(log_dir) / ("run.log"
                          if resume_from is None else "resume_run.log")
    fp = fp.as_posix()
    for h in [log_data_stats, log_model_weights, log_model_grads]:
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED(event_filter=log_event_filter),
            h,
            model=model,
            fp=fp)

    if resume_from is not None:
        tqdm.write(f"Resume from the checkpoint: {resume_from}")
        checkpoint = torch.load(resume_from)
        Checkpoint.load_objects(to_load=objects_to_checkpoint,
                                checkpoint=checkpoint)

    try:
        # Synchronize random states
        manual_seed(15)
        trainer.run(train_loader, max_epochs=epochs)
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    pbar.close()
    writer.close()