Exemplo n.º 1
0
def infer(args, unlabeled, ckpt_file):
    # Load the last best model
    traindataset = BasicDataset(args["TRAINIMAGEDATA_DIR"],
                                args["TRAINLABEL_DIRECTORY"], img_scale)
    unlableddataset = Subset(traindataset, unlabeled)
    unlabeled_loader = DataLoader(unlableddataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True)
    predix = 0
    predictions = {}
    net = UNet(n_channels=3, n_classes=1, bilinear=True)
    net.to(device=device)
    net.load_state_dict(torch.load(os.path.join(args["EXPT_DIR"] + ckpt_file)))
    net.eval()

    with tqdm(total=n_val, desc='Validation round', unit='batch',
              leave=False) as pbar:
        for batch in val_loader:
            imgs, true_masks = batch['image'], batch['mask']
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=mask_type)

            with torch.no_grad():
                mask_pred = net(imgs)
            for ix, logit in enumerate(maskpred):
                predictions[predix] = logit.cpu().numpy()

                predix += 1

            pbar.update()

    return {"outputs": predictions}
Exemplo n.º 2
0
def main():
    net = UNet(n_channels=3, n_classes=1)
    net.load_state_dict(torch.load('MODEL.pth'))
    net.eval()

    input_var = torch.rand(1, 3, 640, 959)  # Use half of the original resolution.
    torch.onnx.export(net, input_var, 'Unet.onnx', verbose=True, export_params=True)
def predict(img_path, ori_seg_path, loc_path, ckpt_path, train_phrase, display_path, mask_path):
    ckpt = os.listdir(ckpt_path)
    ckpt.sort(reverse=True)
    ckpt = ckpt[0]

    model = UNet(channels_in=3, channels_out=1)
    model.load_state_dict(torch.load(os.path.join(ckpt_path, ckpt)))

    dataset = Dataset(img_path, ori_seg_path, loc_path, img_path, train_phrase)
    dataloder = DataLoader(dataset, batch_size=3, shuffle=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.float().to(device)
    model.eval()
    
    for step, batch in enumerate(dataloder):
        features = batch['features'].to(device)
        file_name = batch['file_name']
        pred = model(features)
        pred = pred.data.cpu().numpy()
        for i in range(len(file_name)):
            file_path = os.path.join(img_path, train_phrase, file_name[i])
            img = imageio.imread(file_path)
            display_jpg(img, pred[i].squeeze(), file_name[i], train_phrase, display_path)
            save_mask(pred[i].squeeze(), file_name[i], train_phrase, mask_path)
Exemplo n.º 4
0
def evaluate_model(net: UNet, valid_loader: DataLoader,
                   device: torch.device) -> float:
    net.eval()
    mask_type = torch.float32 if net.n_classes == 1 else torch.long
    nb_of_batches = len(valid_loader)
    total = 0
    print()
    with tqdm(total=nb_of_batches,
              desc="Validation",
              unit="batch",
              leave=False) as bar:
        for i, batch in enumerate(valid_loader):
            images, masks = batch["image"], batch["mask"]
            images = images.to(device=device, dtype=torch.float32)
            masks = masks.to(device=device, dtype=mask_type)

            with torch.no_grad():
                mask_predictions = net(images)

            if net.n_classes > 1:
                total += F.cross_entropy(mask_predictions, masks).item()
            else:
                pred = torch.sigmoid(mask_predictions)
                pred = (pred > 0.5).float()
                dice = dice_coeff(pred, masks).item()
                #print(f"Validation. Batch: {i}. Dice: {round(dice, 4)}")
                total += dice

            bar.update()

    net.train()
    #print(f"Mean dice for validation dataset: {round(total / nb_of_batches, 4)}")
    return total / nb_of_batches
Exemplo n.º 5
0
def initial_models(path_to_ckpt):

    # find the lastest model
    ckpt_list = []
    
    if ".pth" not in path_to_ckpt:
        for c in os.listdir(path_to_ckpt):
            if ".pth" in c:
                ckpt_list.append(c)
        ckpt_list.sort()
        path_to_ckpt = join(path_to_ckpt, ckpt_list[-1])

    assert exists(path_to_ckpt)
    
    # init model
    net = UNet(in_channels=1, out_channels=1, bilinear=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.to(device=device)
    
    # load model
    print("Log:\tload %s"%path_to_ckpt)
    try:
        net.load_state_dict(torch.load(path_to_ckpt, map_location=device))
    except:
        net = torch.nn.DataParallel(net)
        net.load_state_dict(torch.load(path_to_ckpt, map_location=device))
    net.eval()

    return net, device
Exemplo n.º 6
0
def test(args, ckpt_file):
    testdataset = BasicDataset(args["TESTIMAGEDATA_DIR"],
                               args["TESTLABEL_DIRECTORY"], img_scale)
    val_loader = DataLoader(testdataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)
    net = UNet(n_channels=3, n_classes=1, bilinear=True)
    net.to(device=device)
    net.load_state_dict(torch.load(os.path.join(args["EXPT_DIR"] + ckpt_file)))
    net.eval()
    with tqdm(total=n_val, desc='Validation round', unit='batch',
              leave=False) as pbar:
        for batch in val_loader:
            imgs, true_masks = batch['image'], batch['mask']
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=mask_type)

            with torch.no_grad():
                mask_pred = net(imgs)

            if net.n_classes > 1:
                tot += F.cross_entropy(mask_pred, true_masks).item()
            else:
                pred_sig = torch.sigmoid(mask_pred)
                pred = (pred_sig > 0.5).float()
                tot += dice_coeff(pred, true_masks).item()
            pbar.update()

    return {"predictions": pred, "labels": true_masks}
Exemplo n.º 7
0
def main(args):
    makedirs(args)
    device = torch.device(
        "cpu" if not torch.cuda.is_available() else args.device)

    loader = data_loader(args)

    with torch.set_grad_enabled(False):
        unet = UNet(in_channels=Dataset.in_channels,
                    out_channels=Dataset.out_channels)
        #        unet = NestedUNet(in_ch=Dataset.in_channels, out_ch=Dataset.out_channels)
        state_dict = torch.load(args.weights, map_location=device)
        unet.load_state_dict(state_dict)
        unet.eval()
        unet.to(device)

        input_list = []
        pred_list = []
        true_list = []

        for i, data in tqdm(enumerate(loader)):
            x, y_true = data
            x, y_true = x.to(device), y_true.to(device)

            y_pred = unet(x)
            y_pred_np = y_pred.detach().cpu().numpy()
            pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])])

            y_true_np = y_true.detach().cpu().numpy()
            true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])])

            x_np = x.detach().cpu().numpy()
            input_list.extend([x_np[s] for s in range(x_np.shape[0])])

    volumes = postprocess_per_volume(
        input_list,
        pred_list,
        true_list,
        loader.dataset.patient_slice_index,
        loader.dataset.patients,
    )

    dsc_dist = dsc_distribution(volumes)

    dsc_dist_plot = plot_dsc(dsc_dist)
    imsave(args.figure, dsc_dist_plot)

    for p in volumes:
        x = volumes[p][0]
        y_pred = volumes[p][1]
        y_true = volumes[p][2]
        for s in range(x.shape[0]):
            image = gray2rgb(x[s, 1])  # channel 1 is for FLAIR
            image = outline(image, y_pred[s, 0], color=[255, 0, 0])
            image = outline(image, y_true[s, 0], color=[0, 255, 0])
            filename = "{}-{}.png".format(p, str(s).zfill(2))
            filepath = os.path.join(args.predictions, filename)
            imsave(filepath, image)
