示例#1
0
def train_resume():
    args = get_args()
    args.epochs = 100
    args.lr = 0.0075
    save_cp = True
    args.load = '/home/zhaojin/data/TacomaBridge/segdata/train/checkpointCP30.pth'

    args.gpu = True
    net = UNet(n_channels=1, n_classes=4)

    if args.load:
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from {}'.format(args.load))

    if args.gpu:
        net.cuda()
        # cudnn.benchmark = True # faster convolutions, but more memory

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  gpu=args.gpu,
                  save_cp=save_cp,
                  img_scale=args.scale)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        print('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
示例#2
0
def main(raw_args=None):
    """example:  python predict_batch.py --model '/home/zhaojin/data/TacomaBridge/segdata/train/checkpoint/weight_logloss_softmax/CP30.pth' --input '/home/zhaojin/data/TacomaBridge/capture/high-reso-clip2_rename' --output '/home/zhaojin/data/TacomaBridge/segdata/predict/high-reso-clip2_rename'"""
    args = get_args(raw_args)
    # args.model = '/home/zhaojin/data/TacomaBridge/segdata/train/checkpoint/logloss_softmax/CP12.pth'
    # in_files = ['/home/zhaojin/data/TacomaBridge/segdata/train/img/00034.png' ]
    # out_files = ['/home/zhaojin/my_path/dir/segdata/predict/00025.png']
    imgpath = args.input
    lblpath = args.output
    net = UNet(n_channels=1, n_classes=4)

    print("Loading model {}".format(args.model))

    if not args.cpu:
        print("Using CUDA version of the net, prepare your GPU !")
        net.cuda()
        net.load_state_dict(torch.load(args.model))
    else:
        net.cpu()
        net.load_state_dict(torch.load(args.model, map_location='cpu'))
        print("Using CPU version of the net, this may be very slow")

    print("Model loaded !")


    predict_img_batch(net=net,
                            imgpath=imgpath,
                            lblpath=lblpath,
                            scale_factor=args.scale,
                            out_threshold=args.mask_threshold,
                            use_dense_crf= not args.no_crf,
                            use_gpu=not args.cpu)
示例#3
0
def gpu_prediction_sample(args):
    in_files = args.input

    net = UNet(n_channels=3, n_classes=1)

    print("Loading model {}".format(args.model))

    print("Using CUDA version of the net, prepare your GPU !")
    net.cuda()
    net.load_state_dict(torch.load(args.model))

    print("Model loaded !")

    for i, fn in enumerate(in_files):
        print("\nPredicting image {} ...".format(fn))

        img = Image.open(fn)
        if img.size[0] < img.size[1]:
            print("Error: image height larger than the width")

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           use_dense_crf=not args.no_crf,
                           use_gpu=not args.cpu)

        print("Visualizing results for image {}, close to continue ...".format(
            fn))
        plot_img_and_mask(img, mask)
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    dataset = CPDataset(opt)

    # create dataloader
    loader = CPDataLoader(opt, dataset)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=opt.workers,
                                              pin_memory=True,
                                              sampler=None)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    gmm_model = GMM(opt)
    load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
    gmm_model.cuda()

    generator_model = UnetGenerator(25,
                                    4,
                                    6,
                                    ngf=64,
                                    norm_layer=nn.InstanceNorm2d)
    load_checkpoint(generator_model,
                    "checkpoints/tom_train_new_2/step_040000.pth")
    generator_model.cuda()

    embedder_model = Embedder()
    load_checkpoint(embedder_model,
                    "checkpoints/identity_train_64_dim/step_020000.pth")
    embedder_model = embedder_model.embedder_b.cuda()

    model = UNet(n_channels=4, n_classes=3)
    model.cuda()

    if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
        load_checkpoint(model, opt.checkpoint)

    test_residual(opt, data_loader, model, gmm_model, generator_model)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
示例#5
0
def train(opt):
    # seed
    setup_seed(1)
    # dataset
    train_dataset = DataSet(opt.train_data_root)
    logging.info('train dataset sample num:%d' % len(train_dataset))
    trainloader = DataLoader(train_dataset,
                             shuffle=True,
                             batch_size=opt.batch_size,
                             num_workers=opt.num_workers)
    # network
    net = UNet(22, [32, 64, 128, 256])
    if opt.checkpoint_model:
        net.load_state_dict(torch.load(os.path.join(opt.load_model_path, opt.checkpoint_model)))
    if opt.use_gpu:
        net.cuda()

    # loss function
    loss_func = MyLoss()
    # optimizer = optim.SGD(net.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
    
    # optimizer
    optimizer = optim.Adam(net.parameters())
    iteration = 1
    if opt.optimizer:
        optimizer.load_state_dict(torch.load(os.path.join(opt.load_optimizer_path, opt.optimizer)))
        iteration = int(opt.optimizer.split('.')[0].split('_')[-1]) + 1

    # training
    while iteration <= opt.max_iter:
        for ii, (input, label) in enumerate(trainloader):
            if opt.use_gpu:
                input = input.cuda()
                label = label.cuda()
            optimizer.zero_grad()
            output = net(input)
            loss = loss_func(output, label)
            loss.backward()
            optimizer.step()
            # print loss
            if (ii+1) % opt.display == 0:
                for param_group in optimizer.param_groups:
                    lr = param_group['lr']
                logging.info('iteration: %d, lr: %f, loss: %.6f' % (iteration, lr, loss))
                print('iteration: %d, lr: %f, loss: %.6f' % (iteration, lr, loss))
            # save model and optimizer
            if iteration % opt.snapshot == 0:
                torch.save(net.state_dict(),
						os.path.join(opt.load_model_path,
                                opt.model+'_'+str(iteration)+'.pth'))
                torch.save(optimizer.state_dict(),
						os.path.join(opt.load_optimizer_path,
                                'optim_'+str(iteration)+'.pth'))
                logging.info(opt.model + '_' + str(iteration) + '.pth saved')
            iteration += 1
            if iteration > opt.max_iter:
                break
示例#6
0
def main():

    model = UNet(in_channels=21, out_channels=4, init_features=128)
    model.cuda()
    # print(summary(model, input_size=(21, 256, 256)))
    loader_test, test_pairs = datasets()

    load_model_path = '../checkpoints/baseline-128-copy.pk'
    model.load_state_dict(torch.load(load_model_path))
    infer(model, loader_test, test_pairs)
