コード例 #1
0
val_dataset = ImageToImage2D(args.val_dataset, tf_val)
predict_dataset = Image2D(args.val_dataset)

conv_depths = [int(args.width * (2**k)) for k in range(args.depth)]
print(conv_depths)
unet = UNet2D(args.in_channels, args.out_channels, conv_depths)
loss = LogNLLLoss()
optimizer = optim.Adam(unet.parameters(), lr=args.learning_rate)

results_folder = os.path.join(args.checkpoint_path, args.model_name)
if not os.path.exists(results_folder):
    os.makedirs(results_folder)

metric_list = MetricList({
    'jaccard': partial(jaccard_index),
    'f1': partial(f1_score)
})

model = Model(unet, loss, optimizer, results_folder, device=args.device)

model.fit_dataset(train_dataset,
                  n_epochs=args.epochs,
                  n_batch=args.batch_size,
                  shuffle=True,
                  val_dataset=val_dataset,
                  save_freq=args.save_freq,
                  save_model=args.save_model,
                  predict_dataset=predict_dataset,
                  metric_list=metric_list,
                  verbose=True)
コード例 #2
0
ファイル: predict.py プロジェクト: zhould1990/pytorch-UNet
import os

from argparse import ArgumentParser

from unet.model import Model
from unet.dataset import Image2D

parser = ArgumentParser()
parser.add_argument('--dataset', required=True, type=str)
parser.add_argument('--results_path', required=True, type=str)
parser.add_argument('--model_path', required=True, type=str)
parser.add_argument('--device', default='cpu', type=str)
args = parser.parse_args()

predict_dataset = Image2D(args.dataset)
model = torch.load(args.model_path)

if not os.path.exists(args.results_path):
    os.makedirs(args.results_path)

model = Model(unet, checkpoint_folder=args.results_path, device=args.device)

model.predict_dataset(predict_dataset, args.result_path)
コード例 #3
0
                              batch_size=args.batch_size,
                              shuffle=True,
                              collate_fn=my_collate,
                              num_workers=args.num_workers)

    valset = SynGrapeDataset(data=val_df, transform=transform)
    val_loader = DataLoader(valset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            collate_fn=my_collate,
                            num_workers=args.num_workers)

    data_loader = [train_loader, val_loader]

    # Create the network and put it on GPU
    model = Model(nb_classes=1, experiment=experiment_name, device=device)
    model.to(device)

    # Train the network
    model.train_model(data_loader, args.nb_epoch, args.learning_rate)

##############################################################################
# EVALUATION
##############################################################################
if args.evaluate:
    # Used when debugging because args.exp is not defined
    if args.debug:
        args.exp = os.path.join('unet', 'output', 'debug')

    testset = SynGrapeDataset(data=test_df.copy(), transform=transform)
    test_loader = DataLoader(testset,