Exemplo n.º 8
0
def draw_bbx_test(test_dir, model_name, gpu=3):
    """
    draw the bbx for whole scenes of the testdataset
    """
    device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
    print("device: ", f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
    img_ls = glob.glob(os.path.join(test_dir, 'scene_*/*_color.png'))
    random.shuffle(img_ls)
    transform = transforms.Resize(resize_input)
    nets = {}
    for i in range(1, 7):
        correspondence_block = UNet()
        correspondence_block.load_state_dict(
            torch.load(f'ckpt/{model_name}{i}.pt', map_location=device))
        correspondence_block = correspondence_block.to(device)
        correspondence_block.eval()
        nets[i] = correspondence_block
    for idx, img_path in enumerate(img_ls):
        scene_id = img_path.split('/')[-2][-1]
        img_id = img_path.split('/')[-1][:4]
        obj_ls = glob.glob(str(Path(img_path).parents[0] / 'cropped' / img_id) + '*')
        rgb_uncropped = cv2.imread(img_path)  # the original uncropped image
        rgb_uncropped = cv2.cvtColor(rgb_uncropped, cv2.COLOR_BGR2RGB)
        # rgb_uncropped_mask = Image.fromarray(rgb_uncropped.astype(np.uint8))
        for adr_rgb in obj_ls:
            single_data = SingleTestdata(adr_rgb)
            cls_id = single_data.cls_id
            net = nets[cls_id]

            rgb = Image.open(adr_rgb)
            to_tensor = transforms.ToTensor()
            rgb = _make_square(rgb)
            rgb = transform(rgb)
            rgb = to_tensor(rgb)
            rgb = torch.unsqueeze(rgb, 0)
            rgb = rgb.to(device)
            xmask_pred, ymask_pred, zmask_pred = net(rgb)
            xmask_pred = torch.argmax(xmask_pred[0], 0).to('cpu').detach().numpy()
            ymask_pred = torch.argmax(ymask_pred[0], 0).to('cpu').detach().numpy()
            zmask_pred = torch.argmax(zmask_pred[0], 0).to('cpu').detach().numpy()
            xyzmask_pred = np.stack([xmask_pred, ymask_pred, zmask_pred], axis=-1) / 255
            xyzmask_pred = np.clip(xyzmask_pred, 0, 1)

            R, t = ransac_pnp(xyzmask_pred, K,
                              single_data, resize_input, rgb_uncropped)
            R_gt = single_data.R
            t_gt = single_data.t

            pose = Rt_to_H(R, t)
            pose_gt = Rt_to_H(R_gt, t_gt)
            # draw the bbx with poses
            create_bounding_box(rgb_uncropped, pose, pose_gt, single_data, K)

    print('saving to ', f'predicts/bbx/{scene_id}_{img_id}.png')
    plt.imsave(f'predicts/bbx/{scene_id}_{img_id}.png', rgb_uncropped)
def load_model(no_class, model_file_name):
    if 'upLearned' in model_file_name:
        net = UNet(n_channels=3, n_classes=no_class, bilinear=False)
    else:
        net = UNet(n_channels=3, n_classes=no_class, bilinear=True)

    net = parallelize_model(net)
    net.load_state_dict(torch.load('checkpoints/' + model_file_name))
    print("Model loaded !")
    net.eval()
    return net
    def load_model(self):
        print('Loading model from ', self.model_path)
        try:
            net = UNet(n_channels=3, n_classes=self.no_classes, bilinear=False)
        except:
            net = UNet(n_channels=3, n_classes=self.no_classes, bilinear=True)

        net = parallelize_model(net)
        net.load_state_dict(torch.load(self.model_path))
        print("Model loaded!")
        net.eval()
        return net
Exemplo n.º 11
0
class Segmentation:
    def __init__(self):
        #self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.device = torch.device("cpu")
        weights = os.path.abspath(
            os.path.dirname(__file__)) + "/weights/unet.pt"

        self.model = UNet(in_channels=3, out_channels=1)
        state_dict = torch.load(weights, map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.eval()
        self.model.to(self.device)

    def transform(self, img):
        img = img.transpose(2, 0, 1)
        img = torch.from_numpy(img.astype(np.float32))
        img = img.unsqueeze(0)
        return img

    def ask(self, image_b64):
        decoded = base64.b64decode(image_b64)
        img = Image.open(BytesIO(decoded)).convert("RGB")
        img = img.resize((256, 256))

        img = normalize_volume(np.array(img))

        img = self.transform(img)
        img = img.to(self.device)
        pred = self.model(img)

        seg_mask = pred.squeeze(0).squeeze(0).detach().cpu().numpy()
        seg_mask = np.round(seg_mask).astype(int)

        if len(seg_mask[seg_mask != 0]) != 0:
            seg_mask = largest_connected_component(seg_mask)
        else:
            raise BaseException("nothing found")

        initial_image = img.reshape((3, 256, 256))[1]
        initial_image = initial_image.reshape(
            (256, 256)).detach().cpu().numpy()

        initial_image = gray2rgb(initial_image)
        outlined_img = outline(initial_image, seg_mask, color=[255, 0, 0])

        out_img = Image.fromarray(outlined_img)

        imgByteArr = BytesIO()
        out_img.save(imgByteArr, format="PNG")
        imgByteArr = imgByteArr.getvalue()

        return imgByteArr
def load_net(cfg, net_dir: Path, net_name: str):
    net = UNet(cfg)
    net_file = net_dir / f'{net_name}.pkl'
    state_dict = torch.load(str(net_file),
                            map_location=lambda storage, loc: storage)
    net.load_state_dict(state_dict)

    mode = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device(mode)

    net.to(device)
    net.eval()

    return net
Exemplo n.º 13
0
    def build_model(self):
        '''Build the UNet model'''
        net = UNet()

        # Load the weights as trained with 120 epochs
        net.load_state_dict(torch.load(
            "/home/aaron/CMU_Lidar_Navigation/" +
            "experiments/default/state_dicts/120epoch",
            map_location=self.device))

        net = net.to(self.device)
        net.eval()

        return net
Exemplo n.º 14
0
def plot_imgs_pred():
    """
    Funzione
    :return:
    """
    args = get_plot_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # img = Image.open(args.dir_img)
    img = tiff.imread(args.dir_img)
    img = BasicDataset.preprocess(img, scale=args.scale)
    img = torch.from_numpy(img).type(torch.FloatTensor)

    if args.model_arch == 'unet':
        net = UNet(n_channels=4, n_classes=4, bilinear=True)

    elif args.model_arch == 'icnet':
        net = ICNet(n_channels=4, n_classes=4, pretrained_base=False)

    net.load_state_dict(torch.load(args.checkpoint_net, map_location=device))
    net.to(device=device)
    net.eval()

    img = img.to(device=device, dtype=torch.float32)
    img = img.unsqueeze(0)

    with torch.no_grad():
        if args.model_arch == 'icnet':
            mask_pred, pred_sub4, pred_sub8, pred_sub16 = net(img)
        else:
            mask_pred = net(img)

    plt.imshow(img[0][0])
    plt.colorbar()
    plt.savefig(args.dir_output + "original_img.png")
    plt.clf()

    for i, c in enumerate(mask_pred):
        n_classes = c.size(0)
        classes = range(n_classes)
        c = torch.sigmoid(c)
        max_index = torch.max(c, 0).indices
        for class_index in classes:
            # Vediamo la predizione
            jaccard_input = (max_index == class_index).float()
            plt.imshow(jaccard_input)
            plt.colorbar()
            plt.savefig(args.dir_output + f"pred_cls_{class_index}.png")
            plt.clf()
Exemplo n.º 15
0
def val(opt):
    # dataset
    val_dataset = DataSet(opt.val_data_root)
    valloader = DataLoader(val_dataset,
                            batch_size=opt.batch_size,
                            num_workers=opt.num_workers)
    # network
    net = UNet(22, [32, 64, 128, 256])
    net.eval()
    models = natsorted(os.listdir(opt.load_model_path))

    CSI = np.zeros((len(models), 5), dtype=float)
    CSI[:, 0] = np.arange(opt.snapshot, opt.max_iter+1, opt.snapshot)

    with torch.no_grad():
        for iteration, model in enumerate(models):
            print(model)
            logging.info(model)
            dec_value = []
            labels = []
            net.load_state_dict(torch.load(os.path.join(opt.load_model_path, model)))
            if opt.use_gpu:
                net.cuda()
            #  softmax output
            for input, target in valloader:
                if opt.use_gpu:
                    input = input.cuda()
                output = net(input).permute(0, 2, 3, 1).contiguous().view(-1, 2)
                target = target.view(-1)
                output = torch.nn.functional.softmax(output, dim=1).detach().cpu().numpy()
                dec_value.append(output)
                labels.append(target.numpy())

            dec_value = np.concatenate(dec_value)
            labels = np.concatenate(labels).squeeze()
            # save dec_value
            np.savetxt(os.path.join(opt.result_file,
                                    'iteration_' + str((iteration+1)*opt.snapshot) + '.txt'), dec_value, fmt='%10.6f')
            # find best CSI
            CSI[iteration, 1:] = find_best_CSI(dec_value, labels)
            # save CSI to file every epoch
            np.savetxt(opt.result_file + '/CSI.txt', CSI, fmt='%8d'+'%8.4f'*4)

    best_iteration = np.arange(opt.snapshot, opt.max_iter+1, opt.snapshot)[np.argmax(CSI[:,1])]
    confidence = CSI[int(best_iteration/opt.snapshot)-1, 4]
    logging.info('best_iteration: %d,confidence: %.6f' % (best_iteration, confidence))

    return best_iteration, confidence
Exemplo n.º 16
0
def predict():
    model = UNet()
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    cell_dataset = CellDataset(data_folder, eval=True)
    model.load_state_dict(checkpoint)
    model.eval()
    for i in range(len(cell_dataset)):
        input, _ = cell_dataset[i]
        output = model(input).permute(0, 2, 3, 1).squeeze().detach().numpy()
        input_array = input.squeeze().detach().numpy()
        output_array = output.argmax(2) * 255
        input_img = Image.fromarray(input_array)
        output_img = Image.fromarray(
            output_array.astype(dtype=np.uint16)).convert('L')
        input_img.show()
        output_img.show()
    return
Exemplo n.º 17
0
def predict(img_path, ckpt_path, model_name, train_phrase, display_path,
            mask_path, model_res_path):
    ckpt_path_ = os.path.join(ckpt_path, model_name)
    ckpt = os.listdir(ckpt_path_)
    ckpt.sort(reverse=True)
    ckpt = ckpt[0]

    if model_name.find('3layers') >= 0:
        model = UNet(channels_in=4, channels_out=1)
        Dataset = Dataset4Layers
    elif model_name == 'emb':
        model = UNet(channels_in=5, channels_out=1)
        Dataset = EmbDataset
    else:
        model = UNet(channels_in=2, channels_out=1)
        Dataset = Dataset2Layers

    model.load_state_dict(torch.load(os.path.join(ckpt_path, model_name,
                                                  ckpt)))

    if model_name == 'emb':
        set = EmbDataset(img_path, img_path, model_res_path, [
            'unet', 'unet_3layers', 'unet_3layers_with_vgg_loss',
            'unet_with_vgg_loss'
        ], train_phrase)
    else:
        set = Dataset(img_path, ori_seg_path, img_path, train_phrase)
    loader = DataLoader(set, batch_size=3, shuffle=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.float().to(device)
    model.eval()

    for step, batch in enumerate(loader):
        features = batch['features'].to(device)
        file_name = batch['file_name']
        pred = model(features)
        pred = pred.data.cpu().numpy()
        for i in range(len(file_name)):
            file_path = os.path.join(img_path, train_phrase, file_name[i])
            img = imageio.imread(file_path)
            display_jpg(img, pred[i].squeeze(), file_name[i], model_name,
                        train_phrase, display_path)
            save_mask(pred[i].squeeze(), file_name[i], model_name,
                      train_phrase, mask_path)
Exemplo n.º 18
0
def load_models(device):
    model_denoiser = UNet(n_classes=1, depth=5, padding=True, batch_norm=True)
    model_denoiser.load_state_dict(
        torch.load("data/denoising_autoencoder.pth",
                   map_location=torch.device(device)))
    model_denoiser = model_denoiser.to(device)

    model_classifier = models.resnet18(pretrained=True)
    num_ftrs = model_classifier.fc.in_features
    model_classifier.fc = torch.nn.Linear(num_ftrs, 1)
    model_classifier.load_state_dict(
        torch.load("data/classifier_model.pth",
                   map_location=torch.device(device)))
    model_classifier = model_classifier.to(device)

    model_classifier.eval()
    model_denoiser.eval()
    return model_denoiser, model_classifier
Exemplo n.º 19
0
def _add_bgnd(fname, _model_path = model_path, _int_scale = (0, 255), cuda_id = 0):
    
    if torch.cuda.is_available():
        print("THIS IS CUDA!!!!")
        dev_str = "cuda:" + str(cuda_id)
    else:
        dev_str = 'cpu'
    device = torch.device(dev_str)
    
    model = UNet(n_channels = 1, n_classes = 1)
    state = torch.load(_model_path, map_location = 'cpu')
    model.load_state_dict(state['state_dict'])
    
    model = model.to(device)
    model.eval()
    
    with tables.File(fname, 'r+') as fid:
        full_data = fid.get_node('/full_data')
        
        if '/bgnd' in fid:
            fid.remove_node('/bgnd')
            
        bgnd = createImgGroup(fid, "/bgnd", *full_data.shape, is_expandable = False)
        bgnd._v_attrs['save_interval'] = full_data._v_attrs['save_interval']
        
        for ii in tqdm.trange(full_data.shape[0]):
            img = full_data[ii]
            
            x = img.astype(np.float32)
            x = (x - _int_scale[0])/(_int_scale[1] - _int_scale[0])
            
            
            with torch.no_grad():
                X = torch.from_numpy(x[None, None])
                X = X.to(device)
                Xhat = model(X)
            
            xhat = Xhat.squeeze().detach().cpu().numpy()
            
            bg = xhat*(_int_scale[1] - _int_scale[0]) + _int_scale[0]
            bg = bg.round().astype(img.dtype)
            bgnd[ii] = bg
Exemplo n.º 20
0
def main(weights, dataset, output_dir, t, dropout):
    if not os.path.isdir(output_dir):
        print(f"  ** {output_dir} not found, creating directory...")
        os.mkdir(output_dir)
    model = UNet(p_dropout=dropout)
    model = model.to(DEVICE)
    model.load_state_dict(weights)
    model.eval()
    with torch.no_grad():
        for signal_image, _, signal_filename in dataset:
            # re-enable dropout each time as the last prediction requires
            # disabling dropout
            model.enable_dropout()
            store = []
            # add a batch dimension to single image input
            signal_image = signal_image.unsqueeze(0).to(DEVICE)
            for i in range(t):
                output = model(signal_image)
                store.append(output.squeeze())
            output_mean = torch.stack(store).mean(0).data.cpu().numpy()
            output_var = torch.stack(store).var(0).data.cpu().numpy()
            # make a prediction from the complete model without dropout
            model.disable_dropout()
            output_prediction = model(signal_image)
            output_prediction = output_prediction.squeeze().data.cpu().numpy()
            print(f"  ** saving prediction from {signal_filename}")
            skimage.io.imsave(make_save_filename(signal_filename, output_dir,
                                                 "prediction"),
                              output_prediction,
                              check_contrast=False)
            skimage.io.imsave(make_save_filename(signal_filename, output_dir,
                                                 "mean"),
                              output_mean,
                              check_contrast=False)
            skimage.io.imsave(make_save_filename(signal_filename, output_dir,
                                                 "variance"),
                              output_var,
                              check_contrast=False)
            skimage.io.imsave(make_save_filename(signal_filename, output_dir,
                                                 "signal"),
                              signal_image.data.cpu().numpy(),
                              check_contrast=False)
Exemplo n.º 21
0
def test(args):

    #显示模型的输出结果
    model = UNet(1).to(device)
    model.load_state_dict(torch.load(args.ckpt, map_location='cpu'))

    liver_dataset = LiverDataset(r"E:\360Downloads\dataset\fingerprint\val",
                                 transform=x_transforms,
                                 target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.eval()

    import matplotlib.pyplot as plt

    plt.ion()  # 开启动态模式

    with torch.no_grad():
        i = 0  # 验证集中第i张图
        miou_total = 0
        num = len(dataloaders)  # 验证集图片的总数
        for x, _ in dataloaders:
            x = x.to(device)
            y = model(x)

            img_y = torch.squeeze(y).cpu().numpy(
            )  # 输入损失函数之前要把预测图变成numpy格式,且为了跟训练图对应,要额外加多一维表示batchsize
            mask = get_data(i)[1]  # 得到当前mask的路径
            miou_total += get_iou(mask, img_y)  # 获取当前预测图的miou,并加到总miou中
            plt.subplot(1, 3, 1)
            plt.imshow(Image.open(get_data(i)[0]))
            plt.title('Input fingerprint image')
            plt.subplot(1, 3, 2)
            plt.imshow(Image.open(get_data(i)[1]))
            plt.title('Ground Truth label')
            plt.subplot(1, 3, 3)
            plt.imshow(img_y)
            plt.title('Estimated fingerprint pose')
            plt.pause(20)
            if i < num: i += 1  # 处理验证集下一张图
        plt.show()
        print('Miou=%f' % (miou_total / 100))
Exemplo n.º 22
0
def main():

    # init conv net
    
    unet = UNet(3,1)
    if os.path.exists("./unet.pkl"):
        unet.load_state_dict(torch.load("./unet.pkl"))
        print("load unet")
    unet.cuda()

    cnn = CNNEncoder()
    if os.path.exists("./cnn.pkl"):
        cnn.load_state_dict(torch.load("./cnn.pkl"))
        print("load cnn")
    cnn.cuda()

    unet.eval()
    cnn.eval()
    
    print("load ok")

    while True:
        pull_screenshot("autojump.png") # obtain screen and save it to autojump.png
        image = Image.open('./autojump.png')
        set_button_position(image)
        image = preprocess(image)
        
        image = Variable(image.unsqueeze(0)).cuda()
        mask = unet(image)

        plt.imshow(mask.squeeze(0).squeeze(0).cpu().data.numpy(), cmap='hot', interpolation='nearest')
        plt.show()
        
        segmentation = image * mask

        press_time = cnn(segmentation)
        press_time = press_time.cpu().data[0].numpy()
        print(press_time)
        jump(press_time)
        
        time.sleep(random.uniform(0.6, 1.1))
Exemplo n.º 23
0
def test(opt):
    # dataset
    test_dataset = DataSet(opt.test_data_root, test=True)
    testloader = DataLoader(test_dataset,
                            batch_size=opt.batch_size,
                            num_workers=opt.num_workers)

    # network
    net = UNet(22, [32, 64, 128, 256])
    net.eval()
    net.load_state_dict(torch.load(os.path.join(opt.load_model_path, opt.checkpoint_model)))
    if opt.use_gpu:
        net.cuda()

    dec_value = []
    labels = []
    
    with torch.no_grad():
        #  softmax output
        for input, target in testloader:
            if opt.use_gpu:
                input = input.cuda()
            output = net(input).permute(0, 2, 3, 1).contiguous().view(-1, 2)
            output = torch.nn.functional.softmax(output, dim=1).detach().cpu().numpy()
            target = target.view(-1)
            dec_value.append(output)
            labels.append(target.numpy())

    dec_value = np.concatenate(dec_value)
    labels = np.concatenate(labels).squeeze()
    # save dec_value
    np.savetxt(os.path.join(opt.result_file,
                            'best_iteration_' + str(opt.best_iteration) + '.txt'), dec_value, fmt='%10.6f')
    # find best CSI
    res = find_best_CSI(dec_value, labels, opt.confidence)
    print(res)
    np.savetxt(os.path.join(opt.result_file, 'test_result.txt'), [res],
               fmt='CSI:%.6f\nPOD:%.6f\nFAR:%.6f\nconfidence:%.6f')
Exemplo n.º 24
0
def unet_dm(img_file, model_file='MODEL.pth'):
    '''
    Función de predicción
    '''
    img = Image.open(img_file)
    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.114, 0.114, 0.114],
                             std=[0.237, 0.237, 0.237])
    ])
    img = tf(img)
    img = img.unsqueeze(0)
    #img.cuda()

    net = UNet(n_channels=3, n_classes=1, bilinear=False, n_features=32)
    #net.cuda()
    net.load_state_dict(torch.load(model_file))
    net.eval()

    with torch.no_grad():
        output = net(img)

    dm = output.squeeze().cpu().numpy()
    return dm