def main():

    # init conv net
    print("init net")

    unet = UNet(3, 1)
    if os.path.exists("./unet.pkl"):
        unet.load_state_dict(torch.load("./unet.pkl"))
        print("load unet")
    unet.cuda()

    cnn = CNNEncoder()
    if os.path.exists("./cnn.pkl"):
        cnn.load_state_dict(torch.load("./cnn.pkl"))
        print("load cnn")
    cnn.cuda()

    # init dataset
    print("init dataset")
    data_loader = dataset84.jump_data_loader()

    # init optimizer
    unet_optimizer = torch.optim.Adam(unet.parameters(), lr=0.001)
    cnn_optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)
    criterion = nn.MSELoss()

    # train
    print("training...")
    for epoch in range(1000):
        for i, (images, press_times) in enumerate(data_loader):
            images = Variable(images).cuda()
            press_times = Variable(press_times.float()).cuda()

            masks = unet(images)

            segmentations = images * masks
            predict_press_times = cnn(segmentations)

            loss = criterion(predict_press_times, press_times)

            unet_optimizer.zero_grad()
            cnn_optimizer.zero_grad()
            loss.backward()
            unet_optimizer.step()
            cnn_optimizer.step()

            if (i + 1) % 10 == 0:
                print("epoch:", epoch, "step:", i, "loss:", loss.data[0])
            if (epoch + 1) % 5 == 0 and i == 0:
                torch.save(unet.state_dict(), "./unet.pkl")
                torch.save(cnn.state_dict(), "./cnn.pkl")
                print("save model")
示例#8
0
def main(args):
    in_files = args.input
    out_files = get_output_filenames(args)

    n_channels = np.load(in_files[0]).shape[2]
    ## NPY 1 channel Uint16 medical images
    net = UNet(n_channels=n_channels, n_classes=2)

    print("Loading model {}".format(args.model))

    if not args.cpu:
        print("Using CUDA version of the net, prepare your GPU !")
        net.cuda()
        net.load_state_dict(torch.load(args.model))
    else:
        net.cpu()
        net.load_state_dict(torch.load(args.model, map_location='cpu'))
        print("Using CPU version of the net, this may be very slow")

    print("Model loaded !")

    for i, fn in enumerate(in_files):
        print("\nPredicting image {} ...".format(fn))

        img = np.load(fn)

        mask = predict_img(net=net,
                           img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           use_gpu=not args.cpu)

        if args.viz:
            print("Visualizing results for image {}, close to continue ...".
                  format(fn))
            ## save this plt
            fn = out_files[i][:-3] + 'jpg'

            if args.vizsave:
                save = True

            plot_img_and_mask(img, mask, save=save, fn=fn)

        if args.save:
            # Save NPY file
            out_fn = out_files[i]

            np.save(out_fn, mask)

            print("Mask saved to {}".format(out_files[i]))
示例#9
0
def main(raw_args):
    args = get_args(raw_args)
    args.model = '/home/zhaojin/data/TacomaBridge/segdata/train/checkpoint/logloss_softmax/CP12.pth'
    in_files = ['/home/zhaojin/data/TacomaBridge/segdata/train/img/00034.png']
    out_files = ['/home/zhaojin/my_path/dir/segdata/predict/00025.png']

    net = UNet(n_channels=1, n_classes=4)

    print("Loading model {}".format(args.model))

    if not args.cpu:
        print("Using CUDA version of the net, prepare your GPU !")
        net.cuda()
        net.load_state_dict(torch.load(args.model))
    else:
        net.cpu()
        net.load_state_dict(torch.load(args.model, map_location='cpu'))
        print("Using CPU version of the net, this may be very slow")

    print("Model loaded !")

    for i, fn in enumerate(in_files):
        print("\nPredicting image {} ...".format(fn))

        img = Image.open(fn)
        if img.size[0] < img.size[1]:
            print("Error: image height larger than the width")

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           use_dense_crf=not args.no_crf,
                           use_gpu=not args.cpu)

        if args.viz:
            print("Visualizing results for image {}, close to continue ...".
                  format(fn))
            plot_img_and_mask(img, mask)
        # if not args.no_save:
        #     out_fn = out_files[i]
        #     print('mask', mask)
        #     result = mask_to_image(mask)
        #
        #     result.save(out_files[i])
        #
        #     print("Mask saved to {}".format(out_files[i]))

    return mask
示例#10
0
def learn_msra():
    """
    Learn from MSRA10K dataset
    Step 2/3 of 1.3
    *** This uses ~5GB of memory (CatDataset could be changed to use lazy loading if this is a problem) ***
    """
    loader = get_cat_loader(os.path.join('MSRA10K', 'Train'),
                            batch_size=MSRA_BATCH_SIZE,
                            augment=False)
    test_loader = get_cat_loader(os.path.join('MSRA10K', 'Test'),
                                 batch_size=1,
                                 augment=False)
    model = UNet()
    if torch.cuda.is_available():
        model.cuda()
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=MSRA_LEARNING_RATE)
    criterion = cross_entropy_loss
    for e in range(0, MSRA_EPOCHS):
        train_loss = 0
        train_accuracy = 0
        test_loss = 0
        test_accuracy = 0
        # Training
        for images, masks, colors in tqdm(loader):
            optimizer.zero_grad()
            preds = model(images)
            loss = criterion(preds, masks)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_accuracy += sdc(preds, masks)
        else:
            # Get test loss/accuracy
            with torch.no_grad():
                for images, masks, colors in test_loader:
                    preds = model(images)
                    test_loss += criterion(preds, masks).item()
                    test_accuracy += sdc(preds, masks)
            print_epoch(e,
                        train_loss / len(loader),
                        train_accuracy / len(loader),
                        test_loss / len(test_loader),
                        test_accuracy / len(test_loader),
                        total_epochs=MSRA_EPOCHS)
    torch.save(model, 'model_msra.pth')
    # Save the weights in a more portable format to be uploaded with the report
    torch.save(model.state_dict(), 'msra_weights.pth')
