def train():

    parser = argparse.ArgumentParser(
        description='PyTorch Medical Segmentation Training')
    parser = parse_training_args(parser)
    args, _ = parser.parse_known_args()

    args = parser.parse_args()

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    from data_function import MedData_train
    os.makedirs(args.output_dir, exist_ok=True)

    if hp.mode == '2d':
        from models.two_d.unet import Unet
        model = Unet(in_channels=hp.in_class, classes=hp.out_class)

        # from models.two_d.miniseg import MiniSeg
        # model = MiniSeg(in_input=hp.in_class, classes=hp.out_class)

        # from models.two_d.fcn import FCN32s as fcn
        # model = fcn(in_class =hp.in_class,n_class=hp.out_class)

        # from models.two_d.segnet import SegNet
        # model = SegNet(input_nbr=hp.in_class,label_nbr=hp.out_class)

        # from models.two_d.deeplab import DeepLabV3
        # model = DeepLabV3(in_class=hp.in_class,class_num=hp.out_class)

        # from models.two_d.unetpp import ResNet34UnetPlus
        # model = ResNet34UnetPlus(num_channels=hp.in_class,num_class=hp.out_class)

        # from models.two_d.pspnet import PSPNet
        # model = PSPNet(in_class=hp.in_class,n_classes=hp.out_class)

    elif hp.mode == '3d':

        from models.three_d.unet3d import UNet3D
        model = UNet3D(in_channels=hp.in_class,
                       out_channels=hp.out_class,
                       init_features=32)

        # from models.three_d.residual_unet3d import UNet
        # model = UNet(in_channels=hp.in_class, n_classes=hp.out_class, base_n_filter=2)

        #from models.three_d.fcn3d import FCN_Net
        #model = FCN_Net(in_channels =hp.in_class,n_class =hp.out_class)

        #from models.three_d.highresnet import HighRes3DNet
        #model = HighRes3DNet(in_channels=hp.in_class,out_channels=hp.out_class)

        #from models.three_d.densenet3d import SkipDenseNet3D
        #model = SkipDenseNet3D(in_channels=hp.in_class, classes=hp.out_class)

        # from models.three_d.densevoxelnet3d import DenseVoxelNet
        # model = DenseVoxelNet(in_channels=hp.in_class, classes=hp.out_class)

        #from models.three_d.vnet3d import VNet
        #model = VNet(in_channels=hp.in_class, classes=hp.out_class)

    model = torch.nn.DataParallel(model, device_ids=devicess)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr)

    # scheduler = ReduceLROnPlateau(optimizer, 'min',factor=0.5, patience=20, verbose=True)
    scheduler = StepLR(optimizer,
                       step_size=hp.scheduer_step_size,
                       gamma=hp.scheduer_gamma)
    # scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=5e-6)

    if args.ckpt is not None:
        print("load model:", args.ckpt)
        print(os.path.join(args.output_dir, args.latest_checkpoint_file))
        ckpt = torch.load(os.path.join(args.output_dir,
                                       args.latest_checkpoint_file),
                          map_location=lambda storage, loc: storage)

        model.load_state_dict(ckpt["model"])
        optimizer.load_state_dict(ckpt["optim"])

        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()

        # scheduler.load_state_dict(ckpt["scheduler"])
        elapsed_epochs = ckpt["epoch"]
    else:
        elapsed_epochs = 0

    model.cuda()

    from loss_function import Binary_Loss, DiceLoss
    criterion = Binary_Loss().cuda()

    writer = SummaryWriter(args.output_dir)

    train_dataset = MedData_train(source_train_dir, label_train_dir)
    train_loader = DataLoader(train_dataset.queue_dataset,
                              batch_size=args.batch,
                              shuffle=True,
                              pin_memory=True,
                              drop_last=True)

    model.train()

    epochs = args.epochs - elapsed_epochs
    iteration = elapsed_epochs * len(train_loader)

    for epoch in range(1, epochs + 1):
        print("epoch:" + str(epoch))
        epoch += elapsed_epochs

        num_iters = 0

        for i, batch in enumerate(train_loader):

            if hp.debug:
                if i >= 1:
                    break

            print(f"Batch: {i}/{len(train_loader)} epoch {epoch}")

            optimizer.zero_grad()

            if (hp.in_class == 1) and (hp.out_class == 1):
                x = batch['source']['data']
                y = batch['label']['data']

                x = x.type(torch.FloatTensor).cuda()
                y = y.type(torch.FloatTensor).cuda()

            else:
                x = batch['source']['data']
                y_atery = batch['atery']['data']
                y_lung = batch['lung']['data']
                y_trachea = batch['trachea']['data']
                y_vein = batch['atery']['data']

                x = x.type(torch.FloatTensor).cuda()

                y = torch.cat((y_atery, y_lung, y_trachea, y_vein), 1)
                y = y.type(torch.FloatTensor).cuda()

            if hp.mode == '2d':
                x = x.squeeze(4)
                y = y.squeeze(4)

                y[y != 0] = 1

            # print(y.max())

            outputs = model(x)

            # for metrics
            logits = torch.sigmoid(outputs)
            labels = logits.clone()
            labels[labels > 0.5] = 1
            labels[labels <= 0.5] = 0

            loss = criterion(outputs, y)

            num_iters += 1
            loss.backward()

            optimizer.step()
            iteration += 1

            false_positive_rate, false_negtive_rate, dice = metric(
                y.cpu(), labels.cpu())
            ## log
            writer.add_scalar('Training/Loss', loss.item(), iteration)
            writer.add_scalar('Training/false_positive_rate',
                              false_positive_rate, iteration)
            writer.add_scalar('Training/false_negtive_rate',
                              false_negtive_rate, iteration)
            writer.add_scalar('Training/dice', dice, iteration)

            print("loss:" + str(loss.item()))
            print('lr:' + str(scheduler._last_lr[0]))

        scheduler.step()

        # Store latest checkpoint in each epoch
        torch.save(
            {
                "model": model.state_dict(),
                "optim": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "epoch": epoch,
            },
            os.path.join(args.output_dir, args.latest_checkpoint_file),
        )

        # Save checkpoint
        if epoch % args.epochs_per_checkpoint == 0:

            torch.save(
                {
                    "model": model.state_dict(),
                    "optim": optimizer.state_dict(),
                    "epoch": epoch,
                },
                os.path.join(args.output_dir, f"checkpoint_{epoch:04d}.pt"),
            )

            with torch.no_grad():
                if hp.mode == '2d':
                    x = x.unsqueeze(4)
                    y = y.unsqueeze(4)
                    outputs = outputs.unsqueeze(4)

                x = x[0].cpu().detach().numpy()
                y = y[0].cpu().detach().numpy()
                outputs = outputs[0].cpu().detach().numpy()
                affine = batch['source']['affine'][0].numpy()

                if (hp.in_class == 1) and (hp.out_class == 1):
                    source_image = torchio.ScalarImage(tensor=x, affine=affine)
                    source_image.save(
                        os.path.join(args.output_dir,
                                     f"step-{epoch:04d}-source" +
                                     hp.save_arch))
                    # source_image.save(os.path.join(args.output_dir,("step-{}-source.mhd").format(epoch)))

                    label_image = torchio.ScalarImage(tensor=y, affine=affine)
                    label_image.save(
                        os.path.join(args.output_dir,
                                     f"step-{epoch:04d}-gt" + hp.save_arch))

                    output_image = torchio.ScalarImage(tensor=outputs,
                                                       affine=affine)
                    output_image.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-predict" + hp.save_arch))
                else:
                    y = np.expand_dims(y, axis=1)
                    outputs = np.expand_dims(outputs, axis=1)

                    source_image = torchio.ScalarImage(tensor=x, affine=affine)
                    source_image.save(
                        os.path.join(args.output_dir,
                                     f"step-{epoch:04d}-source" +
                                     hp.save_arch))

                    label_image_artery = torchio.ScalarImage(tensor=y[0],
                                                             affine=affine)
                    label_image_artery.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-gt_artery" + hp.save_arch))

                    output_image_artery = torchio.ScalarImage(
                        tensor=outputs[0], affine=affine)
                    output_image_artery.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-predict_artery" + hp.save_arch))

                    label_image_lung = torchio.ScalarImage(tensor=y[1],
                                                           affine=affine)
                    label_image_lung.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-gt_lung" + hp.save_arch))

                    output_image_lung = torchio.ScalarImage(tensor=outputs[1],
                                                            affine=affine)
                    output_image_lung.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-predict_lung" + hp.save_arch))

                    label_image_trachea = torchio.ScalarImage(tensor=y[2],
                                                              affine=affine)
                    label_image_trachea.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-gt_trachea" + hp.save_arch))

                    output_image_trachea = torchio.ScalarImage(
                        tensor=outputs[2], affine=affine)
                    output_image_trachea.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-predict_trachea" +
                            hp.save_arch))

                    label_image_vein = torchio.ScalarImage(tensor=y[3],
                                                           affine=affine)
                    label_image_vein.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-gt_vein" + hp.save_arch))

                    output_image_vein = torchio.ScalarImage(tensor=outputs[3],
                                                            affine=affine)
                    output_image_vein.save(
                        os.path.join(
                            args.output_dir,
                            f"step-{epoch:04d}-predict_vein" + hp.save_arch))

    writer.close()
