def __init__(self, params, use_cuda=False): super(MModel, self).__init__() self.params = params self.src_mask_delta_UNet = UNet(params, 3, [64] * 2 + [128] * 9, [128] * 4 + [32]) self.src_mask_delta_Conv = nn.Conv2d(32, 11, kernel_size=3, stride=1, padding=1, padding_mode='replicate') self.fg_UNet = UNet(params, 30, [64] * 2 + [128] * 9, [128] * 4 + [64]) self.fg_tgt_Conv = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1, padding_mode='replicate') self.fg_mask_Conv = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, padding_mode='replicate') self.bg_UNet = UNet(params, 4, [64] * 2 + [128] * 9, [128] * 4 + [64]) self.bg_tgt_Conv = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1, padding_mode='replicate') self.use_cuda = use_cuda
def build_model(self): """Build generator and discriminator.""" if self.model_type == 'UNet': self.unet = UNet(n_channels=1, n_classes=1) elif self.model_type == 'R2U_Net': self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t) # TODO: changed for green image channel elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=1, output_ch=1) elif self.model_type == 'R2AttU_Net': self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t) elif self.model_type == 'Iternet': self.unet = Iternet(n_channels=1, n_classes=1) elif self.model_type == 'AttUIternet': self.unet = AttUIternet(n_channels=1, n_classes=1) elif self.model_type == 'R2UIternet': self.unet = R2UIternet(n_channels=3, n_classes=1) elif self.model_type == 'NestedUNet': self.unet = NestedUNet(in_ch=1, out_ch=1) elif self.model_type == "AG_Net": self.unet = AG_Net(n_classes=1, bn=True, BatchNorm=False) self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr, betas=tuple(self.beta_list)) self.unet.to(self.device)
def log_param_and_grad(net: UNet, writer: tensorboardX.SummaryWriter, step): for name, param in net.named_parameters(): writer.add_histogram(f"grad/{name}", param.grad.detach().cpu().numpy(), step) writer.add_histogram(f"grad_norm/{name}", np.sqrt((param**2).sum().detach().cpu().numpy()), step) writer.add_histogram(f"param/{name}", param.detach().cpu().numpy(), step)
def get_net(params): if params['network'].lower() == 'pan': net = PAN(params) elif params['network'].lower() == 'shortres': net = ShortRes(params) elif params['network'].lower() == 'unet': net = UNet(params) return net
def my_app(): batch_size_app = 1 loader_ap = loaders(batch_size_app, 2) output = UNet() output.cuda() #output.load_state_dict(torch.load('/home/daisylabs/aritra_project/results/output.pth')) output.load_state_dict( torch.load('/home/daisylabs/aritra_project/results/output_best.pth')) output.eval() with torch.set_grad_enabled(False): for u, (inputs, targets) in enumerate(loader_ap): if (u == 0): inputs = inputs.reshape((batch_size_app, 3, 256, 256)) targets = targets.reshape((batch_size_app, 256, 256, 256)) out_1, out_2 = output(inputs) out_1 = out_1.reshape((batch_size_app, 256, 256, 256)) targets = targets.cpu().numpy() out_1 = out_1.cpu().numpy() tgt_shp = targets.shape[1] for slice_number in range(tgt_shp): targets_1 = targets[0][slice_number].reshape((256, 256)) out_1_1 = out_1[0][slice_number].reshape((256, 256)) plt.figure() plt.subplot(1, 2, 1) plt.title('Original Slice') plt.imshow(targets_1, cmap=plt.get_cmap('gray'), vmin=0, vmax=1) plt.subplot(1, 2, 2) plt.title('Reconstructed Slice') plt.imshow(out_1_1, cmap=plt.get_cmap('gray'), vmin=0, vmax=1) plt.savefig( '/home/daisylabs/aritra_project/results/slices/%d.png' % (slice_number + 1, )) else: break
def main(): # filename = 'mixture2.wav' filename = 'aimer/1-02 花の唄.wav' # filename = 'amazarashi/03 季節は次々死んでいく.wav' batch_length = 512 fs = 44100 frame_size = 4096 shift_size = 2048 modelname = 'model/fs%d_frame%d_shift%d_batch%d.model' % ( fs, frame_size, shift_size, batch_length) statname = 'stat/fs%d_frame%d_shift%d_batch%d.npy' % ( fs, frame_size, shift_size, batch_length) max_norm = float(np.load(statname)) # load network model = UNet() model.load_state_dict(torch.load(modelname)) model.eval() torch.backends.cudnn.benchmark = True # gpu if torch.cuda.is_available(): model.cuda() else: print('gpu is not avaiable.') sys.exit(1) # load wave file wave = load(filename, sr=fs)[0] spec = stft(wave, frame_size, shift_size) soft_vocal, soft_accom, hard_vocal, hard_accom = extract( spec, model, max_norm, fs, frame_size, shift_size) write_wav(os.path.splitext( os.path.basename(filename))[0] + '_original.wav', wave, fs) write_wav(os.path.splitext( os.path.basename(filename))[0] + '_soft_vocal.wav', soft_vocal, fs) write_wav(os.path.splitext( os.path.basename(filename))[0] + '_soft_accom.wav', soft_accom, fs) write_wav(os.path.splitext( os.path.basename(filename))[0] + '_hard_vocal.wav', hard_vocal, fs) write_wav(os.path.splitext( os.path.basename(filename))[0] + '_hard_accom.wav', hard_accom, fs)
def TrainUNet(X, Y, model_=None, optimizer_=None, epoch=40, alpha=0.001, gpu_id=0, loop=1, earlystop=True): assert (len(X) == len(Y)) d_time = datetime.datetime.now().strftime("%m-%d-%H-%M-%S") # 1. Model load. # print(sum(p.data.size for p in model.unet.params())) if model_ is not None: model = Regressor(model_) print("## model loaded.") else: model = Regressor(UNet()) model.compute_accuracy = False if gpu_id >= 0: model.to_gpu(gpu_id) # 2. optimizer load. if optimizer_ is not None: opt = optimizer_ print("## optimizer loaded.") else: opt = optimizers.Adam(alpha=alpha) opt.setup(model) # 3. Data Split. dataset = Unet_DataSet(X, Y) print("# number of patterns", len(dataset)) train, valid = \ split_dataset_random(dataset, int(len(dataset) * 0.8), seed=0) # 4. Iterator train_iter = SerialIterator(train, batch_size=C.BATCH_SIZE) test_iter = SerialIterator(valid, batch_size=C.BATCH_SIZE, repeat=False, shuffle=False) # 5. config train, enable backprop chainer.config.train = True chainer.config.enable_backprop = True # 6. UnetUpdater updater = UnetUpdater(train_iter, opt, model, device=gpu_id) # 7. EarlyStopping if earlystop: stop_trigger = triggers.EarlyStoppingTrigger( monitor='validation/main/loss', max_trigger=(epoch, 'epoch'), patients=5) else: stop_trigger = (epoch, 'epoch') # 8. Trainer trainer = training.Trainer(updater, stop_trigger, out=C.PATH_TRAINRESULT) # 8.1. UnetEvaluator trainer.extend(UnetEvaluator(test_iter, model, device=gpu_id)) trainer.extend(SaveRestore(), trigger=triggers.MinValueTrigger('validation/main/loss')) # 8.2. Extensions LogReport trainer.extend(extensions.LogReport()) # 8.3. Extension Snapshot # trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}')) # trainer.extend(extensions.snapshot_object(model.unet, filename='loop' + str(loop) + '.model')) # 8.4. Print Report trainer.extend(extensions.observe_lr()) trainer.extend( extensions.PrintReport([ 'epoch', 'main/loss', 'validation/main/loss', 'elapsed_time', 'lr' ])) # 8.5. Extension Graph trainer.extend( extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loop-' + str(loop) + '-loss' + d_time + '.png')) # trainer.extend(extensions.dump_graph('main/loss')) # 8.6. Progree Bar trainer.extend(extensions.ProgressBar()) # 9. Trainer run trainer.run() chainer.serializers.save_npz(C.PATH_TRAINRESULT / ('loop' + str(loop)), model.unet) return model.unet, opt
required=True, help='Image Directory') parser.add_argument('-g', '--gpu', type=int, default=0, help='GPU selection') parser.add_argument('-r', '--resolution', type=int, required=True, help='Resolution for Square Image') args = parser.parse_args() # Height model_h = HUNet(128) pretrained_model_h = torch.load( '/content/drive/My Drive/Colab Notebooks/AI_Australia/Models/model_ep_48.pth.tar' ) # Weight model_w = UNet(128, 32, 32) pretrained_model_w = torch.load( '/content/drive/My Drive/Colab Notebooks/AI_Australia/Models/model_ep_37.pth.tar' ) model_h.load_state_dict(pretrained_model_h["state_dict"]) model_w.load_state_dict(pretrained_model_w["state_dict"]) if torch.cuda.is_available(): model = model_w.cuda(args.gpu)
def main(args): modelname = os.path.join(args.dst_dir, os.path.splitext(args.src_file)[0]) if not os.path.exists(args.dst_dir): os.makedirs(args.dst_dir) # define transforms max_norm = float(np.load(args.stats_file)) transform = transforms.Compose([ lambda x: x / max_norm]) # load data with open(args.src_file, 'r') as f: files = f.readlines() filelist = [file.replace('\n', '') for file in files] # define sampler index = list(range(len(filelist))) train_index = sample(index, round(len(index) * args.ratio)) valid_index = list(set(index) - set(train_index)) train_sampler = SubsetRandomSampler(train_index) valid_sampler = SubsetRandomSampler(valid_index) # define dataloader trainset = MagSpecDataset(filelist, transform=transform) train_loader = torch.utils.data.DataLoader( dataset=trainset, batch_size=args.batch_size, shuffle=False, sampler=train_sampler, num_workers=args.num_worker) valid_loader = torch.utils.data.DataLoader( dataset=trainset, batch_size=1, shuffle=False, sampler=valid_sampler, num_workers=args.num_worker) # fix seed torch.manual_seed(args.seed) # define network model = UNet() # define optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr) # define loss criterion = nn.L1Loss(size_average=False) # gpu if torch.cuda.is_available(): model.cuda() criterion.cuda() else: print('gpu is not avaiable.') sys.exit(1) # training for epoch in range(args.num_epoch): model.train() train_loss = 0.0 for i, data in enumerate(train_loader): inputs, targets, silence = data # wrap them in Variable inputs = Variable(inputs[:, None, ...]).cuda() targets = Variable(targets[:, None, ...]).cuda() silence = Variable(silence[:, None, ...]).cuda() # zero the parameter gradients optimizer.zero_grad() outputs = model(inputs) batch_train_loss = criterion( inputs * outputs * silence, targets * silence) batch_train_loss.backward() optimizer.step() # print statistics train_loss += batch_train_loss.item() model.eval() valid_loss = 0.0 for i, data in enumerate(valid_loader): inputs, targets, silence = data # wrap them in Variable inputs = Variable(inputs[:, None, ...]).cuda() targets = Variable(targets[:, None, ...]).cuda() silence = Variable(silence[:, None, ...]).cuda() outputs = model(inputs) batch_valid_loss = criterion( inputs * outputs * silence, targets * silence) # print statistics valid_loss += batch_valid_loss.item() print('[{}/{}] training loss: {:.3f}; validation loss: {:.3f}'.format( epoch + 1, args.num_epoch, train_loss, valid_loss)) # save model if epoch % args.num_interval == args.num_interval - 1: torch.save( model.state_dict(), modelname + '_batch{}_ep{}.model'.format( args.batch_length, epoch + 1)) torch.save(model.state_dict(), modelname + '.model')
elif args.loss == 'mae': height_loss = nn.L1Loss() elif args.loss == 'huber': height_loss = nn.SmoothL1Loss() train = DataLoader(Images(args.dataset, 'TRAINING.csv', True), batch_size=args.batch_size, num_workers=8, shuffle=True) valid = DataLoader(Images(args.dataset, 'VAL.csv', True), batch_size=1, num_workers=8, shuffle=False) print("Training on " + str(len(train)*args.batch_size) + " images.") print("Validating on " + str(len(valid)) + " images.") net = UNet(args.min_neuron) start_epoch = 0 #pretrained_model = torch.load(glob('models/IMDB_MODEL_06102019_121502/*')[0]) #state_dict = pretrained_model["state_dict"] #own_state = net.state_dict() #for name, param in state_dict.items(): # if name not in own_state: # continue # if isinstance(param, Parameter): # backwards compatibility for serialized parameters # param = param.data # if not (("height_1" in name) or ("height_2" in name)):
def main(args): torch.backends.cudnn.benchmark = True seed_all(args.seed) num_classes = 1 d = Dataset(train_set_size=args.train_set_sz, num_cls=num_classes) train = d.train_set valid = d.test_set net = UNet(in_dim=1, out_dim=4).cuda() snake_approx_net = UNet(in_dim=1, out_dim=1, wf=3, padding=True, first_layer_pad=None, depth=4, last_layer_resize=True).cuda() best_val_dice = -np.inf optimizer = torch.optim.Adam(params=net.parameters(), lr=args.lr, weight_decay=args.weight_decay) snake_approx_optimizer = torch.optim.Adam( params=snake_approx_net.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=50, after_scheduler=None) # load model if args.ckpt: loaded = _pickle.load(open(args.ckpt, 'rb')) net.load_state_dict(loaded[0]) optimizer.load_state_dict(loaded[1]) snake_approx_net.load_state_dict(loaded[2]) snake_approx_optimizer.load_state_dict(loaded[3]) if not os.path.exists(args.log_dir): os.makedirs(args.log_dir, exist_ok=True) writer = tensorboardX.SummaryWriter(log_dir=args.log_dir) snake = SnakePytorch(args.delta, args.batch_sz * args.num_samples, args.num_lines, args.radius) snake_eval = SnakePytorch(args.delta, args.batch_sz, args.num_lines, args.radius) noises = torch.zeros( (args.batch_sz, args.num_samples, args.num_lines, args.radius)).cuda() step = 1 start = timeit.default_timer() for epoch in range(1, args.n_epochs + 1): for iteration in range( 1, int(np.ceil(train.dataset_sz() / args.batch_sz)) + 1): scheduler_warmup.step() imgs, masks, onehot_masks, centers, dts_modified, dts_original, jitter_radius, bboxes = \ train.next_batch(args.batch_sz) xs = make_batch_input(imgs) xs = torch.cuda.FloatTensor(xs) net.train() unet_logits = net(xs) center_jitters, angle_jitters = [], [] for img, mask, center in zip(imgs, masks, centers): c_j, a_j = get_random_jitter_by_mask(mask, center, [1], args.theta_jitter) if not args.use_center_jitter: c_j = np.zeros_like(c_j) center_jitters.append(c_j) angle_jitters.append(a_j) center_jitters = np.asarray(center_jitters) angle_jitters = np.asarray(angle_jitters) # args.radius + 1 because we need additional outermost points for the gradient gs_logits_whole_img = unet_logits[:, 3, ...] gs_logits, coords_r, coords_c = get_star_pattern_values( gs_logits_whole_img, None, centers, args.num_lines, args.radius + 1, center_jitters=center_jitters, angle_jitters=angle_jitters) # currently only class 1 is foreground # if there's multiple foreground classes use a for loop gs = gs_logits[:, :, 1:] - gs_logits[:, :, :-1] # compute the gradient noises.normal_( 0, 1 ) # noises here is only used for random exploration so no need mirrored sampling gs_noisy = torch.unsqueeze(gs, 1) + noises def batch_eval_snake(snake, inputs, batch_sz): n_inputs = len(inputs) assert n_inputs % batch_sz == 0 n_batches = int(np.ceil(n_inputs / batch_sz)) ind_sets = [] for j in range(n_batches): inps = inputs[j * batch_sz:(j + 1) * batch_sz] batch_ind_sets = snake(inps).data.cpu().numpy() ind_sets.append(batch_ind_sets) ind_sets = np.concatenate(ind_sets, 0) return ind_sets gs_noisy = gs_noisy.reshape((args.batch_sz * args.num_samples, args.num_lines, args.radius)) ind_sets = batch_eval_snake(snake, gs_noisy, args.batch_sz * args.num_samples) ind_sets = ind_sets.reshape( (args.batch_sz * args.num_samples, args.num_lines)) ind_sets = np.expand_dims( smooth_ind(ind_sets, args.smoothing_window), -1) # loss layers m = torch.nn.LogSoftmax(dim=1) loss = torch.nn.NLLLoss() # =========================================================================== # Inner loop: Train dice loss prediction network snake_approx_net.train() for _ in range(args.dice_approx_train_steps): snake_approx_logits = snake_approx_net( gs_noisy.reshape(args.batch_sz * args.num_samples, 1, args.num_lines, args.radius).detach()) snake_approx_train_loss = loss( m(snake_approx_logits.squeeze().transpose(2, 1)), torch.cuda.LongTensor(ind_sets.squeeze())) snake_approx_optimizer.zero_grad() snake_approx_train_loss.backward() snake_approx_optimizer.step() # =========================================================================== # =========================================================================== # Now, minimize the approximate dice loss snake_approx_net.eval() gt_indices = [] for mask, center, cj, aj in zip(masks, centers, center_jitters, angle_jitters): gt_ind = mask_to_indices(mask, center, args.radius, args.num_lines, cj, aj) gt_indices.append(gt_ind) gt_indices = np.asarray(gt_indices).astype(int) gt_indices = gt_indices.reshape((args.batch_sz, args.num_lines)) gt_indices = torch.cuda.LongTensor(gt_indices) snake_approx_logits = snake_approx_net( gs.reshape((args.batch_sz, 1, args.num_lines, args.radius))) nll_approx_loss = loss( m(snake_approx_logits.squeeze().transpose(2, 1)), gt_indices) total_loss = nll_approx_loss optimizer.zero_grad() total_loss.backward() optimizer.step() # =========================================================================== snake_approx_train_loss = snake_approx_train_loss.data.cpu().numpy( ) nll_approx_loss = nll_approx_loss.data.cpu().numpy() total_loss = snake_approx_train_loss + nll_approx_loss if step % args.log_freq == 0: stop = timeit.default_timer() print(f"step={step}\tepoch={epoch}\titer={iteration}" f"\tloss={total_loss}" f"\tsnake_approx_train_loss={snake_approx_train_loss}" f"\tnll_approx_loss={nll_approx_loss}" f"\tlr={optimizer.param_groups[0]['lr']}" f"\ttime={stop-start}") start = stop writer.add_scalar("total_loss", total_loss, step) writer.add_scalar("nll_approx_loss", nll_approx_loss, step) writer.add_scalar("lr", optimizer.param_groups[0]["lr"], step) if step % args.train_eval_freq == 0: train_dice = do_eval( net, snake_eval, train.images, train.masks, train.centers, args.batch_sz, args.num_lines, args.radius, smoothing_window=args.smoothing_window).data.cpu().numpy() writer.add_scalar("train_dice", train_dice, step) print( f"step={step}\tepoch={epoch}\titer={iteration}\ttrain_eval: train_dice={train_dice}" ) if step % args.val_eval_freq == 0: val_dice = do_eval( net, snake_eval, valid.images, valid.masks, valid.centers, args.batch_sz, args.num_lines, args.radius, smoothing_window=args.smoothing_window).data.cpu().numpy() writer.add_scalar("val_dice", val_dice, step) print( f"step={step}\tepoch={epoch}\titer={iteration}\tvalid_dice={val_dice}" ) if val_dice > best_val_dice: best_val_dice = val_dice _pickle.dump([ net.state_dict(), optimizer.state_dict(), snake_approx_net.state_dict(), snake_approx_optimizer.state_dict() ], open( os.path.join(args.log_dir, 'best_model.pth.tar'), 'wb')) f = open( os.path.join(args.log_dir, f"best_val_dice{step}.txt"), 'w') f.write(str(best_val_dice)) f.close() print(f"better val dice detected.") step += 1 return best_val_dice
####### CREATE/LOAD TRAINING INFO JSON if trainMode.lower() == 'start': print('\nSTARTING TRAINING...') trainCFG = dict() init_epoch = 0 trainCFG['trainTFRdir'] = trainTFRdir trainCFG['valTFRdir'] = valTFRdir trainCFG[ 'input_shape'] = input_shape #[int(x) for x in args.inputShape.split(',')] trainCFG['batch_size'] = batch_size trainCFG['train_epochs'] = epochs trainCFG['initLR'] = initLR trainCFG['best_val_loss'] = np.inf ##### INITIALIZE MODEL model = UNet(input_shape=input_shape) model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=initLR), loss=tf.losses.CategoricalCrossentropy()) elif trainMode.lower() == 'resume': print('RESUMING TRAINING FROM: ' + trainInfoPath) trainCFG = load_json(trainInfoPath) init_epoch = trainCFG['last_epoch'] + 1 if init_epoch >= trainCFG['train_epochs']: raise Exception( '\nInitial training epoch value is higher than the max. number of training epochs specified' ) model = tf.keras.models.load_model(lastModelPath) dataset_info = load_json(os.path.join(dataDir, 'data_info.json'))
new_file = pad(new_file) #new_file = new_file/torch.max(new_file) dataset = torch.stack((dataset,new_file)) else: new_file,fs = torchaudio.load(filepath) new_file = pad(new_file) #new_file = new_file/torch.max(new_file) dataset = torch.cat((dataset,new_file.unsqueeze(0))) n_files = n_files + 1 print("finished loading: {} files loaded, Total Time: {}".format(n_files, time.time()-start_time)) G = UNet(1,2) G.cuda() #load the model G.load_state_dict(torch.load("g_param.pth")) G.eval() results_path = "val_out" for j in range(dataset.size()[0]): input_stereo = dataset[j,:,:].cuda() input_wav = torch.mean(input_stereo, dim=0).unsqueeze(0) output_wav = G(input_wav.unsqueeze(0)).cpu().detach() torchaudio.save(results_path + os.sep + "test_output_" + str(j) + ".wav", output_wav.squeeze(),fs)
dataset = torch.stack((dataset,new_file)) else: new_file,fs = torchaudio.load(filepath) new_file = pad(new_file) new_file = new_file/torch.max(new_file) dataset = torch.cat((dataset,new_file.unsqueeze(0))) n_files = n_files + 1 print("finished loading: {} files loaded".format(n_files)) #setup the network (refer to the github page) #network input/output (N,C,H,W) = (1,1,256,256) => (1,2,256,256) model = UNet(1,2) model.cuda() model.train() criterion = torch.optim.Adam(model.parameters(), lr = .0001, betas = (.5,.999)) #for each epoch: keep_training = True training_losses = [] counter = 1 results_path = "output" print("training start!") while keep_training: epoch_losses = []
def main(): global args, best_result, output_directory # set random seed torch.manual_seed(args.manual_seed) torch.cuda.manual_seed(args.manual_seed) np.random.seed(args.manual_seed) random.seed(args.manual_seed) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") args.batch_size = args.batch_size * torch.cuda.device_count() else: print("Let's use GPU ", torch.cuda.current_device()) train_loader, val_loader = create_loader(args) if args.resume: assert os.path.isfile(args.resume), \ "=> no checkpoint found at '{}'".format(args.resume) print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = 0 # start_epoch = checkpoint['epoch'] + 1 # best_result = checkpoint['best_result'] # optimizer = checkpoint['optimizer'] # solve 'out of memory' model = checkpoint['model'] optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch'])) # clear memory del checkpoint # del model_dict torch.cuda.empty_cache() else: print("=> creating Model") # input_shape = [args.batch_size,3,256,512] model = UNet(3, 1) print("=> model created.") start_epoch = 0 print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # You can use DataParallel() whether you use Multi-GPUs or not model = nn.DataParallel(model).cuda() # when training, use reduceLROnPlateau to reduce learning rate scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=args.lr_patience) # loss function criterion = criteria.myL1Loss() # criterion = nn.SmoothL1Loss() # create directory path output_directory = utils.get_output_directory(args) if not os.path.exists(output_directory): os.makedirs(output_directory) best_txt = os.path.join(output_directory, 'best.txt') config_txt = os.path.join(output_directory, 'config.txt') # write training parameters to config file if not os.path.exists(config_txt): with open(config_txt, 'w') as txtfile: args_ = vars(args) args_str = '' for k, v in args_.items(): args_str = args_str + str(k) + ':' + str(v) + ',\t\n' txtfile.write(args_str) for epoch in range(start_epoch, args.epochs): # remember change of the learning rate old_lr = 0.0 # adjust_learning_rate(optimizer,epoch) for i, param_group in enumerate(optimizer.param_groups): old_lr = float(param_group['lr']) print("lr: %f" % old_lr) train(train_loader, model, criterion, optimizer, epoch) # train for one epoch result, img_merge = validate(val_loader, model, epoch) # evaluate on validation set # remember best mae and save checkpoint is_best = result.mae < best_result.mae if is_best: best_result = result with open(best_txt, 'w') as txtfile: txtfile.write("epoch={}, mae={:.3f}, " "t_gpu={:.4f}".format(epoch, result.mae, result.gpu_time)) if img_merge is not None: img_filename = output_directory + '/comparison_best.png' utils.save_image(img_merge, img_filename) # save checkpoint for each epoch utils.save_checkpoint( { 'args': args, 'epoch': epoch, 'model': model, 'best_result': best_result, 'optimizer': optimizer, }, is_best, epoch, output_directory) # when mae doesn't fall, reduce learning rate scheduler.step(result.mae)
def main(args): torch.backends.cudnn.benchmark = True seed_all(args.seed) d = Dataset(train_set_size=args.train_set_sz, num_cls=args.num_cls, remove_nan_center=False) train = d.train_set valid = d.test_set num_cls = args.num_cls + 1 # +1 for background net = UNet(in_dim=1, out_dim=num_cls).cuda() best_net = UNet(in_dim=1, out_dim=num_cls) best_val_dice = -np.inf best_cls_val_dices = None optimizer = torch.optim.Adam(params=net.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=50, after_scheduler=None) if not os.path.exists(args.log_dir): os.makedirs(args.log_dir, exist_ok=True) writer = tensorboardX.SummaryWriter(log_dir=args.log_dir) step = 1 for epoch in range(1, args.n_epochs + 1): for iteration in range( 1, int(np.ceil(train.dataset_sz() / args.batch_sz)) + 1): net.train() imgs, masks, one_hot_masks, centers, _, _, _, _ = train.next_batch( args.batch_sz) imgs = make_batch_input(imgs) imgs = torch.cuda.FloatTensor(imgs) one_hot_masks = torch.cuda.FloatTensor(one_hot_masks) pred_logit = net(imgs) pred_softmax = F.softmax(pred_logit, dim=1) if args.use_ce: ce = torch.nn.CrossEntropyLoss() loss = ce(pred_logit, torch.cuda.LongTensor(masks)) else: loss = dice_loss(pred_softmax, one_hot_masks, keep_background=False).mean() scheduler_warmup.step() optimizer.zero_grad() loss.backward() optimizer.step() if step % args.log_freq == 0: print( f"step={step}\tepoch={epoch}\titer={iteration}\tloss={loss.data.cpu().numpy()}" ) writer.add_scalar("cnn_dice_loss", loss.data.cpu().numpy(), step) writer.add_scalar("lr", optimizer.param_groups[0]["lr"], step) if step % args.train_eval_freq == 0: train_dice, cls_train_dices = do_eval(net, train.images, train.onehot_masks, args.batch_sz, num_cls) train_dice = train_dice.cpu().numpy() cls_train_dices = cls_train_dices.cpu().numpy() writer.add_scalar("train_dice", train_dice, step) # lr_sched.step(1-train_dice) for j, cls_train_dice in enumerate(cls_train_dices): writer.add_scalar(f"train_dice/{j}", cls_train_dice, step) print( f"step={step}\tepoch={epoch}\titer={iteration}\ttrain_eval: train_dice={train_dice}" ) if step % args.val_eval_freq == 0: _pickle.dump( net.state_dict(), open(os.path.join(args.log_dir, 'model.pth.tar'), 'wb')) val_dice, cls_val_dices = do_eval(net, valid.images, valid.onehot_masks, args.batch_sz, num_cls) val_dice = val_dice.cpu().numpy() cls_val_dices = cls_val_dices.cpu().numpy() writer.add_scalar("val_dice", val_dice, step) for j, cls_val_dice in enumerate(cls_val_dices): writer.add_scalar(f"val_dice/{j}", cls_val_dice, step) print( f"step={step}\tepoch={epoch}\titer={iteration}\tvalid_dice={val_dice}" ) if val_dice > best_val_dice: best_val_dice = val_dice best_cls_val_dices = cls_val_dices best_net.load_state_dict(net.state_dict().copy()) _pickle.dump( best_net.state_dict(), open(os.path.join(args.log_dir, 'best_model.pth.tar'), 'wb')) f = open( os.path.join(args.log_dir, f"best_val_dice{step}.txt"), 'w') f.write(str(best_val_dice) + "\n") f.write(" ".join([ str(dice_score) for dice_score in best_cls_val_dices ])) f.close() print(f"better val dice detected.") # if step % 5000 == 0: # _pickle.dump(net.state_dict(), open(os.path.join(args.log_dir, '{}.pth.tar'.format(step)), # 'wb')) step += 1 return best_val_dice, best_cls_val_dices
ax2.title.set_text('probability prediction') plt.show(block=False) plt.pause(0.01) def show_gamma(self): plt.figure(3) plt.subplot(1, 1, 1) plt.imshow(self.gamma[0]) plt.title('Gamma') plt.show(block=False) plt.pause(0.01) def show_s(self): plt.figure(4) plt.subplot(1, 1, 1) plt.imshow(self.s[0]) plt.show(block=False) plt.pause(0.01) if __name__ == "__main__": net = UNet(num_classes=2) net_ = networks(net, 10, 100) for i in xrange(10): # print(net_) limage = torch.randn(1, 1, 256, 256) uimage = torch.randn(1, 1, 256, 256) lmask = torch.randint(0, 2, (1, 256, 256), dtype=torch.long) net_.update((limage, lmask), uimage)
util.mix_voice_noise(voicedir, noisedir, mixeddir, num_data, fs=16000) # get list each of which the path name is written voicepath_list = util.get_wavlist(voicedir) mixedpath_list = util.get_wavlist(mixeddir) # make spectrogram (n x F x T x 1) V = util.make_dataset(voicepath_list, fftsize, hopsize, nbit) X = util.make_dataset(mixedpath_list, fftsize, hopsize, nbit) #%% model training height, width = fftsize // 2, fftsize // 2 #CNN height x width X_train = X[:, :height, :width, ...] #voice + noise Y_train = V[:, :height, :width, ...] #noise only num_filt_first = 16 unet = UNet(height, width, num_filt_first) model = unet.get_model() model.compile(optimizer='adam', loss='mean_squared_error') history = model.fit(X_train, Y_train, epochs=5, batch_size=32) #%% model testing absY, phsY, max_Y, min_Y = util.make_spectrogram(mixpath_list[0], fftsize, hopsize, nbit, istest=True) P = np.squeeze(model.predict(absY[np.newaxis, :height, :width, ...])) P = np.hstack((P, absY[:height, width:])) #t-axis P = np.vstack((P, absY[height, :])) #f-axis Y = (absY * (max_Y - min_Y) + min_Y) * phsY y = librosa.core.istft(absY * phsY, hop_length=hopsize, win_length=fftsize)
print("debug:", DEBUG) if DEBUG: task = 'DEBUG' + task num_train = 10 num_val = 2 save_model_freq = 1 # set up the model and define the graph with tf.variable_scope(tf.get_variable_scope()): input=tf.placeholder(tf.float32,shape=[None,None,None,5]) reflection=tf.placeholder(tf.float32,shape=[None,None,None,5]) target=tf.placeholder(tf.float32,shape=[None,None,None,5]) overexp_mask = utils.tf_overexp_mask(input) tf_input, tf_reflection, tf_target, real_input = utils.prepare_real_input(input, target, reflection, overexp_mask, ARGS) reflection_layer=UNet(real_input, ext='Ref_') #real_reflect = build_one_hyper(reflection_layer[...,4:5]) transmission_layer=UNet(tf.concat([real_input, reflection_layer],axis=3),ext='Tran_') lossDict = {} lossDict["percep_t"]=0.2*loss.compute_percep_loss(0.5 * tf_target[...,4:5], 0.5*transmission_layer[...,4:5], overexp_mask, reuse=False ) lossDict["percep_r"]=0.2*loss.compute_percep_loss(0.5 * tf_reflection[...,4:5], 0.5*reflection_layer[...,4:5], overexp_mask, reuse=True) lossDict["pncc"] = 6*loss.compute_percep_ncc_loss(tf.multiply(0.5*transmission_layer[...,4:5],overexp_mask), tf.multiply(0.5*reflection_layer[...,4:5],overexp_mask)) lossDict["reconstruct"]= loss.mask_reconstruct_loss(tf_input[...,4:5], transmission_layer[...,4:5], reflection_layer[...,4:5], overexp_mask) lossDict["reflection"] = lossDict["percep_r"] lossDict["transmission"]=lossDict["percep_t"] lossDict["all_loss"] = lossDict["reflection"] + lossDict["transmission"] + lossDict["pncc"]
def score_data(input_folder, output_folder, model_path, args, do_postprocessing=False, gt_exists=True, evaluate_all=False, random_center_ratio=None): num_classes = args.num_cls nx, ny = exp_config.image_size[:2] batch_size = 1 num_channels = num_classes + 1 net = UNet(in_dim=1, out_dim=4).cuda() ckpt_path = os.path.join(model_path, 'best_model.pth.tar') net.load_state_dict(_pickle.load(open(ckpt_path, 'rb'))[0]) if args.unet_ckpt: pretrained_unet = UNet(in_dim=1, out_dim=4).cuda() pretrained_unet.load_state_dict( _pickle.load(open(args.unet_ckpt, 'rb'))) snake = SnakePytorch(args.delta, 1, args.num_lines, args.radius) evaluate_test_set = not gt_exists total_time = 0 total_volumes = 0 for folder in os.listdir(input_folder): folder_path = os.path.join(input_folder, folder) if os.path.isdir(folder_path): if evaluate_test_set or evaluate_all: train_test = 'test' # always test else: train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train' if train_test == 'test': infos = {} for line in open(os.path.join(folder_path, 'Info.cfg')): label, value = line.split(':') infos[label] = value.rstrip('\n').lstrip(' ') patient_id = folder.lstrip('patient') ED_frame = int(infos['ED']) ES_frame = int(infos['ES']) for file in glob.glob( os.path.join(folder_path, 'patient???_frame??.nii.gz')): logging.info( ' ----- Doing image: -------------------------') logging.info('Doing: %s' % file) logging.info( ' --------------------------------------------') file_base = file.split('.nii.gz')[0] frame = int(file_base.split('frame')[-1]) img_dat = utils.load_nii(file) img = img_dat[0].copy() img = image_utils.normalise_image(img) if gt_exists: file_mask = file_base + '_gt.nii.gz' mask_dat = utils.load_nii(file_mask) mask = mask_dat[0] start_time = time.time() pixel_size = (img_dat[2].structarr['pixdim'][1], img_dat[2].structarr['pixdim'][2]) scale_vector = (pixel_size[0] / exp_config.target_resolution[0], pixel_size[1] / exp_config.target_resolution[1]) predictions = [] for zz in range(img.shape[2]): slice_img = np.squeeze(img[:, :, zz]) slice_rescaled = transform.rescale(slice_img, scale_vector, order=1, preserve_range=True, multichannel=False, mode='constant') x, y = slice_rescaled.shape slice_cropped, x_s, y_s, x_c, y_c = get_slice( slice_rescaled, nx, ny) # GET PREDICTION network_input = np.float32( np.tile(np.reshape(slice_cropped, (nx, ny, 1)), (batch_size, 1, 1, 1))) network_input = np.transpose(network_input, [0, 3, 1, 2]) network_input = torch.cuda.FloatTensor(network_input) with torch.no_grad(): net.eval() logit = net(network_input) # get the center if args.unet_ckpt != '': unet_mask = torch.argmax( pretrained_unet(network_input), dim=1).data.cpu().numpy()[0] else: assert gt_exists mask_copy = mask[:, :, zz].copy() unet_mask = get_slice(mask_copy, nx, ny)[0] unet_mask = image_utils.keep_largest_connected_components( unet_mask) from data_iterator import get_center_of_mass if num_classes == 2: lv_center = get_center_of_mass(unet_mask, [3]) mo_center = get_center_of_mass(unet_mask, [2]) else: lv_center = get_center_of_mass(unet_mask, [3]) mo_center = np.asarray([[np.nan, np.nan]]) lv_center = np.asarray(lv_center) mo_center = np.asarray(mo_center) lv_mask = np.zeros((nx, ny)) if not np.isnan(lv_center[0, 0]): if random_center_ratio: dt, _ = get_distance_transform( unet_mask == 3, None) max_radius = dt[0, int(lv_center[0][0]), int(lv_center[0][1])] radius = int(max_radius * random_center_ratio) c_j, _ = get_random_jitter(radius, 0) else: c_j = None lv_logit, _, _ = get_star_pattern_values( logit[:, 3, ...], None, lv_center, args.num_lines, args.radius + 1, center_jitters=c_j) lv_gs = lv_logit[:, :, 1:] - lv_logit[:, :, : -1] # compute the gradient # run DP algo # can only put batch with fixed shape into the snake algorithm lv_ind = snake(lv_gs).data.cpu().numpy() lv_ind = np.expand_dims( smooth_ind(lv_ind.squeeze(-1), args.smoothing_window), -1) lv_mask = star_pattern_ind_to_mask( lv_ind, lv_center, nx, ny, args.num_lines, args.radius) if num_classes == 1: pred_mask = lv_mask * 3 else: mo_mask = np.zeros((nx, ny)) if not np.isnan(mo_center[0]): c_j = None mo_logit, _, _ = get_star_pattern_values( logit[:, 2, ...], None, lv_center, args.num_lines, args.radius + 1, center_jitters=c_j) mo_gs = mo_logit[:, :, 1:] - mo_logit[:, :, : -1] # compute the gradient mo_ind = snake(mo_gs).data.cpu().numpy() mo_ind = mo_ind[:len(mo_gs), ...] mo_ind = np.expand_dims( smooth_ind(mo_ind.squeeze(-1), args.smoothing_window), -1) mo_mask = star_pattern_ind_to_mask( mo_ind, lv_center, nx, ny, args.num_lines, args.radius) pred_mask = lv_mask * 3 + ( 1 - lv_mask ) * mo_mask * 2 # 3 is lv class, 2 is mo class prediction_cropped = pred_mask.squeeze() # ASSEMBLE BACK THE SLICES prediction = np.zeros((x, y)) # insert cropped region into original image again if x > nx and y > ny: prediction[x_s:x_s + nx, y_s:y_s + ny] = prediction_cropped else: if x <= nx and y > ny: prediction[:, y_s:y_s + ny] = prediction_cropped[x_c:x_c + x, :] elif x > nx and y <= ny: prediction[ x_s:x_s + nx, :] = prediction_cropped[:, y_c:y_c + y] else: prediction[:, :] = prediction_cropped[x_c:x_c + x, y_c:y_c + y] # RESCALING ON THE LOGITS if gt_exists: prediction = transform.resize( prediction, (mask.shape[0], mask.shape[1]), order=0, preserve_range=True, mode='constant') else: # This can occasionally lead to wrong volume size, therefore if gt_exists # we use the gt mask size for resizing. prediction = transform.rescale( prediction, (1.0 / scale_vector[0], 1.0 / scale_vector[1]), order=0, preserve_range=True, multichannel=False, mode='constant') # prediction = np.uint8(np.argmax(prediction, axis=-1)) prediction = np.uint8(prediction) predictions.append(prediction) gt_binary = (mask[..., zz] == 3) * 1 pred_binary = (prediction == 3) * 1 from medpy.metric.binary import hd, dc, assd lv_center = lv_center[0] # i=0; plt.imshow(network_input[0, 0]); plt.plot(lv_center[1], lv_center[0], 'ro'); plt.show(); plt.imshow(unet_mask); plt.plot(lv_center[1], lv_center[0], 'ro'); plt.show(); plt.imshow(logit[0, 0]); plt.plot(lv_center[1], lv_center[0], 'ro'); plt.show(); plt.imshow(lv_logit[0]); plt.show(); plt.imshow(lv_gs[0]); plt.show(); plt.imshow(prediction_cropped); plt.plot(lv_center[1], lv_center[0], 'r.'); plt.show(); prediction_arr = np.transpose( np.asarray(predictions, dtype=np.uint8), (1, 2, 0)) # This is the same for 2D and 3D again if do_postprocessing: prediction_arr = image_utils.keep_largest_connected_components( prediction_arr) elapsed_time = time.time() - start_time total_time += elapsed_time total_volumes += 1 logging.info('Evaluation of volume took %f secs.' % elapsed_time) if frame == ED_frame: frame_suffix = '_ED' elif frame == ES_frame: frame_suffix = '_ES' else: raise ValueError( 'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d' % (frame, ED_frame, ES_frame)) # Save prediced mask out_file_name = os.path.join( output_folder, 'prediction', 'patient' + patient_id + frame_suffix + '.nii.gz') if gt_exists: out_affine = mask_dat[1] out_header = mask_dat[2] else: out_affine = img_dat[1] out_header = img_dat[2] logging.info('saving to: %s' % out_file_name) utils.save_nii(out_file_name, prediction_arr, out_affine, out_header) # Save image data to the same folder for convenience image_file_name = os.path.join( output_folder, 'image', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % image_file_name) utils.save_nii(image_file_name, img_dat[0], out_affine, out_header) if gt_exists: # Save GT image gt_file_name = os.path.join( output_folder, 'ground_truth', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % gt_file_name) utils.save_nii(gt_file_name, mask, out_affine, out_header) # Save difference mask between predictions and ground truth difference_mask = np.where( np.abs(prediction_arr - mask) > 0, [1], [0]) difference_mask = np.asarray(difference_mask, dtype=np.uint8) diff_file_name = os.path.join( output_folder, 'difference', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % diff_file_name) utils.save_nii(diff_file_name, difference_mask, out_affine, out_header) logging.info('Average time per volume: %f' % (total_time / total_volumes)) return None
from network import UNet import numpy as np from numpy import * import random import matplotlib.pyplot as plt from PIL import Image from sklearn.cluster import KMeans from sklearn.externals import joblib from sklearn import cluster import cv2 cmap = plt.cm.jet #os.environ["CUDA_VISIBLE_DEVICES"] = "1" # use single GPU args = utils.parse_command() print(args) model = UNet(3, 1) model = nn.DataParallel(model, device_ids=[0]) model.cuda() # if setting gpu id, the using single GPU print('Single GPU Mode.') def create_loader(args): root_dir = '' test_set = KittiFolder(root_dir, mode='test', size=(256, 512)) test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True) return test_loader
valid_loader = data.DataLoader( dataset=DataFolder('dataset/valid_images_256/', 'dataset/valid_masks_256/', 'validation'), batch_size=args.eval_batch_size, shuffle=False, num_workers=2 ) eval_loader = data.DataLoader( dataset=DataFolder('dataset/eval_images_256/', 'dataset/eval_masks_256/', 'evaluate'), batch_size=args.eval_batch_size, shuffle=False, num_workers=2 ) model = UNet(1, shrink=1).cuda() nets = [model] params = [{'params': net.parameters()} for net in nets] solver = optim.Adam(params, lr=args.lr) criterion = nn.CrossEntropyLoss() es = EarlyStopping(min_delta=args.min_delta, patience=args.patience) for epoch in range(1, args.epochs+1): train_loss = [] valid_loss = [] for batch_idx, (img, mask, _) in enumerate(train_loader): solver.zero_grad()
def score_data(input_folder, output_folder, model_path, num_classes=3, do_postprocessing=False, gt_exists=True, evaluate_all=False): nx, ny = exp_config.image_size[:2] batch_size = 1 num_channels = num_classes + 1 net = UNet(in_dim=1, out_dim=num_classes + 1).cuda() ckpt_path = os.path.join(model_path, 'best_model.pth.tar') net.load_state_dict(_pickle.load(open(ckpt_path, 'rb'))) evaluate_test_set = not gt_exists total_time = 0 total_volumes = 0 for folder in os.listdir(input_folder): folder_path = os.path.join(input_folder, folder) if os.path.isdir(folder_path): if evaluate_test_set or evaluate_all: train_test = 'test' # always test else: train_test = 'test' if (int(folder[-3:]) % 5 == 0) else 'train' if train_test == 'test': infos = {} for line in open(os.path.join(folder_path, 'Info.cfg')): label, value = line.split(':') infos[label] = value.rstrip('\n').lstrip(' ') patient_id = folder.lstrip('patient') ED_frame = int(infos['ED']) ES_frame = int(infos['ES']) for file in glob.glob( os.path.join(folder_path, 'patient???_frame??.nii.gz')): logging.info( ' ----- Doing image: -------------------------') logging.info('Doing: %s' % file) logging.info( ' --------------------------------------------') file_base = file.split('.nii.gz')[0] frame = int(file_base.split('frame')[-1]) img_dat = utils.load_nii(file) img = img_dat[0].copy() img = image_utils.normalise_image(img) if gt_exists: file_mask = file_base + '_gt.nii.gz' mask_dat = utils.load_nii(file_mask) mask = mask_dat[0] start_time = time.time() pixel_size = (img_dat[2].structarr['pixdim'][1], img_dat[2].structarr['pixdim'][2]) scale_vector = (pixel_size[0] / exp_config.target_resolution[0], pixel_size[1] / exp_config.target_resolution[1]) predictions = [] for zz in range(img.shape[2]): slice_img = np.squeeze(img[:, :, zz]) slice_rescaled = transform.rescale(slice_img, scale_vector, order=1, preserve_range=True, multichannel=False, mode='constant') x, y = slice_rescaled.shape x_s = (x - nx) // 2 y_s = (y - ny) // 2 x_c = (nx - x) // 2 y_c = (ny - y) // 2 # Crop section of image for prediction if x > nx and y > ny: slice_cropped = slice_rescaled[x_s:x_s + nx, y_s:y_s + ny] else: slice_cropped = np.zeros((nx, ny)) if x <= nx and y > ny: slice_cropped[x_c:x_c + x, :] = slice_rescaled[:, y_s:y_s + ny] elif x > nx and y <= ny: slice_cropped[:, y_c:y_c + y] = slice_rescaled[x_s:x_s + nx, :] else: slice_cropped[x_c:x_c + x, y_c:y_c + y] = slice_rescaled[:, :] # GET PREDICTION network_input = np.float32( np.tile(np.reshape(slice_cropped, (nx, ny, 1)), (batch_size, 1, 1, 1))) network_input = np.transpose(network_input, [0, 3, 1, 2]) network_input = torch.cuda.FloatTensor(network_input) with torch.no_grad(): net.eval() logits_out = net(network_input) softmax_out = F.softmax(logits_out, dim=1) # mask_out = torch.argmax(logits_out, dim=1) softmax_out = softmax_out.data.cpu().numpy() softmax_out = np.transpose(softmax_out, [0, 2, 3, 1]) # prediction_cropped = np.squeeze(softmax_out[0,...]) prediction_cropped = np.squeeze(softmax_out) # ASSEMBLE BACK THE SLICES slice_predictions = np.zeros((x, y, num_channels)) # insert cropped region into original image again if x > nx and y > ny: slice_predictions[x_s:x_s + nx, y_s:y_s + ny, :] = prediction_cropped else: if x <= nx and y > ny: slice_predictions[:, y_s:y_s + ny, :] = prediction_cropped[ x_c:x_c + x, :, :] elif x > nx and y <= ny: slice_predictions[ x_s:x_s + nx, :, :] = prediction_cropped[:, y_c:y_c + y, :] else: slice_predictions[:, :, :] = prediction_cropped[ x_c:x_c + x, y_c:y_c + y, :] # RESCALING ON THE LOGITS if gt_exists: prediction = transform.resize( slice_predictions, (mask.shape[0], mask.shape[1], num_channels), order=1, preserve_range=True, mode='constant') else: # This can occasionally lead to wrong volume size, therefore if gt_exists # we use the gt mask size for resizing. prediction = transform.rescale( slice_predictions, (1.0 / scale_vector[0], 1.0 / scale_vector[1], 1), order=1, preserve_range=True, multichannel=False, mode='constant') # prediction = transform.resize(slice_predictions, # (mask.shape[0], mask.shape[1], num_channels), # order=1, # preserve_range=True, # mode='constant') prediction = np.uint8(np.argmax(prediction, axis=-1)) if num_classes == 1: prediction[prediction == 1] = 3 elif num_classes == 2: prediction[prediction == 2] = 3 prediction[prediction == 1] = 2 predictions.append(prediction) prediction_arr = np.transpose( np.asarray(predictions, dtype=np.uint8), (1, 2, 0)) # This is the same for 2D and 3D again if do_postprocessing: assert num_classes == 1 from skimage.measure import regionprops lv_obj = (mask_dat[0] == 3).astype(np.uint8) prop = regionprops(lv_obj) assert len(prop) == 1 prop = prop[0] centroid = prop.centroid centroid = (int(centroid[0]), int(centroid[1]), int(centroid[2])) prediction_arr = image_utils.keep_largest_connected_components( prediction_arr, centroid) elapsed_time = time.time() - start_time total_time += elapsed_time total_volumes += 1 logging.info('Evaluation of volume took %f secs.' % elapsed_time) if frame == ED_frame: frame_suffix = '_ED' elif frame == ES_frame: frame_suffix = '_ES' else: raise ValueError( 'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d' % (frame, ED_frame, ES_frame)) # Save prediced mask out_file_name = os.path.join( output_folder, 'prediction', 'patient' + patient_id + frame_suffix + '.nii.gz') if gt_exists: out_affine = mask_dat[1] out_header = mask_dat[2] else: out_affine = img_dat[1] out_header = img_dat[2] logging.info('saving to: %s' % out_file_name) utils.save_nii(out_file_name, prediction_arr, out_affine, out_header) # Save image data to the same folder for convenience image_file_name = os.path.join( output_folder, 'image', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % image_file_name) utils.save_nii(image_file_name, img_dat[0], out_affine, out_header) if gt_exists: # Save GT image gt_file_name = os.path.join( output_folder, 'ground_truth', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % gt_file_name) utils.save_nii(gt_file_name, mask, out_affine, out_header) # Save difference mask between predictions and ground truth difference_mask = np.where( np.abs(prediction_arr - mask) > 0, [1], [0]) difference_mask = np.asarray(difference_mask, dtype=np.uint8) diff_file_name = os.path.join( output_folder, 'difference', 'patient' + patient_id + frame_suffix + '.nii.gz') logging.info('saving to: %s' % diff_file_name) utils.save_nii(diff_file_name, difference_mask, out_affine, out_header) logging.info('Average time per volume: %f' % (total_time / total_volumes)) return None
os.environ["CUDA_VISIBLE_DEVICES"]=str(np.argmax( [int(x.split()[2]) for x in subprocess.Popen("nvidia-smi -q -d Memory | grep -A4 GPU | grep Free", shell=True, stdout=subprocess.PIPE).stdout.readlines()])) else: os.environ["CUDA_VISIBLE_DEVICES"]='' test_names= sorted(glob(ARGS.test_dir + "/*png")) print('Data load succeed!') # set up the model and define the graph with tf.variable_scope(tf.get_variable_scope()): input=tf.placeholder(tf.float32,shape=[None,None,None,5]) reflection=tf.placeholder(tf.float32,shape=[None,None,None,5]) target=tf.placeholder(tf.float32,shape=[None,None,None,5]) overexp_mask = utils.tf_overexp_mask(input) tf_input, tf_reflection, tf_target, real_input = utils.prepare_real_input(input, target, reflection, overexp_mask, ARGS) reflection_layer=UNet(real_input, ext='Ref_') transmission_layer = UNet(tf.concat([real_input, reflection_layer],axis=3),ext='Tran_') ######### Session ######### saver=tf.train.Saver(max_to_keep=10) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess=tf.Session(config=config) sess.run(tf.global_variables_initializer()) var_restore = [v for v in tf.trainable_variables()] saver_restore=tf.train.Saver(var_restore) for var in tf.trainable_variables(): print("Listing trainable variables ... ") print(var)
print(f'Dice loss in step {step} is {dice_loss}') for i in range(len(label)): mask = pred[i, 0, :, :].data.cpu().numpy() mask = np.where(mask > 0.4, 1, 0) name = batch['name'][i] img = sitk.ReadImage(os.path.join(img_path, train_phrase, name)) img = sitk.GetArrayFromImage(img) display(name.split('.')[0], mask, mask) # display(name.split('.')[0], mask, img) pass pass aver_dice = aver_dice / len(dataloader) print(f'average dice is {aver_dice}.') pass if __name__ == "__main__": dataset = EmbDataset(train_phrase='train') channels_in = len(dataset.model_set) + 1 dataloader = DataLoader(dataset, batch_size=3, shuffle=False) state = torch.load( '/home/zhangqianru/data/ly/ckpt_folder/retrain_2/epoch4.pth') epoch = state['epoch'] print(f'Load epoch {epoch}.') net = UNet(channels_in, 1) net.load_state_dict(state['net']) eval_net(net, dataloader, dataset.train_phrase, save_path='/home/zhangqianru/') pass
from data_loader import loaders from train import my_train from eval import my_eval from visualize import my_vis from app import my_app import numpy as np import ray #data loading batch_size = 2 loader_tr = loaders(batch_size, 0) loader_vl = loaders(batch_size, 1) #networks output = UNet() output.cuda() #optimizer optimizer = optim.Adam(output.parameters(), lr=.00003, weight_decay=1e-4) #training metric_values, metric1_values, val_metric_values, val_metric1_values, epoch_values, loss_values, val_loss_values = ( [] for i in range(7)) no_of_epochs = 1000 no_of_batches = len(loader_tr) no_of_batches_1 = len(loader_vl)