示例#11
0
def val(opt):
    # dataset
    val_dataset = DataSet(opt.val_data_root)
    valloader = DataLoader(val_dataset,
                            batch_size=opt.batch_size,
                            num_workers=opt.num_workers)
    # network
    net = UNet(22, [32, 64, 128, 256])
    net.eval()
    models = natsorted(os.listdir(opt.load_model_path))

    CSI = np.zeros((len(models), 5), dtype=float)
    CSI[:, 0] = np.arange(opt.snapshot, opt.max_iter+1, opt.snapshot)

    with torch.no_grad():
        for iteration, model in enumerate(models):
            print(model)
            logging.info(model)
            dec_value = []
            labels = []
            net.load_state_dict(torch.load(os.path.join(opt.load_model_path, model)))
            if opt.use_gpu:
                net.cuda()
            #  softmax output
            for input, target in valloader:
                if opt.use_gpu:
                    input = input.cuda()
                output = net(input).permute(0, 2, 3, 1).contiguous().view(-1, 2)
                target = target.view(-1)
                output = torch.nn.functional.softmax(output, dim=1).detach().cpu().numpy()
                dec_value.append(output)
                labels.append(target.numpy())

            dec_value = np.concatenate(dec_value)
            labels = np.concatenate(labels).squeeze()
            # save dec_value
            np.savetxt(os.path.join(opt.result_file,
                                    'iteration_' + str((iteration+1)*opt.snapshot) + '.txt'), dec_value, fmt='%10.6f')
            # find best CSI
            CSI[iteration, 1:] = find_best_CSI(dec_value, labels)
            # save CSI to file every epoch
            np.savetxt(opt.result_file + '/CSI.txt', CSI, fmt='%8d'+'%8.4f'*4)

    best_iteration = np.arange(opt.snapshot, opt.max_iter+1, opt.snapshot)[np.argmax(CSI[:,1])]
    confidence = CSI[int(best_iteration/opt.snapshot)-1, 4]
    logging.info('best_iteration: %d,confidence: %.6f' % (best_iteration, confidence))

    return best_iteration, confidence
示例#12
0
def main(batch_size, data_root):
    train_data = MyDataset(
        mode='train',
        txt=data_root + 'train_label_balance.txt',
        transform=transforms.Compose([
            # transforms.RandomResizedCrop(224),
            # transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.3301, 0.3301, 0.3301],
                                 std=[0.1938, 0.1938, 0.1938])
        ]))
    test_data = MyDataset(
        mode='test',
        txt=data_root + 'test_label_balance.txt',
        transform=transforms.Compose([
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.3301, 0.3301, 0.3301],
                                 std=[0.1938, 0.1938, 0.1938])
        ]))

    train_loader = DataLoader(
        train_data, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
    test_loader = DataLoader(
        test_data, batch_size=batch_size, shuffle=False, num_workers=8)
    model = UNet(n_channels=3, n_classes=3)
    print(model)
    model = nn.DataParallel(model.cuda(), device_ids=[0])
    optimizer = optim.SGD(params=model.parameters(),
                          lr=0.01, momentum=0.9, weight_decay=1e-5)
    # optimizer = optim.Adam(params=model.parameters(), lr=0.01)
    scheduler = StepLR(optimizer, 10, gamma=0.1)
    trainer = Trainer(model, optimizer, F.cross_entropy, save_dir=".")
    trainer.loop(50, train_loader, test_loader, scheduler)
示例#13
0
def main(args):
    # If there is more than one channel, adapt unet for the channel depth
    try:
        n_channels = np.load(glob.glob(os.path.join(args.img,
                                                    '*.npy'))[0]).shape[2]
    except:
        n_channels = 1

    # Change number of classes
    max0 = np.max(np.load(glob(os.path.join(args.mask, '*.npy'))[0]))
    max1 = np.max(np.load(glob(os.path.join(args.mask, '*.npy'))[1]))
    max2 = np.max(np.load(glob(os.path.join(args.mask, '*.npy'))[2]))
    max3 = np.max(np.load(glob(os.path.join(args.mask, '*.npy'))[3]))
    max4 = np.max(np.load(glob(os.path.join(args.mask, '*.npy'))[4]))

    n_classes = int(np.max((max0, max1, max2, max3, max4)) + 1)

    net = UNet(n_channels=n_channels, n_classes=n_classes)

    if args.load:
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from {}'.format(args.load))

    if args.gpu:
        torch.cuda.set_device(args.gpunum)
        net.cuda()
        cudnn.benchmark = True  # faster convolutions, but more memory

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batch_size,
                  lr=args.learning_rate,
                  gpu=args.gpu,
                  img_scale=args.scale,
                  dir_img=args.img,
                  dir_mask=args.mask,
                  dir_checkpoint=args.checkpoint,
                  channels=n_channels,
                  classes=n_classes)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        print('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
示例#14
0
def domain():
    net = UNet(n_channels=1, n_classes=6)
    net.load_state_dict(torch.load('weight/labeled.pth'))
    net.cuda()
    i = 1000
    means = 0
    while i < 1321:
        x, y = get_test(i)
        imgs = torch.from_numpy(x)
        imgs = imgs.cuda()
        masks_preds = net(imgs)
        masks_preds = F.softmax(masks_preds, 1)
        mean = compared(masks_preds, y, i)
        print('accuracy of {} ='.format(i), mean)
        means += mean
        i += 1
    print('accuracy=', means / 321)
示例#15
0
def unet(in_files='./predictions/images', output='./predictions/images'):
    net = UNet(n_channels=3, n_classes=1)

    net.cuda()
    net.load_state_dict(torch.load('./MODEL.pth'))

    for i, fn in enumerate(in_files):

        img = Image.open(fn)

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=0.5,
                           out_threshold=0.5,
                           use_dense_crf=False,
                           use_gpu=True)
        result = mask_to_image(mask)
        result.save(output + '/unet' + str(i))
示例#16
0
def domain():
    net = UNet(n_channels=1, n_classes=1)
    net.load_state_dict(torch.load('weight/' + name + '.pth'))
    # net.load_state_dict(torch.load('weight/teacher.pth'))
    net.cuda()
    o = 200
    i = 200
    while (i < 578):
        x = get_data(i)
        imgs = torch.from_numpy(x)
        imgs = imgs.cuda()
        masks_preds = net(imgs)
        # saveResult('data/result_labeled', masks_preds)
        saveResult('data/result_' + name, masks_preds, o)
    mean = 0
    o = 200
    mean = compare('data/unlabel_masks', 'data/result_' + name, o, mean)
    print('accuracy=', mean / 378)