def test():

    parser = argparse.ArgumentParser(
        description='PyTorch Medical Segmentation Testing')
    parser = parse_training_args(parser)
    args, _ = parser.parse_known_args()

    args = parser.parse_args()

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    from data_function import MedData_test

    os.makedirs(output_dir_test, exist_ok=True)

    if hp.mode == '2d':
        from models.two_d.unet import Unet
        model = Unet(in_channels=hp.in_class, classes=hp.out_class)

        # from models.two_d.miniseg import MiniSeg
        # model = MiniSeg(in_input=hp.in_class, classes=hp.out_class)

        # from models.two_d.fcn import FCN32s as fcn
        # model = fcn(in_class =hp.in_class,n_class=hp.out_class)

        # from models.two_d.segnet import SegNet
        # model = SegNet(input_nbr=hp.in_class,label_nbr=hp.out_class)

        # from models.two_d.deeplab import DeepLabV3
        # model = DeepLabV3(in_class=hp.in_class,class_num=hp.out_class)

        # from models.two_d.unetpp import ResNet34UnetPlus
        # model = ResNet34UnetPlus(num_channels=hp.in_class,num_class=hp.out_class)

        # from models.two_d.pspnet import PSPNet
        # model = PSPNet(in_class=hp.in_class,n_classes=hp.out_class)

    elif hp.mode == '3d':
        from models.three_d.unet3d import UNet
        model = UNet(in_channels=hp.in_class,
                     n_classes=hp.out_class,
                     base_n_filter=2)

        #from models.three_d.fcn3d import FCN_Net
        #model = FCN_Net(in_channels =hp.in_class,n_class =hp.out_class)

        #from models.three_d.highresnet import HighRes3DNet
        #model = HighRes3DNet(in_channels=hp.in_class,out_channels=hp.out_class)

        #from models.three_d.densenet3d import SkipDenseNet3D
        #model = SkipDenseNet3D(in_channels=hp.in_class, classes=hp.out_class)

        # from models.three_d.densevoxelnet3d import DenseVoxelNet
        # model = DenseVoxelNet(in_channels=hp.in_class, classes=hp.out_class)

        #from models.three_d.vnet3d import VNet
        #model = VNet(in_channels=hp.in_class, classes=hp.out_class)

    model = torch.nn.DataParallel(model,
                                  device_ids=devicess,
                                  output_device=[1])

    print("load model:", args.ckpt)
    print(os.path.join(args.output_dir, args.latest_checkpoint_file))
    ckpt = torch.load(os.path.join(args.output_dir,
                                   args.latest_checkpoint_file),
                      map_location=lambda storage, loc: storage)

    model.load_state_dict(ckpt["model"])

    model.cuda()

    test_dataset = MedData_test(source_test_dir, label_test_dir)
    znorm = ZNormalization()

    if hp.mode == '3d':
        patch_overlap = hp.patch_overlap
        patch_size = hp.patch_size
    elif hp.mode == '2d':
        patch_overlap = hp.patch_overlap
        patch_size = hp.patch_size

    for i, subj in enumerate(test_dataset.subjects):
        subj = znorm(subj)
        grid_sampler = torchio.inference.GridSampler(
            subj,
            patch_size,
            patch_overlap,
        )

        patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=16)
        aggregator = torchio.inference.GridAggregator(grid_sampler)
        aggregator_1 = torchio.inference.GridAggregator(grid_sampler)
        model.eval()
        with torch.no_grad():
            for patches_batch in tqdm(patch_loader):

                input_tensor = patches_batch['source'][torchio.DATA].to(device)
                locations = patches_batch[torchio.LOCATION]

                if hp.mode == '2d':
                    input_tensor = input_tensor.squeeze(4)
                outputs = model(input_tensor)

                if hp.mode == '2d':
                    outputs = outputs.unsqueeze(4)
                logits = torch.sigmoid(outputs)

                labels = logits.clone()
                labels[labels > 0.5] = 1
                labels[labels <= 0.5] = 0

                aggregator.add_batch(logits, locations)
                aggregator_1.add_batch(labels, locations)
        output_tensor = aggregator.get_output_tensor()
        output_tensor_1 = aggregator_1.get_output_tensor()

        affine = subj['source']['affine']
        if (hp.in_class == 1) and (hp.out_class == 1):
            label_image = torchio.ScalarImage(tensor=output_tensor.numpy(),
                                              affine=affine)
            label_image.save(
                os.path.join(output_dir_test,
                             f"{str(i):04d}-result_float" + hp.save_arch))

            # f"{str(i):04d}-result_float.mhd"

            output_image = torchio.ScalarImage(tensor=output_tensor_1.numpy(),
                                               affine=affine)
            output_image.save(
                os.path.join(output_dir_test,
                             f"{str(i):04d}-result_int" + hp.save_arch))
        else:
            output_tensor = output_tensor.unsqueeze(1)
            output_tensor_1 = output_tensor_1.unsqueeze(1)

            output_image_artery_float = torchio.ScalarImage(
                tensor=output_tensor[0].numpy(), affine=affine)
            output_image_artery_float.save(
                os.path.join(
                    output_dir_test,
                    f"{str(i):04d}-result_float_artery" + hp.save_arch))
            # f"{str(i):04d}-result_float_artery.mhd"

            output_image_artery_int = torchio.ScalarImage(
                tensor=output_tensor_1[0].numpy(), affine=affine)
            output_image_artery_int.save(
                os.path.join(output_dir_test,
                             f"{str(i):04d}-result_int_artery" + hp.save_arch))

            output_image_lung_float = torchio.ScalarImage(
                tensor=output_tensor[1].numpy(), affine=affine)
            output_image_lung_float.save(
                os.path.join(output_dir_test,
                             f"{str(i):04d}-result_float_lung" + hp.save_arch))

            output_image_lung_int = torchio.ScalarImage(
                tensor=output_tensor_1[1].numpy(), affine=affine)
            output_image_lung_int.save(
                os.path.join(output_dir_test,
                             f"{str(i):04d}-result_int_lung" + hp.save_arch))

            output_image_trachea_float = torchio.ScalarImage(
                tensor=output_tensor[2].numpy(), affine=affine)
            output_image_trachea_float.save(
                os.path.join(
                    output_dir_test,
                    f"{str(i):04d}-result_float_trachea" + hp.save_arch))

            output_image_trachea_int = torchio.ScalarImage(
                tensor=output_tensor_1[2].numpy(), affine=affine)
            output_image_trachea_int.save(
                os.path.join(output_dir_test,
                             f"{str(i):04d}-result_int_trachea" +
                             hp.save_arch))

            output_image_vein_float = torchio.ScalarImage(
                tensor=output_tensor[3].numpy(), affine=affine)
            output_image_vein_float.save(
                os.path.join(output_dir_test,
                             f"{str(i):04d}-result_float_vein" + hp.save_arch))

            output_image_vein_int = torchio.ScalarImage(
                tensor=output_tensor_1[3].numpy(), affine=affine)
            output_image_vein_int.save(
                os.path.join(output_dir_test,
                             f"{str(i):04d}-result_int_vein" + hp.save_arch))