예제 #1
0
 def test_rand_3d_elastic(self, input_param, input_data, expected_val):
     g = Rand3DElastic(**input_param)
     g.set_random_state(123)
     result = g(**input_data)
     self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor))
     if isinstance(result, torch.Tensor):
         np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4)
     else:
         np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4)
예제 #2
0
 def test_rand_3d_elastic(self, input_param, input_data, expected_val):
     g = Rand3DElastic(**input_param)
     set_track_meta(False)
     g.set_random_state(123)
     result = g(**input_data)
     self.assertNotIsInstance(result, MetaTensor)
     self.assertIsInstance(result, torch.Tensor)
     set_track_meta(True)
     g.set_random_state(123)
     result = g(**input_data)
     assert_allclose(result, expected_val, type_test=False, rtol=1e-1, atol=1e-1)
예제 #3
0
 def test_rand_3d_elastic(self, input_param, input_data, expected_val):
     g = Rand3DElastic(**input_param)
     g.set_random_state(123)
     result = g(**input_data)
     assert_allclose(result, expected_val, rtol=1e-1, atol=1e-1)
예제 #4
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    # suppress printing if not master
    if args.multiprocessing_distributed and args.gpu != 0:

        def print_pass(*args):
            pass

        builtins.print = print_pass

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    if args.rank == 0:
        configure(os.path.join('./exp', args.exp_name))

    # create model
    model_patch = moco.builder_v3.MoCo(Encoder, args.num_patch, args.moco_dim,
                                       args.moco_k_patch, args.moco_m,
                                       args.moco_t, args.mlp)

    model_graph = moco.builder_graph.MoCo(GraphNet, args.gpu, args.moco_dim,
                                          args.moco_k_graph, args.moco_m,
                                          args.moco_t, args.mlp)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model_patch.cuda(args.gpu)
            model_graph.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size_patch = int(args.batch_size_patch / ngpus_per_node)
            args.batch_size_graph = int(args.batch_size_graph / ngpus_per_node)
            args.workers_patch = int(
                (args.workers_patch + ngpus_per_node - 1) / ngpus_per_node)
            args.workers_graph = int(
                (args.workers_graph + ngpus_per_node - 1) / ngpus_per_node)
            model_patch = torch.nn.parallel.DistributedDataParallel(
                model_patch, device_ids=[args.gpu])
            model_graph = torch.nn.parallel.DistributedDataParallel(
                model_graph,
                device_ids=[args.gpu],
                find_unused_parameters=True)
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        # comment out the following line for debugging
        raise NotImplementedError("Only DistributedDataParallel is supported.")
    else:
        # AllGather implementation (batch shuffle, queue update, etc.) in
        # this code only supports DistributedDataParallel.
        raise NotImplementedError("Only DistributedDataParallel is supported.")

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer_patch = torch.optim.SGD(model_patch.parameters(),
                                      args.lr,
                                      momentum=args.momentum,
                                      weight_decay=args.weight_decay)
    optimizer_graph = torch.optim.SGD(model_graph.parameters(),
                                      args.lr,
                                      momentum=args.momentum,
                                      weight_decay=args.weight_decay)
    # save the initial model
    if not args.resume:
        save_checkpoint(
            {
                'epoch': 0,
                'arch': args.arch,
                'state_dict': model_patch.state_dict(),
                'optimizer': optimizer_patch.state_dict(),
            },
            is_best=False,
            filename=os.path.join(os.path.join('./exp', args.exp_name),
                                  'checkpoint_patch_init.pth.tar'))
        save_checkpoint(
            {
                'epoch': 0,
                'arch': args.arch,
                'state_dict': model_graph.state_dict(),
                'optimizer': optimizer_graph.state_dict(),
            },
            is_best=False,
            filename=os.path.join(os.path.join('./exp', args.exp_name),
                                  'checkpoint_graph_init.pth.tar'))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint_patch = torch.load(args.resume)
                checkpoint_graph = torch.load(args.resume_graph)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint_patch = torch.load(args.resume, map_location=loc)
                checkpoint_graph = torch.load(args.resume_graph,
                                              map_location=loc)
            args.start_epoch = checkpoint_patch['epoch']
            model_patch.load_state_dict(checkpoint_patch['state_dict'])
            model_graph.load_state_dict(checkpoint_graph['state_dict'])
            optimizer_patch.load_state_dict(checkpoint_patch['optimizer'])
            optimizer_graph.load_state_dict(checkpoint_graph['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint_patch['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            exit()

    transform_re = Rand3DElastic(
        mode='bilinear',
        prob=1.0,
        sigma_range=(8, 12),
        magnitude_range=(0, 1024 + 240),  #[-1024, 240] -> [0, 1024+240]
        spatial_size=(32, 32, 32),
        translate_range=(12, 12, 12),
        rotate_range=(np.pi / 18, np.pi / 18, np.pi / 18),
        scale_range=(0.1, 0.1, 0.1),
        padding_mode='border',
        device=torch.device('cuda:' + str(args.gpu)))
    transform_rgn = RandGaussianNoise(prob=0.25, mean=0.0, std=50)
    transform_rac = RandAdjustContrast(prob=0.25)

    train_tratransforms = Compose([transform_rac, transform_rgn, transform_re])

    train_dataset_patch = COPD_dataset_patch(
        "train", args, moco.loader.TwoCropsTransform(train_tratransforms))
    train_dataset_graph = COPD_dataset_graph(
        "train", args, moco.loader.TwoCropsTransform(train_tratransforms))

    if args.distributed:
        train_sampler_patch = torch.utils.data.distributed.DistributedSampler(
            train_dataset_patch)
        train_sampler_graph = torch.utils.data.distributed.DistributedSampler(
            train_dataset_graph)
    else:
        train_sampler = None

    train_loader_patch = torch.utils.data.DataLoader(
        train_dataset_patch,
        batch_size=args.batch_size_patch,
        shuffle=(train_sampler_patch is None),
        num_workers=args.workers_patch,
        pin_memory=True,
        sampler=train_sampler_patch,
        drop_last=True)
    train_loader_graph = torch.utils.data.DataLoader(
        train_dataset_graph,
        batch_size=args.batch_size_graph,
        shuffle=(train_sampler_graph is None),
        num_workers=args.workers_graph,
        pin_memory=True,
        sampler=train_sampler_graph,
        drop_last=True)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler_patch.set_epoch(epoch)
        adjust_learning_rate(optimizer_patch, epoch, args)
        # train for one epoch
        train_patch(train_loader_patch, model_patch, criterion,
                    optimizer_patch, epoch, args)
        # save model for every epoch
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model_patch.state_dict(),
                    'optimizer': optimizer_patch.state_dict(),
                },
                is_best=False,
                filename=os.path.join(
                    os.path.join('./exp', args.exp_name),
                    'checkpoint_patch_{:04d}.pth.tar'.format(epoch)))

        for sub_epoch in range(args.num_sub_epoch):
            if args.distributed:
                train_sampler_graph.set_epoch(args.num_sub_epoch * epoch +
                                              sub_epoch)
            adjust_learning_rate(optimizer_graph,
                                 args.num_sub_epoch * epoch + sub_epoch, args)
            train_graph(train_loader_graph, model_graph, model_patch,
                        criterion, optimizer_graph,
                        args.num_sub_epoch * epoch + sub_epoch, args)
            if not args.multiprocessing_distributed or (
                    args.multiprocessing_distributed
                    and args.rank % ngpus_per_node == 0):
                save_checkpoint(
                    {
                        'epoch': args.num_sub_epoch * epoch + sub_epoch + 1,
                        'arch': args.arch,
                        'state_dict': model_graph.state_dict(),
                        'optimizer': optimizer_graph.state_dict(),
                    },
                    is_best=False,
                    filename=os.path.join(
                        os.path.join('./exp', args.exp_name),
                        'checkpoint_graph_{:04d}.pth.tar'.format(
                            args.num_sub_epoch * epoch + sub_epoch)))