示例#17
0
def prediction(args):
    in_files = args.input
    out_files = get_output_filenames(args)

    net = UNet(n_channels=3, n_classes=1)

    print("Loading model {}".format(args.model))

    if not args.cpu:
        print("Using CUDA version of the net, prepare your GPU !")
        net.cuda()
        net.load_state_dict(torch.load(args.model))
    else:
        net.cpu()
        net.load_state_dict(torch.load(args.model, map_location='cpu'))
        print("Using CPU version of the net, this may be very slow")

    print("Model loaded !")

    for i, fn in enumerate(in_files):
        print("\nPredicting image {} ...".format(fn))

        img = Image.open(fn)
        if img.size[0] < img.size[1]:
            print("Error: image height larger than the width")

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           use_dense_crf=not args.no_crf,
                           use_gpu=not args.cpu)

        if args.viz:
            print("Visualizing results for image {}, close to continue ...".
                  format(fn))
            plot_img_and_mask(img, mask)

        if not args.no_save:
            out_fn = out_files[i]
            result = mask_to_image(mask)
            result.save(out_files[i])

            print("Mask saved to {}".format(out_files[i]))
示例#18
0
def main(raw_args=None):
    """example:  python predict.py --model '/home/zhaojin/data/TacomaBridge/segdata/train/checkpoint/weight_logloss_softmax/CP30.pth' --input '/home/zhaojin/data/TacomaBridge/segdata/train/img/00034.png' --viz"""
    args = get_args(raw_args)
    print('args', args)
    # args.model = '/home/zhaojin/data/TacomaBridge/segdata/train/checkpoint/logloss_softmax/CP12.pth'
    # in_files = ['/home/zhaojin/data/TacomaBridge/segdata/train/img/00034.png' ]
    # out_files = ['/home/zhaojin/my_path/dir/segdata/predict/00025.png']
    in_files = args.input
    net = UNet(n_channels=1, n_classes=4)

    print("Loading model {}".format(args.model))

    if not args.cpu:
        print("Using CUDA version of the net, prepare your GPU !")
        net.cuda()
        net.load_state_dict(torch.load(args.model))
    else:
        net.cpu()
        net.load_state_dict(torch.load(args.model, map_location='cpu'))
        print("Using CPU version of the net, this may be very slow")

    print("Model loaded !")

    for i, fn in enumerate(in_files):
        print("\nPredicting image {} ...".format(fn))

        img = Image.open(fn)
        if img.size[0] < img.size[1]:
            print("Error: image height larger than the width")

        mask = predict_img(net=net,
                           full_img=img,
                           scale_factor=args.scale,
                           out_threshold=args.mask_threshold,
                           use_dense_crf=not args.no_crf,
                           use_gpu=not args.cpu)

        if args.viz:
            print("Visualizing results for image {}, close to continue ...".
                  format(fn))
            mask = np.transpose(mask, axes=[1, 2, 0])
            plot_img_and_mask(img, mask)
示例#19
0
def eval_on_dir(image_dir, weights_file='q1.pth', out_dir='output'):
    # Load model from disk
    model = UNet()
    model.load_state_dict(torch.load(weights_file))
    cuda = torch.cuda.is_available()
    if cuda:
        model.cuda()
    # Run inferences
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    for filename in os.listdir(image_dir):
        image = cv2.resize(
            cv2.cvtColor(cv2.imread(os.path.join(image_dir, filename)),
                         cv2.COLOR_BGR2GRAY), (128, 128))
        tensor = torch.from_numpy(
            image.reshape(1, 1, image.shape[0], image.shape[1])).float()
        if cuda:
            tensor = tensor.cuda()
        result = mask_argmax(model(tensor).detach()).cpu().numpy()[0] * 255
        cv2.imwrite(os.path.join(out_dir, filename), result)