Exemplo n.º 25
0
    device = torch.device(args.gpu) if args.gpu is not None else 'cpu'
    if args.gpu:
        torch.cuda.set_device(device)

    # load dataset and user groups
    train_dataset, _, user_groups = get_dataset(args)

    # 使用 UNET
    if args.model == 'unet':
        global_model = UNet(n_channels=1, n_classes=1, bilinear=True)
    else:
        exit('Error: unrecognized model')
    # 加载模型
    global_model = torch.load(args.model_path, map_location=device)
    global_model.eval()
    print(global_model)
    train_loader = DataLoader(train_dataset, batch_size=1, num_workers=8)
    # 输出数据转换
    dcs = []
    for i, (images, ground_truth) in enumerate(train_loader):
        images, ground_truth = images.to(device), ground_truth.to(
            device).squeeze(1).squeeze(0)  # ground truth 去掉 channel 维度
        logits = global_model(images)
        probs = torch.sigmoid(logits).squeeze(1).squeeze(
            0)  # 去掉 channel 和 batch 的维度
        mask = (probs > 0.5).float()
        dc = dice_coeff(ground_truth, mask).cpu().item()
        print(f'{i}, dc: {dc}')
        dcs.append(dc)
    dc_avg = np.average(dcs)
Exemplo n.º 26
0
    assert not args.two_bucket
    invNet = StandardConv2D(out_channels=num_features).cuda()
