Beispiel #1
0
def main():
    args = get_arguments()
    utils.make_dirs(args.save)
    name_model = args.model + "_" + args.dataset_name + "_" + utils.datestr()

    # TODO visual3D_temp.Basewriter package
    writer = SummaryWriter(log_dir='../runs/' + name_model, comment=name_model)

    training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(
        args, path='.././datasets')
    model, optimizer = medzoo.create_model(args)

    criterion = DiceLoss(classes=11, skip_index_after=args.classes)

    if args.cuda:
        torch.cuda.manual_seed(seed)
        model = model.cuda()
        print("Model transferred in GPU.....")

    print("START TRAINING...")
    for epoch in range(1, args.nEpochs + 1):
        train_stats = train.train_dice(args, epoch, model, training_generator,
                                       optimizer, criterion)

        val_stats = train.test_dice(args, epoch, model, val_generator,
                                    criterion)

        #old
        utils.write_train_val_score(writer, epoch, train_stats, val_stats)

        model.save_checkpoint(args.save,
                              epoch,
                              val_stats[0],
                              optimizer=optimizer)
def main():
    args = get_arguments()
    utils.make_dirs(args.save)
    train_f, val_f = utils.create_stats_files(args.save)
    name_model = args.model + "_" + args.dataset_name + "_" + utils.datestr()
    writer = SummaryWriter(log_dir='../runs/' + name_model, comment=name_model)
    best_prec1 = 100.

    training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(
        args, path='.././datasets')
    model, optimizer = medzoo.create_model(args)

    criterion = DiceLoss(classes=args.classes)

    if args.cuda:
        torch.cuda.manual_seed(seed)
        model = model.cuda()
        print("Model transferred in GPU.....")

    print("START TRAINING...")
    for epoch in range(1, args.nEpochs + 1):
        train_stats = train.train_dice(args, epoch, model, training_generator,
                                       optimizer, criterion, train_f, writer)

        val_stats = train.test_dice(args, epoch, model, val_generator,
                                    criterion, val_f, writer)

        utils.write_train_val_score(writer, epoch, train_stats, val_stats)

        model.save_checkpoint(args.save,
                              epoch,
                              val_stats[0],
                              optimizer=optimizer)

        # if epoch % 5 == 0:
        # utils.visualize_no_overlap(args, full_volume, affine, model, epoch, DIM, writer)

        #utils.save_model(model, args, val_stats[0], epoch, best_prec1)

    train_f.close()
    val_f.close()
def main():
    args = get_arguments()
    utils.make_dirs(args.save)
    train_f, val_f = utils.create_stats_files(args.save)

    name_model = args.model + "_" + args.dataset_name + "_" + utils.datestr()
    writer = SummaryWriter(log_dir='../runs/' + name_model, comment=name_model)

    best_pred = 1.01
    samples_train = 200
    samples_val = 200
    training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(
        args,
        path='.././datasets',
        samples_train=samples_train,
        samples_val=samples_val)

    model, optimizer = medzoo.create_model(args)
    criterion = medzoo.DiceLoss2D(args.classes)

    if args.cuda:
        torch.cuda.manual_seed(seed)
        model = model.cuda()

    for epoch in range(1, args.nEpochs + 1):
        train_stats = train.train_dice(args, epoch, model, training_generator,
                                       optimizer, criterion, train_f, writer)
        val_stats = train.test_dice(args, epoch, model, val_generator,
                                    criterion, val_f, writer)

        utils.write_train_val_score(writer, epoch, train_stats, val_stats)
        best_pred = utils.save_model(model=model,
                                     args=args,
                                     dice_loss=val_stats[0],
                                     epoch=epoch,
                                     best_pred_loss=best_pred)

    train_f.close()
    val_f.close()