示例#20
0
def train_1st(raw_args=None):
    args = get_args(raw_args)
    # args.epochs = 30
    # args.lr = 0.1
    save_cp = True
    args.load = ''
    # args.batchsize=20
    # args.scale=0.5
    imagepath = args.imagepath
    maskpath = args.maskpath
    cpsavepath = args.savepath
    args.gpu = True
    net = UNet(n_channels=1, n_classes=4)

    if args.load:
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from {}'.format(args.load))

    if args.gpu:
        net.cuda()
        # cudnn.benchmark = True # faster convolutions, but more memory

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  lrd=args.lrd,
                  gpu=args.gpu,
                  save_cp=save_cp,
                  img_scale=args.scale,
                  imagepath=imagepath,
                  maskpath=maskpath,
                  cpsavepath=cpsavepath)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        print('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
示例#21
0
def main():

    # init conv net
    
    unet = UNet(3,1)
    if os.path.exists("./unet.pkl"):
        unet.load_state_dict(torch.load("./unet.pkl"))
        print("load unet")
    unet.cuda()

    cnn = CNNEncoder()
    if os.path.exists("./cnn.pkl"):
        cnn.load_state_dict(torch.load("./cnn.pkl"))
        print("load cnn")
    cnn.cuda()

    unet.eval()
    cnn.eval()
    
    print("load ok")

    while True:
        pull_screenshot("autojump.png") # obtain screen and save it to autojump.png
        image = Image.open('./autojump.png')
        set_button_position(image)
        image = preprocess(image)
        
        image = Variable(image.unsqueeze(0)).cuda()
        mask = unet(image)

        plt.imshow(mask.squeeze(0).squeeze(0).cpu().data.numpy(), cmap='hot', interpolation='nearest')
        plt.show()
        
        segmentation = image * mask

        press_time = cnn(segmentation)
        press_time = press_time.cpu().data[0].numpy()
        print(press_time)
        jump(press_time)
        
        time.sleep(random.uniform(0.6, 1.1))
示例#22
0
def test_hdr_net(model_path,
                 dir_checkpoints,
                 experiment_name,
                 dataloader,
                 criterion,
                 gpu=False,
                 expositions_num=15,
                 tb=False):

    print('{}{}{}'.format('+', '=' * 78, '+'))
    print('| {0:} Testing {1:}|'.format(' ' * 30, ' ' * 30))
    print('{}{}{}'.format('+', '=' * 78, '+'))
    tot_psnrm = 0
    tot_psnrhvs = 0
    steps = 0
    for i, b in enumerate(dataloader):
        steps += 1
        imgs, true_masks, imgs_ids = b['input'], b['target'], b['id']
        net = UNet(n_channels=3, n_classes=3)
        net.load_state_dict(torch.load(model_path))

        if gpu:
            net.cuda()
            imgs = imgs.cuda()
            true_masks = true_masks.cuda()
        else:
            print(' GPU not available')

        pred = net(imgs)

        batch_hvsm, batch_hvs = get_psnrhs(true_masks, pred, 1)
        tot_psnrm += batch_hvsm
        tot_psnrhvs += batch_hvs

    avg_psnr_m = tot_psnrm / steps
    avg_psnr_hvs = tot_psnrhvs / steps
    print('| AVG PSNR-HVS-M: {0:0.04} | AVG PSNR-hvs: {1:0.04} '.format(
        avg_psnr_m, avg_psnr_hvs))
    print('{}{}{}'.format('+', '-' * 78, '+'))
    return avg_psnr_m, avg_psnr_hvs
示例#23
0
def test(opt):
    # dataset
    test_dataset = DataSet(opt.test_data_root, test=True)
    testloader = DataLoader(test_dataset,
                            batch_size=opt.batch_size,
                            num_workers=opt.num_workers)

    # network
    net = UNet(22, [32, 64, 128, 256])
    net.eval()
    net.load_state_dict(torch.load(os.path.join(opt.load_model_path, opt.checkpoint_model)))
    if opt.use_gpu:
        net.cuda()

    dec_value = []
    labels = []
    
    with torch.no_grad():
        #  softmax output
        for input, target in testloader:
            if opt.use_gpu:
                input = input.cuda()
            output = net(input).permute(0, 2, 3, 1).contiguous().view(-1, 2)
            output = torch.nn.functional.softmax(output, dim=1).detach().cpu().numpy()
            target = target.view(-1)
            dec_value.append(output)
            labels.append(target.numpy())

    dec_value = np.concatenate(dec_value)
    labels = np.concatenate(labels).squeeze()
    # save dec_value
    np.savetxt(os.path.join(opt.result_file,
                            'best_iteration_' + str(opt.best_iteration) + '.txt'), dec_value, fmt='%10.6f')
    # find best CSI
    res = find_best_CSI(dec_value, labels, opt.confidence)
    print(res)
    np.savetxt(os.path.join(opt.result_file, 'test_result.txt'), [res],
               fmt='CSI:%.6f\nPOD:%.6f\nFAR:%.6f\nconfidence:%.6f')
示例#24
0
def main():

    dsc_loss = DiceLoss()

    model = UNet(in_channels=21, out_channels=4)
    model.cuda()
    print(summary(model, input_size=(21, 256, 256)))
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    scheduler = optim.lr_scheduler.CyclicLR(optimizer,
                                            base_lr=0.0001,
                                            max_lr=0.01)
    loader_train, loader_eval = datasets()

    save_model_path = '../checkpoints/baseline.pk'
    load_model_path = '../checkpoints/baseline-056.pk'
    model.load_state_dict(torch.load(load_model_path))

    for e in range(1000):
        model = train_epoch(model, loader_train, optimizer, dsc_loss)
        model = eval_epoch(model, loader_eval, dsc_loss)
        scheduler.step()
        torch.save(model.state_dict(), save_model_path)
        print('begin epoch', e, 'saving to', save_model_path)
示例#25
0
def validate(params, data_loader, _log, flag):
    with th.no_grad():
        sig = nn.Sigmoid()
        criterion = nn.BCEWithLogitsLoss()
        running_loss = 0
        shape = params['shape']
        res = th.zeros(shape)
        prune = params['prune']

        if params['model'] == 'FCNwBottleneck':
           trained_model = model.FCNwBottleneck(params['feature_num'], params['pix_res'])
        elif params['model'] == 'UNet':
            trained_model = UNet(params['feature_num'], 1)
        elif params['model'] == 'SimplerFCNwBottleneck':
            trained_model = model.SimplerFCNwBottleneck(params['feature_num'])
        elif params['model'] == 'Logistic':
            trained_model = model.Logistic(params['feature_num'])
        elif params['model'] == 'PolyLogistic':
            trained_model = model.PolyLogistic(params['feature_num'])
        
        trained_model = trained_model.cuda()
        if th.cuda.device_count() > 1:
            trained_model = nn.DataParallel(trained_model)
        trained_model.load_state_dict(th.load(params['load_model']))
        _log.info('[{}] model is successfully loaded.'.format(ctime()))

        data_iter = iter(data_loader)
        for iter_ in range(len(data_iter)):
            sample = data_iter.next()
            data, gt = sample['data'].cuda(), sample['gt'].cuda()
            ignore = gt < 0
            prds = trained_model.forward(data)[:, :, prune:-prune, prune:-prune]
            loss = criterion(prds[1-ignore], gt[1-ignore])
            running_loss += loss.item()
            
            prds = sig(prds)
            prds[ignore] = 0

            for idx in range(prds.shape[0]):
                row, col = sample['index'][0][idx], sample['index'][1][idx]
                res[
                    row*params['ws']:(row+1)*params['ws'],
                    col*params['ws']:(col+1)*params['ws']
                ] = prds[idx, 0, :, :]
            _log.info('[{}]: writing [{}/{}]'.format(ctime(), iter_, len(data_iter)))
        _log.info('all image patches are written!')
        save_image(res, '{}{}_{}_predictions.tif'.format(params['save_to'], params['region'], flag))
        np.save('{}{}_{}_predictions.npy'.format(params['save_to'],params['region'], flag), res.data.numpy())
        return running_loss/len(data_iter)
示例#26
0
def run_main():
    import argparse
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--dataroot',
                        default="/root/chujiajia/Dataset/COVID19/")
    parser.add_argument('--batchsize', type=int, default=12)
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--gpu',
                        type=int,
                        default=1,
                        help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
    parser.add_argument('--initial_lr', type=float, default=0.001)
    parser.add_argument('--resultpath',
                        type=str,
                        default="/root/chujiajia/Results/")
    parser.add_argument('--experiment_name',
                        type=str,
                        default='bz12_bzdice_baseline_covid_epoch100')
    parser.add_argument('--num_workers', default=8, type=int)
    parser.add_argument('--model_path', type=str, default="None")
    parser.add_argument('--contour', type=str, default="None")  # None contour

    args = parser.parse_args()
    # torch.cuda.set_device(int(args.gpu))
    global_step = 0
    savepath = os.path.join(args.resultpath, args.experiment_name)
    if not os.path.exists(savepath):
        os.mkdir(savepath)
    imgsavepath = os.path.join(savepath, "imgresult")
    if not os.path.exists(imgsavepath):
        os.mkdir(imgsavepath)
    logpath = os.path.join(savepath, "log.txt")

    traindataset = Covid19DataSet(args.dataroot,
                                  set="train",
                                  lesion_phase=True,
                                  augmentations=None)
    trainloader = DataLoader(traindataset,
                             batch_size=args.batchsize,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    testdataset = Covid19DataSet(args.dataroot,
                                 set="test",
                                 lesion_phase=True,
                                 augmentations=None)
    testloader = DataLoader(testdataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=args.num_workers,
                            pin_memory=True)
    print("training dataset:{}".format(len(traindataset)))
    print("testing dataset:{}".format(len(testdataset)))
    mask_ids = [
        'XH2_10', 'XH2_70', 'XH3_26', 'XH2_89', 'XH2_34', 'XH3_7', 'XH3_12',
        'XH3_123', 'XH5_36', 'XH5_49', '82', '83', '84', '85', '86', '87',
        '88', '89', '95', '96'
    ]

    model = UNet(n_channels=1, n_classes=1)
    # model = torch.nn.DataParallel(model, device_ids=[0,3])#.cuda() #0,
    model.cuda()

    if args.model_path != "None":
        model.load_state_dict(torch.load(args.model_path, map_location='cpu'))
        print("load model from" + args.model_path)

    model.train()
    model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.initial_lr)

    lr = args.initial_lr
    # Training loop
    save_mean_dice = 0.0
    for epoch in range(1, args.num_epochs + 1):
        lr = adjust_learning_rate(optimizer, epoch, lr)
        model.train()
        model.cuda()
        composite_loss_total = 0.0
        loss_dice_total = 0.0
        loss_contour_total = 0.0
        num_steps = 0
        for i, train_batch in enumerate(trainloader):
            # Supervised component --------------------------------------------
            src_img, labels, edge, train_name = train_batch
            src_img = torch.Tensor(src_img).cuda()
            # print(edge)
            label = torch.Tensor(labels).cuda()
            pred_seg = F.sigmoid(model(src_img))

            # loss_dice = dice_loss(pred_seg,label)
            loss_dice = bz_dice_loss(args.batchsize, pred_seg, label)

            if args.contour == "active":
                label = torch.Tensor(labels).cuda()
                loss_contour = active_contour_loss(pred_seg,
                                                   label)  #* args.w_contour
            elif args.contour == "boundary":
                label = torch.Tensor(labels).cuda()
                loss_contour = boundary_loss(pred_seg, labels)
            elif args.contour == "contour":
                label = torch.Tensor(labels)
                edge = torch.Tensor(edge)
                loss_contour = contour_loss(pred_seg, labels, edge)
            elif args.contour == "None":
                loss_contour = 0

            if epoch == 1 and i == 1:
                print(train_name)
                line = str(train_name)
                write_log(logpath, line)

            if args.contour == "None":
                composite_loss = loss_dice
                composite_loss_total += composite_loss.item()
            else:
                composite_loss = loss_contour + loss_dice
                composite_loss_total += composite_loss.item()

            optimizer.zero_grad()
            composite_loss.backward()
            optimizer.step()

            # composite_loss_total += composite_loss.item()
            loss_dice_total += loss_dice.item()

            if args.contour == "None":
                loss_contour_total += 0
            else:
                loss_contour_total += loss_contour.item()

            num_steps += 1
            global_step += 1

        dice_loss_avg = loss_dice_total / num_steps
        contour_loss_avg = loss_contour_total / num_steps
        composite_loss_avg = composite_loss_total / num_steps

        torch.save(model.state_dict(), os.path.join(savepath, 'CP_latest.pth'))

        line = "Epoch: {},Composite Loss: {:.6f},Dice Loss:{:.6f},Boundary Loss:{:.6f},lr:{:.6f},best_dice:{:.6f}".format(
            epoch, composite_loss_avg, dice_loss_avg, contour_loss_avg, lr,
            save_mean_dice)

        print(line)
        write_log(logpath, line)

        if epoch > 0 and epoch % 5 == 0:
            print("valiation_structure")
            # mean_dice = valiation_structure(model,args,"mr_test",imgsavepath,logpath)
            mean_dice = valiation_structure(model, args, logpath, testloader,
                                            mask_ids)
            if mean_dice > save_mean_dice:
                save_mean_dice = mean_dice
                torch.save(model.state_dict(),
                           os.path.join(savepath, 'CP_Best.pth'))
                line = 'Best Checkpoint {} saved !Mean_dice:{}'.format(
                    epoch, save_mean_dice)
                write_log(logpath, line)
