コード例 #1
0
ファイル: main.py プロジェクト: jbdel/medical_imaging_toolbox
def main():
    # parse arguments
    args = parse_agrs()

    # fix random seeds
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(args.seed)

    # create tokenizer
    tokenizer = Tokenizer(args)

    # create data loader
    train_dataloader = R2DataLoader(args,
                                    tokenizer,
                                    split='train',
                                    shuffle=True)
    val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False)
    test_dataloader = R2DataLoader(args,
                                   tokenizer,
                                   split='test',
                                   shuffle=False)

    # build model architecture
    model = R2GenModel(args, tokenizer)

    # get function handles of loss and metrics
    criterion = compute_loss
    metrics = compute_scores

    # build optimizer, learning rate scheduler
    optimizer = build_optimizer(args, model)
    lr_scheduler = build_lr_scheduler(args, optimizer)

    # build trainer and start to train
    trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler,
                      train_dataloader, val_dataloader, test_dataloader)
    trainer.train()
コード例 #2
0
def train(hparams, distributed_run=False, rank=0, n_gpus=None):
    """Training and validation logging results to tensorboard and stdout
    """
    if distributed_run:
        assert n_gpus is not None

    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)

    model = load_model(hparams, distributed_run)
    optimizer = build_optimizer(model, hparams)
    lr_scheduler = build_scheduler(optimizer, hparams)
    criterion = OverallLoss(hparams)

    if hparams.fp16_run:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

    if distributed_run:
        model = apply_gradient_allreduce(model)

    logger = prepare_directories_and_logger(hparams.output_dir,
                                            hparams.log_dir, rank)
    copyfile(hparams.path, os.path.join(hparams.output_dir, 'hparams.yaml'))
    train_loader, valset, collate_fn = prepare_dataloaders(
        hparams, distributed_run)

    # Load checkpoint if one exists
    iteration = 0
    epoch_offset = 0
    if hparams.checkpoint is not None:
        if hparams.warm_start:
            model = warm_start_model(hparams.checkpoint, model,
                                     hparams.ignore_layers)
        else:
            model, optimizer, lr_scheduler, mmi_criterion, iteration = load_checkpoint(
                hparams.checkpoint, model, optimizer, lr_scheduler, criterion,
                hparams.restore_scheduler_state)

            iteration += 1  # next iteration is iteration + 1
            epoch_offset = max(0, int(iteration / len(train_loader)))

    model.train()
    is_overflow = False
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, hparams.epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            torch.cuda.empty_cache()

            start = time.perf_counter()

            model.zero_grad()
            inputs, alignments, inputs_ctc = model.parse_batch(batch)

            outputs, decoder_outputs = model(inputs)

            losses = criterion(outputs,
                               inputs,
                               alignments=alignments,
                               inputs_ctc=inputs_ctc,
                               decoder_outputs=decoder_outputs)

            if hparams.use_mmi and hparams.use_gaf and i % gradient_adaptive_factor.UPDATE_GAF_EVERY_N_STEP == 0:
                mi_loss = losses["mi/loss"]
                overall_loss = losses["overall/loss"]

                gaf = calc_gaf(model, optimizer, overall_loss, mi_loss,
                               hparams.max_gaf)

                losses["mi/loss"] = gaf * mi_loss
                losses["overall/loss"] = overall_loss - mi_loss * (1 - gaf)

            reduced_losses = {
                key: reduce_loss(value, distributed_run, n_gpus)
                for key, value in losses.items()
            }
            loss = losses["overall/loss"]

            if hparams.fp16_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()

                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), hparams.grad_clip_thresh)
                is_overflow = math.isnan(grad_norm)
            else:
                loss.backward()

                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), hparams.grad_clip_thresh)

            optimizer.step()

            if not is_overflow and rank == 0:
                learning_rate = lr_scheduler.get_last_lr()[0]
                duration = time.perf_counter() - start
                print(
                    "Iteration {}: overall loss {:.6f} Grad Norm {:.6f} {:.2f}s/it LR {:.3E}"
                    .format(iteration, reduced_losses["overall/loss"],
                            grad_norm, duration, learning_rate))

                logger.log_training(reduced_losses, grad_norm, learning_rate,
                                    duration, iteration)

            if not is_overflow and (iteration % hparams.iters_per_checkpoint
                                    == 0):
                val_loss = validate(model, criterion, valset, iteration,
                                    hparams.batch_size, collate_fn, logger,
                                    distributed_run, rank, n_gpus)
                if rank == 0:
                    checkpoint = os.path.join(
                        hparams.output_dir, "checkpoint_{}".format(iteration))

                    save_checkpoint(model, optimizer, lr_scheduler, criterion,
                                    iteration, hparams, checkpoint)

            iteration += 1
            if hparams.lr_scheduler == SchedulerTypes.cyclic:
                lr_scheduler.step()

        if not hparams.lr_scheduler == SchedulerTypes.cyclic:
            # TODO: для plateau ошибка валидации должна рассчитываться в конце каждой эпохи, по-хорошему
            scheduler_args = (
            ) if hparams.lr_scheduler != SchedulerTypes.plateau else (
                val_loss, )
            lr_scheduler.step(*scheduler_args)
