def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) g.set_random_state(123) result = g(**input_data) self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor)) if isinstance(result, torch.Tensor): np.testing.assert_allclose(result.cpu().numpy(), expected_val.cpu().numpy(), rtol=1e-4, atol=1e-4) else: np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4)
def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) set_track_meta(False) g.set_random_state(123) result = g(**input_data) self.assertNotIsInstance(result, MetaTensor) self.assertIsInstance(result, torch.Tensor) set_track_meta(True) g.set_random_state(123) result = g(**input_data) assert_allclose(result, expected_val, type_test=False, rtol=1e-1, atol=1e-1)
def test_rand_3d_elastic(self, input_param, input_data, expected_val): g = Rand3DElastic(**input_param) g.set_random_state(123) result = g(**input_data) assert_allclose(result, expected_val, rtol=1e-1, atol=1e-1)
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu # suppress printing if not master if args.multiprocessing_distributed and args.gpu != 0: def print_pass(*args): pass builtins.print = print_pass if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) if args.rank == 0: configure(os.path.join('./exp', args.exp_name)) # create model model_patch = moco.builder_v3.MoCo(Encoder, args.num_patch, args.moco_dim, args.moco_k_patch, args.moco_m, args.moco_t, args.mlp) model_graph = moco.builder_graph.MoCo(GraphNet, args.gpu, args.moco_dim, args.moco_k_graph, args.moco_m, args.moco_t, args.mlp) if args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if args.gpu is not None: torch.cuda.set_device(args.gpu) model_patch.cuda(args.gpu) model_graph.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have args.batch_size_patch = int(args.batch_size_patch / ngpus_per_node) args.batch_size_graph = int(args.batch_size_graph / ngpus_per_node) args.workers_patch = int( (args.workers_patch + ngpus_per_node - 1) / ngpus_per_node) args.workers_graph = int( (args.workers_graph + ngpus_per_node - 1) / ngpus_per_node) model_patch = torch.nn.parallel.DistributedDataParallel( model_patch, device_ids=[args.gpu]) model_graph = torch.nn.parallel.DistributedDataParallel( model_graph, device_ids=[args.gpu], find_unused_parameters=True) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) # comment out the following line for debugging raise NotImplementedError("Only DistributedDataParallel is supported.") else: # AllGather implementation (batch shuffle, queue update, etc.) in # this code only supports DistributedDataParallel. raise NotImplementedError("Only DistributedDataParallel is supported.") # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) optimizer_patch = torch.optim.SGD(model_patch.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer_graph = torch.optim.SGD(model_graph.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # save the initial model if not args.resume: save_checkpoint( { 'epoch': 0, 'arch': args.arch, 'state_dict': model_patch.state_dict(), 'optimizer': optimizer_patch.state_dict(), }, is_best=False, filename=os.path.join(os.path.join('./exp', args.exp_name), 'checkpoint_patch_init.pth.tar')) save_checkpoint( { 'epoch': 0, 'arch': args.arch, 'state_dict': model_graph.state_dict(), 'optimizer': optimizer_graph.state_dict(), }, is_best=False, filename=os.path.join(os.path.join('./exp', args.exp_name), 'checkpoint_graph_init.pth.tar')) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint_patch = torch.load(args.resume) checkpoint_graph = torch.load(args.resume_graph) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint_patch = torch.load(args.resume, map_location=loc) checkpoint_graph = torch.load(args.resume_graph, map_location=loc) args.start_epoch = checkpoint_patch['epoch'] model_patch.load_state_dict(checkpoint_patch['state_dict']) model_graph.load_state_dict(checkpoint_graph['state_dict']) optimizer_patch.load_state_dict(checkpoint_patch['optimizer']) optimizer_graph.load_state_dict(checkpoint_graph['optimizer']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint_patch['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) exit() transform_re = Rand3DElastic( mode='bilinear', prob=1.0, sigma_range=(8, 12), magnitude_range=(0, 1024 + 240), #[-1024, 240] -> [0, 1024+240] spatial_size=(32, 32, 32), translate_range=(12, 12, 12), rotate_range=(np.pi / 18, np.pi / 18, np.pi / 18), scale_range=(0.1, 0.1, 0.1), padding_mode='border', device=torch.device('cuda:' + str(args.gpu))) transform_rgn = RandGaussianNoise(prob=0.25, mean=0.0, std=50) transform_rac = RandAdjustContrast(prob=0.25) train_tratransforms = Compose([transform_rac, transform_rgn, transform_re]) train_dataset_patch = COPD_dataset_patch( "train", args, moco.loader.TwoCropsTransform(train_tratransforms)) train_dataset_graph = COPD_dataset_graph( "train", args, moco.loader.TwoCropsTransform(train_tratransforms)) if args.distributed: train_sampler_patch = torch.utils.data.distributed.DistributedSampler( train_dataset_patch) train_sampler_graph = torch.utils.data.distributed.DistributedSampler( train_dataset_graph) else: train_sampler = None train_loader_patch = torch.utils.data.DataLoader( train_dataset_patch, batch_size=args.batch_size_patch, shuffle=(train_sampler_patch is None), num_workers=args.workers_patch, pin_memory=True, sampler=train_sampler_patch, drop_last=True) train_loader_graph = torch.utils.data.DataLoader( train_dataset_graph, batch_size=args.batch_size_graph, shuffle=(train_sampler_graph is None), num_workers=args.workers_graph, pin_memory=True, sampler=train_sampler_graph, drop_last=True) for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler_patch.set_epoch(epoch) adjust_learning_rate(optimizer_patch, epoch, args) # train for one epoch train_patch(train_loader_patch, model_patch, criterion, optimizer_patch, epoch, args) # save model for every epoch if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model_patch.state_dict(), 'optimizer': optimizer_patch.state_dict(), }, is_best=False, filename=os.path.join( os.path.join('./exp', args.exp_name), 'checkpoint_patch_{:04d}.pth.tar'.format(epoch))) for sub_epoch in range(args.num_sub_epoch): if args.distributed: train_sampler_graph.set_epoch(args.num_sub_epoch * epoch + sub_epoch) adjust_learning_rate(optimizer_graph, args.num_sub_epoch * epoch + sub_epoch, args) train_graph(train_loader_graph, model_graph, model_patch, criterion, optimizer_graph, args.num_sub_epoch * epoch + sub_epoch, args) if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): save_checkpoint( { 'epoch': args.num_sub_epoch * epoch + sub_epoch + 1, 'arch': args.arch, 'state_dict': model_graph.state_dict(), 'optimizer': optimizer_graph.state_dict(), }, is_best=False, filename=os.path.join( os.path.join('./exp', args.exp_name), 'checkpoint_graph_{:04d}.pth.tar'.format( args.num_sub_epoch * epoch + sub_epoch)))