if __name__ == '__main__':
    __spec__ = None

    opt = Option()

    dataset = CarvanaDataset(opt.dir_img, opt.dir_mask, scale=opt.scale)
    dataloader = DataLoader(dataset=dataset,
                            batch_size=opt.batchsize,
                            shuffle=True,
                            num_workers=opt.workers)

    unet = UNet(in_dim=opt.in_dim)
    loss_func = nn.BCEWithLogitsLoss()
    if opt.cuda:
        unet = unet.cuda()
        loss_func = loss_func.cuda()
    optimizer = torch.optim.Adam(unet.parameters(),
                                 lr=opt.lr,
                                 weight_decay=0.0005)
    # 加载预训练的参数
    if opt.pretrained:
        state = torch.load(opt.net_path)
        unet.load_state_dict(state['net'])
        optimizer.load_state_dict(state['optimizer'])
    unet.train()

    loss_list = []
    loss_list_big = []
    plt.ion()
    plt.show()
示例#28
0
    (options, args) = parser.parse_args()
    return options


if __name__ == '__main__':
    args = get_args()
    print("args:{}".format(args))

    net = UNet(n_channels=3, n_classes=1)

    if args.load:
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from {}'.format(args.load))

    if args.gpu:
        net.cuda()
        # cudnn.benchmark = True # faster convolutions, but more memory

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  gpu=args.gpu,
                  img_scale=args.scale)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        print('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
示例#29
0
def train_net(options):
    dir_img = options.data + '/images/'
    dir_mask = options.data + '/masks/'
    dir_edge = options.data + '/edges/'
    dir_save_model = options.save_model
    dir_save_state = options.save_state
    ids = load.get_ids(dir_img)

    # trainとvalに分ける  # ここで順序も決まってしまう
    iddataset = {}
    iddataset["train"] = list(
        map(
            lambda x: x.split(".png")[0],
            os.listdir(
                "/data/unagi0/kanayama/dataset/nuclei_images/stage1_train_splited/train_default/"
            )))
    iddataset["val"] = list(
        map(
            lambda x: x.split(".png")[0],
            os.listdir(
                "/data/unagi0/kanayama/dataset/nuclei_images/stage1_train_splited/val_default/"
            )))
    N_train = len(iddataset['train'])
    N_val = len(iddataset['val'])
    N_batch_per_epoch_train = int(N_train / options.batchsize)
    N_batch_per_epoch_val = int(N_val / options.val_batchsize)

    # 実験条件の表示
    option_manager.display_info(options, N_train, N_val)

    # 結果の記録用インスタンス
    logger = Logger(options, iddataset)

    # モデルの定義
    net = UNet(3, 1, options.drop_rate1, options.drop_rate2,
               options.drop_rate3)

    # 学習済みモデルをロードする
    if options.load_model:
        net.load_state_dict(torch.load(options.load_model))
        print('Model loaded from {}'.format(options.load_model))

    # モデルをGPU対応させる
    if options.gpu:
        net.cuda()

    # 最適化手法を定義
    optimizer = optim.Adam(net.parameters())

    # optimizerの状態をロードする
    if options.load_state:
        optimizer.load_state_dict(torch.load(options.load_state))
        print('State loaded from {}'.format(options.load_state))

    # 学習開始
    for epoch in range(options.epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, options.epochs))
        train = load.get_imgs_and_masks(iddataset['train'], dir_img, dir_mask,
                                        dir_edge, options.resize_shape)
        original_sizes = load.get_original_sizes(iddataset['val'], dir_img,
                                                 '.png')
        val = load.get_imgs_and_masks(iddataset['val'],
                                      dir_img,
                                      dir_mask,
                                      dir_edge,
                                      options.resize_shape,
                                      train=False)
        train_loss = 0
        validation_loss = 0
        validation_score = 0
        validation_scores = np.zeros(10)

        # training phase
        if not options.skip_train:
            net.train()
            for i, b in enumerate(utils.batch(train, options.batchsize)):
                X = np.array([j[0] for j in b])
                y = np.array([j[1] for j in b])
                w = np.array([j[2] for j in b])

                if X.shape[
                        0] != options.batchsize:  # batch sizeを揃える(揃ってないとなぜかエラーになる)
                    continue

                X, y, w = utils.data_augmentation(X, y, w)

                X = torch.FloatTensor(X)
                y = torch.ByteTensor(y)
                w = torch.ByteTensor(w)

                if options.gpu:
                    X = X.cuda()
                    y = y.cuda()
                    w = w.cuda()

                X = Variable(X)
                y = Variable(y)
                w = Variable(w)

                y_pred = net(X)
                probs = F.sigmoid(y_pred)
                probs_flat = probs.view(-1)
                y_flat = y.view(-1)
                w_flat = w.view(-1)
                weight = (w_flat.float() / 255.) * (options.weight - 1) + 1.
                loss = weighted_binary_cross_entropy(probs_flat,
                                                     y_flat.float() / 255.,
                                                     weight)
                train_loss += loss.data[0]

                print('{0:.4f} --- loss: {1:.6f}'.format(
                    i * options.batchsize / N_train, loss.data[0]))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            print('Epoch finished ! Loss: {}'.format(train_loss /
                                                     N_batch_per_epoch_train))
            logger.save_loss(train_loss / N_batch_per_epoch_train,
                             phase="train")

        # validation phase
        net.eval()
        probs_array = np.zeros(
            (N_val, options.resize_shape[0], options.resize_shape[1]))
        for i, b in enumerate(utils.batch(val, options.val_batchsize)):
            X = np.array([j[0] for j in b])[:, :3, :, :]  # alpha channelを取り除く
            y = np.array([j[1] for j in b])
            w = np.array([j[2] for j in b])
            X = torch.FloatTensor(X)
            y = torch.ByteTensor(y)
            w = torch.ByteTensor(w)

            if options.gpu:
                X = X.cuda()
                y = y.cuda()
                w = w.cuda()

            X = Variable(X, volatile=True)
            y = Variable(y, volatile=True)
            w = Variable(w, volatile=True)

            y_pred = net(X)
            probs = F.sigmoid(y_pred)

            probs_flat = probs.view(-1)
            y_flat = y.view(-1)
            w_flat = w.view(-1)

            # edgeに対して重み付けをする
            weight = (w_flat.float() / 255.) * (options.weight - 1) + 1.
            loss = weighted_binary_cross_entropy(probs_flat,
                                                 y_flat.float() / 255., weight)
            validation_loss += loss.data[0]

            # 後処理
            y_hat = np.asarray((probs > 0.5).data)
            y_hat = y_hat.reshape((y_hat.shape[2], y_hat.shape[3]))
            y_truth = np.asarray(y.data)

            # ノイズ除去 & 二値化
            #dst_img = remove_noise(probs_resized, (original_height, original_width))
            #dst_img = (dst_img * 255).astype(np.uint8)

            # calculate validatation score
            if (options.calc_score_step !=
                    0) and (epoch + 1) % options.calc_score_step == 0:
                score, scores, _ = validate(y_hat, y_truth)
                validation_score += score
                validation_scores += scores
                print("Image No.", i, ": score ", score)

            logger.save_output_mask(y_hat, original_sizes[i],
                                    iddataset['val'][i])
            if options.save_probs is not None:
                logger.save_output_prob(np.asarray(probs.data[0][0]),
                                        original_sizes[i], iddataset['val'][i])

        print('Val Loss: {}'.format(validation_loss / N_batch_per_epoch_val))
        logger.save_loss(validation_loss / N_batch_per_epoch_val, phase="val")

        # スコアを保存する
        if (options.calc_score_step !=
                0) and (epoch + 1) % options.calc_score_step == 0:
            print('score: {}'.format(validation_score / i))
            logger.save_score(validation_scores, validation_score,
                              N_batch_per_epoch_val, epoch)

        # modelとoptimizerの状態を保存する。
        if (epoch + 1) % 10 == 0:
            torch.save(
                net.state_dict(), dir_save_model + str(options.id) +
                '_CP{}.model'.format(epoch + 1))
            torch.save(
                optimizer.state_dict(), dir_save_state + str(options.id) +
                '_CP{}.state'.format(epoch + 1))
            print('Checkpoint {} saved !'.format(epoch + 1))

        # draw loss graph
        logger.draw_loss_graph("./results/loss")
        # draw score graph
        if (options.calc_score_step !=
                0) and (epoch + 1) % options.calc_score_step == 0:
            logger.draw_score_graph("./results/score" + str(options.id) +
                                    ".png")