else:
    invNet = ShiftVarConv2D(out_channels=num_features,
                            block_size=args.blocksize,
                            two_bucket=args.two_bucket).cuda()
uNet = UNet(in_channel=num_features,
            out_channel=args.subframes,
            instance_norm=False).cuda()

## load checkpoint
ckpt = torch.load(os.path.join('models', args.ckpt))
invNet.load_state_dict(ckpt['invnet_state_dict'])
uNet.load_state_dict(ckpt['unet_state_dict'])
invNet.eval()
uNet.eval()

## load checkpoint
ckpt = torch.load(os.path.join('models', args.ckpt))
uNet.load_state_dict(ckpt['unet_state_dict'])
uNet.eval()

c2b_code = ckpt['c2b_state_dict']['code']
code_repeat = c2b_code.repeat(1, 1, input_params['height'] // args.blocksize,
                              input_params['width'] // args.blocksize)

logging.info('Starting inference')
psnr_sum = 0.
ssim_sum = 0.
full_gt = []
full_pred = []
Exemplo n.º 27
0
class SubgoalServer(object):
    def load_subgoal_model(self):
        ''' Load the pretrained model to predict nearby waypoints from 
            pano and laser scan data '''
        self.unet_weights = rospy.get_param('unet_weights_file')
        self.unet = UNet(n_channels=2, n_classes=1).to(self.device)
        self.unet.load_state_dict(
            torch.load(self.unet_weights, map_location=self.device))
        self.unet.eval()
        rospy.loginfo('Subgoal model loaded weights')

    def __init__(self):

        # Fire up some networks
        self.device = torch.device(rospy.get_param('device', 'cuda:0'))
        self.load_subgoal_model()
        self.cnn = CNN()

        # Make sure we can find out where we are
        self.tf_lis = tf.TransformListener()

        # Subscribe to panos and scans that have been processed by rotate_pano
        pano_sub = message_filters.Subscriber('theta/image/rotated', Image)
        scan_sub = message_filters.Subscriber('scan/rotated', LaserScan)
        ts = message_filters.TimeSynchronizer([pano_sub, scan_sub], 1)
        ts.registerCallback(self.predict_waypoints)

        # NMS param
        self.max_predictions = rospy.get_param('max_subgoal_predictions', 10)

        # Publisher
        self.pub_feat = rospy.Publisher('subgoal/features',
                                        Image,
                                        queue_size=1)
        self.pub_occ = rospy.Publisher('subgoal/occupancy',
                                       Image,
                                       queue_size=1)
        self.pub_prob = rospy.Publisher('subgoal/prob', Image, queue_size=1)
        self.pub_nms = rospy.Publisher('subgoal/nms_prob', Image, queue_size=1)
        self.pub_way = rospy.Publisher('subgoal/waypoints',
                                       PoseArray,
                                       queue_size=1)
        rospy.spin()

    def prep_scan_for_net(self, scan_img):
        imgs = np.empty((1, 2, scan_img.shape[0], scan_img.shape[1]),
                        dtype=np.float32)
        imgs[:, 1, :, :] = scan_img.transpose((2, 0, 1))
        ran_ch = np.linspace(-0.5, 0.5, num=imgs.shape[2])
        imgs[:, 0, :, :] = np.expand_dims(np.expand_dims(ran_ch, axis=0),
                                          axis=2)
        out = torch.from_numpy(imgs).to(device=self.device)
        return out

    def radial_occupancy(self, scan):
        ''' Convert an 1D numpy array of 360 degree range scans to a 2D numpy array representing
            a radial occupancy map. Values are 1: occupied, -1: free, 0: unknown 
            Here we assume the scan is a full 360 degrees due to preprocessing by rotate_pano.'''
        n_range_bins = rospy.get_param('range_bins')
        n_heading_bins = rospy.get_param('heading_bins')
        range_bin_width = rospy.get_param('range_bin_width')
        range_bins = np.arange(0, range_bin_width * (n_range_bins + 1),
                               range_bin_width)
        heading_bin_width = 360.0 / n_heading_bins

        # Record the heading, range of the center of each bin in ros coords. Heading increases as you turn left.
        hr = np.zeros((n_range_bins, n_heading_bins, 2), dtype=np.float32)
        range_centers = range_bins[:-1] + range_bin_width / 2
        hr[:, :, 1] = range_centers.reshape(-1, 1)
        assert n_heading_bins % 2 == 0
        heading_centers = -(np.arange(n_heading_bins) * heading_bin_width +
                            heading_bin_width / 2 - 180)
        hr[:, :, 0] = np.radians(heading_centers)

        output = np.zeros((n_range_bins, n_heading_bins, 1),
                          dtype=np.float32)  # rows, cols, channels
        # chunk scan data to generate occupied (value 1)
        chunk_size = len(scan.ranges) // n_heading_bins
        args = [iter(scan.ranges[::-1])
                ] * chunk_size  # reverse scan since it's from right to left!
        n = 0
        for chunk in izip_longest(*args):
            # occupied (value 1)
            chunk = np.array(chunk)
            chunk[np.isnan(
                chunk
            )] = -1  # Remove nan values, negatives will fall outside range_bins
            # Add 'inf' as right edge of an extra bin to account for the case if the returned range exceeds
            # the maximum discretized range. In this case we still want to register these cells as free.
            hist, _ = np.histogram(chunk,
                                   bins=np.array(range_bins.tolist() +
                                                 [np.Inf]))
            output[:, n, 0] = np.clip(hist[:-1], 0, 1)
            # free (value -1)
            free_ix = np.flip(np.cumsum(np.flip(hist, axis=0), axis=0),
                              axis=0)[1:] > 0
            output[:, n, 0][free_ix] = -1
            n += 1
        return output, hr

    def predict_waypoints(self, image, scan):
        rospy.loginfo('Subgoal model predicting waypoints')
        stamp = rospy.Time.now()

        # Extract cnn features plus their heading and elevation
        feats, imgs_he = self.cnn.extract_features(image)
        aug_feats = np.concatenate((feats.cpu().numpy(), imgs_he), axis=0)
        feat_dim = feats.shape[0]

        # Prepare the scan data
        scan_img, scan_hr = self.radial_occupancy(scan)
        if rospy.get_param('subgoal_publish_occupancy', False):
            np_scan_img = ros_numpy.msgify(
                Image, (127 * (scan_img + 1)).astype(np.uint8),
                encoding='mono8')
            np_scan_img.header.stamp = stamp
            self.pub_occ.publish(np_scan_img)
        # Roll the scans so to match the image features
        roll_ix = -scan_img.shape[1] // 4 + 2  # -90 degrees plus 2 bins
        rolled_scan_img = np.roll(scan_img, roll_ix, axis=1)
        rolled_scan_hr = np.roll(scan_hr, roll_ix, axis=1)

        # Predict subgoals
        scans = self.prep_scan_for_net(rolled_scan_img)
        feats = feats.reshape((1, feat_dim, 3, 12))
        with torch.no_grad():
            logits = self.unet(scans, feats)
            pred = F.softmax(logits.flatten(1), dim=1).reshape(logits.shape)

        if rospy.get_param('subgoal_publish_prob', False):
            viz_pred = np.roll(255 * pred.squeeze().cpu().numpy(),
                               -roll_ix,
                               axis=1)
            np_viz_pred = ros_numpy.msgify(Image, viz_pred, encoding="32FC1")
            np_viz_pred.header.stamp = stamp
            self.pub_prob.publish(np_viz_pred)
        sigma = rospy.get_param('subgoal_nms_sigma')
        thresh = rospy.get_param('subgoal_nms_thresh')
        nms_pred = nms(pred, sigma, thresh, self.max_predictions)
        if rospy.get_param('subgoal_publish_nms_prob', False):
            viz_nms_pred = np.roll(255 * nms_pred.squeeze().cpu().numpy(),
                                   -roll_ix,
                                   axis=1)
            np_viz_nms_pred = ros_numpy.msgify(Image,
                                               viz_nms_pred,
                                               encoding="32FC1")
            np_viz_nms_pred.header.stamp = stamp
            self.pub_nms.publish(np_viz_nms_pred)

        # Get agent pose
        trans, rot = self.tf_lis.lookupTransform('/map', '/base_footprint',
                                                 rospy.Time(0))
        r, p, agent_heading_rad = quaternion_to_euler(rot)  # ros heading

        # Extract waypoint candidates
        nms_pred = nms_pred.squeeze()
        waypoint_ix = (nms_pred > 0).nonzero()

        # Extract candidate features - each should be the closest to that viewpoint
        # Note: The candidate_feat at last position is the feature for the END stop signal (zeros)
        num_candidates = waypoint_ix.shape[0] + 1
        candidates = np.zeros((feat_dim + 2, num_candidates), dtype=np.float32)
        im_features = aug_feats[:-2].reshape(feat_dim, 3, 12)
        imgs_he = imgs_he.reshape(2, 3, 12)

        # Publish pose array of possible goals
        pa = PoseArray()
        pa.header.stamp = stamp
        pa.header.frame_id = 'map'

        for i, (range_bin,
                heading_bin) in enumerate(waypoint_ix.cpu().numpy()):
            hr = rolled_scan_hr[range_bin, heading_bin]
            # Calculate elevation to the candidate pose is 0 for the robot (stays the same height, doesn't go on stairs)
            # So candidate is always from the centre row of images 3 * 12 images
            img_heading_bin = heading_bin // 4
            candidates[:-2,
                       i] = im_features[:, 1,
                                        img_heading_bin]  # 1 is for elevation 0
            candidates[-2:, i] = [hr[0], 0]  # heading, elevation

            # Construct pose output as well
            pose = Pose()
            pose.position.x = trans[0] + math.cos(hr[0]) * hr[1]
            pose.position.y = trans[1] + math.sin(hr[0]) * hr[1]
            pose.position.z = 0
            # Which way should the robot face when it arrives? Away from here I guess.
            candidate_heading = math.atan2(pose.position.y - trans[1],
                                           pose.position.x - trans[0])
            pose.orientation = euler_to_quaternion(0, 0, candidate_heading)
            pa.poses.append(pose)

        # Publish image and candidate features
        combined_feats = np.concatenate([aug_feats, candidates], axis=1)
        np_feats = ros_numpy.msgify(Image, combined_feats, encoding="32FC1")
        np_feats.header.stamp = stamp
        self.pub_feat.publish(np_feats)
        rospy.logdebug('Subgoal server published features')

        # Put current pose in last position to feed the agent_server
        pose = Pose()
        pose.position.x = trans[0]
        pose.position.y = trans[1]
        pose.position.z = 0
        pose.orientation = euler_to_quaternion(0, 0, agent_heading_rad)
        pa.poses.append(pose)
        self.pub_way.publish(pa)
        rospy.loginfo('Subgoal server published waypoints')
Exemplo n.º 28
0
def train_net(options):
    dir_img = options.data + '/images/'
    dir_mask = options.data + '/masks/'
    dir_edge = options.data + '/edges/'
    dir_save_model = options.save_model
    dir_save_state = options.save_state
    ids = load.get_ids(dir_img)

    # trainとvalに分ける  # ここで順序も決まってしまう
    iddataset = {}
    iddataset["train"] = list(
        map(
            lambda x: x.split(".png")[0],
            os.listdir(
                "/data/unagi0/kanayama/dataset/nuclei_images/stage1_train_splited/train_default/"
            )))
    iddataset["val"] = list(
        map(
            lambda x: x.split(".png")[0],
            os.listdir(
                "/data/unagi0/kanayama/dataset/nuclei_images/stage1_train_splited/val_default/"
            )))
    N_train = len(iddataset['train'])
    N_val = len(iddataset['val'])
    N_batch_per_epoch_train = int(N_train / options.batchsize)
    N_batch_per_epoch_val = int(N_val / options.val_batchsize)

    # 実験条件の表示
    option_manager.display_info(options, N_train, N_val)

    # 結果の記録用インスタンス
    logger = Logger(options, iddataset)

    # モデルの定義
    net = UNet(3, 1, options.drop_rate1, options.drop_rate2,
               options.drop_rate3)

    # 学習済みモデルをロードする
    if options.load_model:
        net.load_state_dict(torch.load(options.load_model))
        print('Model loaded from {}'.format(options.load_model))

    # モデルをGPU対応させる
    if options.gpu:
        net.cuda()

    # 最適化手法を定義
    optimizer = optim.Adam(net.parameters())

    # optimizerの状態をロードする
    if options.load_state:
        optimizer.load_state_dict(torch.load(options.load_state))
        print('State loaded from {}'.format(options.load_state))

    # 学習開始
    for epoch in range(options.epochs):
        print('Starting epoch {}/{}.'.format(epoch + 1, options.epochs))
        train = load.get_imgs_and_masks(iddataset['train'], dir_img, dir_mask,
                                        dir_edge, options.resize_shape)
        original_sizes = load.get_original_sizes(iddataset['val'], dir_img,
                                                 '.png')
        val = load.get_imgs_and_masks(iddataset['val'],
                                      dir_img,
                                      dir_mask,
                                      dir_edge,
                                      options.resize_shape,
                                      train=False)
        train_loss = 0
        validation_loss = 0
        validation_score = 0
        validation_scores = np.zeros(10)

        # training phase
        if not options.skip_train:
            net.train()
            for i, b in enumerate(utils.batch(train, options.batchsize)):
                X = np.array([j[0] for j in b])
                y = np.array([j[1] for j in b])
                w = np.array([j[2] for j in b])

                if X.shape[
                        0] != options.batchsize:  # batch sizeを揃える(揃ってないとなぜかエラーになる)
                    continue

                X, y, w = utils.data_augmentation(X, y, w)

                X = torch.FloatTensor(X)
                y = torch.ByteTensor(y)
                w = torch.ByteTensor(w)

                if options.gpu:
                    X = X.cuda()
                    y = y.cuda()
                    w = w.cuda()

                X = Variable(X)
                y = Variable(y)
                w = Variable(w)

                y_pred = net(X)
                probs = F.sigmoid(y_pred)
                probs_flat = probs.view(-1)
                y_flat = y.view(-1)
                w_flat = w.view(-1)
                weight = (w_flat.float() / 255.) * (options.weight - 1) + 1.
                loss = weighted_binary_cross_entropy(probs_flat,
                                                     y_flat.float() / 255.,
                                                     weight)
                train_loss += loss.data[0]

                print('{0:.4f} --- loss: {1:.6f}'.format(
                    i * options.batchsize / N_train, loss.data[0]))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            print('Epoch finished ! Loss: {}'.format(train_loss /
                                                     N_batch_per_epoch_train))
            logger.save_loss(train_loss / N_batch_per_epoch_train,
                             phase="train")

        # validation phase
        net.eval()
        probs_array = np.zeros(
            (N_val, options.resize_shape[0], options.resize_shape[1]))
        for i, b in enumerate(utils.batch(val, options.val_batchsize)):
            X = np.array([j[0] for j in b])[:, :3, :, :]  # alpha channelを取り除く
            y = np.array([j[1] for j in b])
            w = np.array([j[2] for j in b])
            X = torch.FloatTensor(X)
            y = torch.ByteTensor(y)
            w = torch.ByteTensor(w)

            if options.gpu:
                X = X.cuda()
                y = y.cuda()
                w = w.cuda()

            X = Variable(X, volatile=True)
            y = Variable(y, volatile=True)
            w = Variable(w, volatile=True)

            y_pred = net(X)
            probs = F.sigmoid(y_pred)

            probs_flat = probs.view(-1)
            y_flat = y.view(-1)
            w_flat = w.view(-1)

            # edgeに対して重み付けをする
            weight = (w_flat.float() / 255.) * (options.weight - 1) + 1.
            loss = weighted_binary_cross_entropy(probs_flat,
                                                 y_flat.float() / 255., weight)
            validation_loss += loss.data[0]

            # 後処理
            y_hat = np.asarray((probs > 0.5).data)
            y_hat = y_hat.reshape((y_hat.shape[2], y_hat.shape[3]))
            y_truth = np.asarray(y.data)

            # ノイズ除去 & 二値化
            #dst_img = remove_noise(probs_resized, (original_height, original_width))
            #dst_img = (dst_img * 255).astype(np.uint8)

            # calculate validatation score
            if (options.calc_score_step !=
                    0) and (epoch + 1) % options.calc_score_step == 0:
                score, scores, _ = validate(y_hat, y_truth)
                validation_score += score
                validation_scores += scores
                print("Image No.", i, ": score ", score)

            logger.save_output_mask(y_hat, original_sizes[i],
                                    iddataset['val'][i])
            if options.save_probs is not None:
                logger.save_output_prob(np.asarray(probs.data[0][0]),
                                        original_sizes[i], iddataset['val'][i])

        print('Val Loss: {}'.format(validation_loss / N_batch_per_epoch_val))
        logger.save_loss(validation_loss / N_batch_per_epoch_val, phase="val")

        # スコアを保存する
        if (options.calc_score_step !=
                0) and (epoch + 1) % options.calc_score_step == 0:
            print('score: {}'.format(validation_score / i))
            logger.save_score(validation_scores, validation_score,
                              N_batch_per_epoch_val, epoch)

        # modelとoptimizerの状態を保存する。
        if (epoch + 1) % 10 == 0:
            torch.save(
                net.state_dict(), dir_save_model + str(options.id) +
                '_CP{}.model'.format(epoch + 1))
            torch.save(
                optimizer.state_dict(), dir_save_state + str(options.id) +
                '_CP{}.state'.format(epoch + 1))
            print('Checkpoint {} saved !'.format(epoch + 1))

        # draw loss graph
        logger.draw_loss_graph("./results/loss")
        # draw score graph
        if (options.calc_score_step !=
                0) and (epoch + 1) % options.calc_score_step == 0:
            logger.draw_score_graph("./results/score" + str(options.id) +
                                    ".png")
Exemplo n.º 29
0
Arquivo: train.py Projeto: Eikor/Unet
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_value_(net.parameters(), 0.1)
        optimizer.step()

        global_step += 1

        # writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
        # writer.add_images('images', imgs, global_step)
        # writer.add_images('masks/true', true_masks, global_step)
        # writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)

    ###### validation ######
    # torch.cuda.empty_cache()
    net.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            imgs = batch['image']
            label = batch['mask']
            imgs = imgs.cuda()
            label = label.cuda()
            pred = net(imgs)
            val_loss += np.sum(criterion(pred, label)) / len(val_loader)
        scheduler.step(val_loss)
    print('Validation: {}'.format(val_loss))
    ########################

    # writer.add_scalar('test', val_score, epoch)
torch.save(net.state_dict(), '320_rotate' + f'epoch{epoch + 1}.pth')
def main(args):
    makedirs(args)
    snapshotargs(args)
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

    loader_train, loader_valid = data_loaders(args)
    loaders = {"train": loader_train, "valid": loader_valid}

    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    unet.to(device)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    optimizer = optim.Adam(unet.parameters(), lr=args.lr)
    print("Learning rate = ", args.lr)                     #AP knowing lr
    print("Batch-size = ", args.batch_size)  # AP knowing batch-size
    print("Number of visualization images to save in log file = ", args.vis_images)  # AP knowing batch-size

    logger = Logger(args.logs)
    loss_train = []
    loss_valid = []

    step = 0

    for epoch in tqdm(range(args.epochs), total=args.epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                unet.train()
            else:
                unet.eval()

            validation_pred = []
            validation_true = []

            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1

                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)

                    loss = dsc_loss(y_pred, y_true)

                    if phase == "valid":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                        )
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])]
                        )
                        if (epoch % args.vis_freq == 0) or (epoch == args.epochs - 1):
                            if i * args.batch_size < args.vis_images:
                                tag = "image/{}".format(i)
                                num_images = args.vis_images - i * args.batch_size
                                logger.image_list_summary(
                                    tag,
                                    log_images(x, y_true, y_pred)[:num_images],
                                    step,
                                )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()

                if phase == "train" and (step + 1) % 10 == 0:
                    log_loss_summary(logger, loss_train, step)
                    loss_train = []

            if phase == "valid":
                log_loss_summary(logger, loss_valid, step, prefix="val_")
                mean_dsc = np.mean(
                    dsc_per_volume(
                        validation_pred,
                        validation_true,
                        loader_valid.dataset.patient_slice_index,
                    )
                )
                logger.scalar_summary("val_dsc", mean_dsc, step)
                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt"))
                loss_valid = []

    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))