コード例 #3
0
def train(hparams, distributed_run=False, rank=0, n_gpus=None):
    """Training and validation logging results to tensorboard and stdout
    """
    if distributed_run:
        assert n_gpus is not None

    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)

    model = load_model(hparams, distributed_run)
    criterion = OverallLoss(hparams)
    if criterion.mmi_criterion is not None:
        parameters = chain(model.parameters(),
                           criterion.mmi_criterion.parameters())
    else:
        parameters = model.parameters()
    optimizer = build_optimizer(parameters, hparams)
    lr_scheduler = build_scheduler(optimizer, hparams)

    if distributed_run:
        model = apply_gradient_allreduce(model)
    scaler = amp.GradScaler(enabled=hparams.fp16_run)

    logger = prepare_directories_and_logger(hparams.output_dir,
                                            hparams.log_dir, rank)
    copyfile(hparams.path, os.path.join(hparams.output_dir, 'hparams.yaml'))
    train_loader, valset, collate_fn = prepare_dataloaders(
        hparams, distributed_run)

    # Load checkpoint if one exists
    iteration = 0
    epoch_offset = 0
    if hparams.checkpoint is not None:
        if hparams.warm_start:
            model = warm_start_model(hparams.checkpoint, model,
                                     hparams.ignore_layers,
                                     hparams.ignore_mismatched_layers)
        else:
            model, optimizer, lr_scheduler, mmi_criterion, iteration = load_checkpoint(
                hparams.checkpoint, model, optimizer, lr_scheduler, criterion,
                hparams.restore_scheduler_state)

            iteration += 1  # next iteration is iteration + 1
            epoch_offset = max(0, int(iteration / len(train_loader)))

    model.train()
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, hparams.epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            start = time.perf_counter()

            model.zero_grad()
            inputs, alignments, inputs_ctc = model.parse_batch(batch)

            with amp.autocast(enabled=hparams.fp16_run):
                outputs, decoder_outputs = model(inputs)

                losses = criterion(outputs,
                                   inputs,
                                   alignments=alignments,
                                   inputs_ctc=inputs_ctc,
                                   decoder_outputs=decoder_outputs)

            if hparams.use_mmi and hparams.use_gaf and i % gradient_adaptive_factor.UPDATE_GAF_EVERY_N_STEP == 0:
                mi_loss = losses["mi/loss"]
                overall_loss = losses["overall/loss"]

                gaf = calc_gaf(model, optimizer, overall_loss, mi_loss,
                               hparams.max_gaf)

                losses["mi/loss"] = gaf * mi_loss
                losses["overall/loss"] = overall_loss - mi_loss * (1 - gaf)

            reduced_losses = {
                key: reduce_loss(value, distributed_run, n_gpus)
                for key, value in losses.items()
            }
            loss = losses["overall/loss"]

            scaler.scale(loss).backward()

            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(), hparams.grad_clip_thresh)

            scaler.step(optimizer)
            scaler.update()

            if rank == 0:
                learning_rate = lr_scheduler.get_last_lr()[0]
                duration = time.perf_counter() - start
                print(
                    "Iteration {} ({} epoch): overall loss {:.6f} Grad Norm {:.6f} {:.2f}s/it LR {:.3E}"
                    .format(iteration, epoch, reduced_losses["overall/loss"],
                            grad_norm, duration, learning_rate))

                grad_norm = None if torch.isnan(grad_norm) or torch.isinf(
                    grad_norm) else grad_norm
                logger.log_training(reduced_losses, grad_norm, learning_rate,
                                    duration, iteration)

            if iteration % hparams.iters_per_checkpoint == 0:
                validate(model, criterion, valset, iteration,
                         hparams.batch_size, collate_fn, logger,
                         distributed_run, rank, n_gpus)
                if rank == 0:
                    checkpoint = os.path.join(
                        hparams.output_dir, "checkpoint_{}".format(iteration))

                    save_checkpoint(model, optimizer, lr_scheduler, criterion,
                                    iteration, hparams, checkpoint)

            iteration += 1
            if hparams.lr_scheduler == SchedulerTypes.cyclic:
                lr_scheduler.step()

        if not hparams.lr_scheduler == SchedulerTypes.cyclic:
            if hparams.lr_scheduler == SchedulerTypes.plateau:
                lr_scheduler.step(
                    validate(model, criterion, valset, iteration,
                             hparams.batch_size, collate_fn, logger,
                             distributed_run, rank, n_gpus))
            else:
                lr_scheduler.step()