if cuda:
    torch.cuda.manual_seed(opt.seed)

print('===> Loading datasets')
train_set = get_training_set(opt.size + opt.remsize, target_mode=opt.target_mode, colordim=opt.colordim)
test_set = get_test_set(opt.size, target_mode=opt.target_mode, colordim=opt.colordim)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)

print('===> Building unet')
unet = UNet(opt.colordim)


criterion = nn.MSELoss()
if cuda:
    unet = unet.cuda()
    criterion = criterion.cuda()

pretrained = True
if pretrained:
    unet.load_state_dict(torch.load(opt.pretrain_net))

optimizer = optim.SGD(unet.parameters(), lr=opt.lr)
print('===> Training unet')

def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        randH = random.randint(0, opt.remsize)
        randW = random.randint(0, opt.remsize)
        input = Variable(batch[0][:, :, randH:randH + opt.size, randW:randW + opt.size])
示例#31
0
def run_main():
    import argparse
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--trainlist',
        default=
        "/root/chujiajia/Code/ACDC_CF/libs/datasets/jsonLists/train_50000.json"
    )
    parser.add_argument(
        '--testlist',
        default=
        "/root/chujiajia/Code/ACDC_CF/libs/datasets/jsonLists/test_50000.json")
    parser.add_argument('--data_root', default="/root/XieHe_DataSet/")

    parser.add_argument('--batchsize', type=int, default=32)
    parser.add_argument('--num_epochs', type=int, default=10)
    parser.add_argument('--w_contour', type=float, default=1)

    parser.add_argument('--gpu',
                        type=int,
                        default=0,
                        help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
    # parser.add_argument('--exp', type=str, default='liver')
    parser.add_argument('--loss_term', type=str, default='dice')

    parser.add_argument('--initial_lr', type=float, default=0.001)
    parser.add_argument('--resultpath',
                        type=str,
                        default="/root/chujiajia/Results/")
    parser.add_argument('--experiment_name',
                        type=str,
                        default='dice_loss_acdc_1')
    parser.add_argument('--num_workers', default=16, type=int)
    args = parser.parse_args()

    # torch.cuda.set_device(int(args.gpu))
    global_step = 0
    savepath = os.path.join(args.resultpath, args.experiment_name)
    if not os.path.exists(savepath):
        os.mkdir(savepath)
    imgsavepath = os.path.join(savepath, "imgresult")
    if not os.path.exists(imgsavepath):
        os.mkdir(imgsavepath)
    logpath = os.path.join(savepath, "log.txt")

    # traindataset = ACDCDataset(args.trainlist,augmentations=data_aug)
    # trainloader = DataLoader(traindataset, batch_size=args.batchsize,shuffle=True, num_workers=args.num_workers, pin_memory=True)
    # testdataset = ACDCDataset(args.testlist)
    # testloader = DataLoader(testdataset, batch_size=1,shuffle=False, num_workers=args.num_workers, pin_memory=True)

    traindataset = XieHeDataset(args.trainlist, args.data_root)
    trainloader = DataLoader(traindataset,
                             batch_size=args.batchsize,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    testdataset = XieHeDataset(args.testlist, args.data_root)

    testloader = DataLoader(testdataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=args.num_workers,
                            pin_memory=True)

    print("training dataset:{}".format(len(traindataset)))
    print("testing dataset:{}".format(len(testdataset)))
    # mask_ids = ["ProstateDx-03-0001","ProstateDx-03-0002","ProstateDx-03-0003","ProstateDx-03-0004","ProstateDx-03-0005"]

    model = UNet(n_channels=1, n_classes=4)
    model = torch.nn.DataParallel(model, device_ids=[0, 1])

    optimizer = torch.optim.Adam(model.parameters(), lr=args.initial_lr)
    model.train()
    model.cuda()

    # Training loop
    save_mean_dice = 0.0
    for epoch in range(1, args.num_epochs + 1):
        lr = args.initial_lr
        # lr = adjust_learning_rate(optimizer, epoch-1, args.num_epochs, args.initial_lr)
        model.train()
        model.cuda()
        composite_loss_total = 0.0
        loss_dice_total = 0.0
        loss_contour_total = 0.0
        num_steps = 0
        for i, train_batch in enumerate(trainloader):
            # Supervised component --------------------------------------------
            src_img, labels, edge, contour, train_name = train_batch
            src_img = torch.Tensor(src_img).cuda()
            # print(edge)
            # edge = torch.Tensor(edge).cuda()
            # contour = torch.Tensor(contour).cuda()
            # labels = torch.Tensor(labels).cuda()
            pred_seg = F.softmax(model(src_img), dim=1)
            if args.loss_term == "edge":
                loss_contour = multiContourLoss(pred_seg, labels,
                                                edge) * args.w_contour
            elif args.loss_term == "contour":
                loss_contour = multiContourLoss(pred_seg, labels,
                                                contour) * args.w_contour
            elif args.loss_term == "boundary":
                loss_contour = multiBoundaryLoss(pred_seg,
                                                 labels) * args.w_contour
            else:
                loss_contour = 0

            loss_dice = multiDiceLoss(pred_seg, labels)

            if epoch == 1 and i == 1:
                print(train_name)
                line = str(train_name)
                write_log(logpath, line)

            composite_loss = loss_contour + loss_dice

            optimizer.zero_grad()
            composite_loss.backward()
            optimizer.step()

            composite_loss_total += composite_loss.item()
            loss_dice_total += loss_dice.item()
            if args.loss_term == "edge" or args.loss_term == "contour" or args.loss_term == "boundary":
                loss_contour_total += loss_contour.item()
            else:
                loss_contour_total = 0

            num_steps += 1
            global_step += 1

        dice_loss_avg = loss_dice_total / num_steps
        contour_loss_avg = loss_contour_total / num_steps
        composite_loss_avg = composite_loss_total / num_steps

        torch.save(model.state_dict(), os.path.join(savepath, 'CP_latest.pth'))
        torch.save(model.state_dict(),
                   os.path.join(savepath, 'CP' + str(epoch) + '.pth'))

        line = "Epoch: {},Composite Loss: {:.6f},Dice Loss:{:.6f},Contour Loss:{:.6f},lr:{:.6f},best_dice:{:.6f}".format(
            epoch, composite_loss_avg, dice_loss_avg, contour_loss_avg, lr,
            save_mean_dice)
        print(line)
        write_log(logpath, line)

        if epoch > 0 and epoch % 10 == 0:
            print("valiation_structure")
            # mean_dice = valiation_structure(model,args,"mr_test",imgsavepath,logpath)
            mean_dice = valiation_structure(model, args, logpath, testloader,
                                            imgsavepath)
            if mean_dice > save_mean_dice:
                save_mean_dice = mean_dice
                torch.save(model.state_dict(),
                           os.path.join(savepath, 'CP_Best.pth'))
                line = 'Best Checkpoint {} saved !Mean_dice:{}'.format(
                    epoch, save_mean_dice)
                write_log(logpath, line)