def train(): args = setup_run_arguments() # args = parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[INFO] Initializing UNet-model using: {device}") net = UNet(n_channels=args.n_channels, n_classes=args.n_classes, bilinear=True) if args.from_pretrained: net.load_state_dict(torch.load(args.from_pretrained, map_location=device)) net.to(device=device) training_loop.run(network=net, epochs=args.epochs, batch_size=args.batch_size, lr=args.learning_rate, device=device, n_classes=args.n_classes, val_percent=args.val_percent, image_dir=args.image_dir, mask_dir=args.mask_dir, checkpoint_path=args.checkpoint_path, loss=args.loss, num_workers=args.num_workers )
def detect_noise_regions(image, args): # load noise segmentation network (U-Net) unet_model_path = os.path.join(args.checkpoints, 'unet', 'UNet.pth') net = UNet(n_channels=3, n_classes=1).to(device) net.load_state_dict(torch.load(unet_model_path)) net.eval() # predict noise regions predict = predict_img(net, device, image) # search inpaint patches patches, labels, _, absolute_position, relative_position = search_inpaint_area(np.array(image), np.array(predict.convert('RGB'))) # save inpaint patches patches_dir = os.path.join(args.output, 'patches') labels_dir = os.path.join(args.output, 'labels') os.makedirs(patches_dir, exist_ok=True) os.makedirs(labels_dir, exist_ok=True) filename = os.path.basename(args.input).split('.')[0] counter = 0 for patch, label in zip(patches, labels): Image.fromarray(patch).save(os.path.join(patches_dir, '{}-{:0>3d}.png'.format(filename, counter))) Image.fromarray(label).save(os.path.join(labels_dir, '{}-{:0>3d}.png'.format(filename, counter))) counter += 1 return patches_dir, labels_dir, absolute_position, relative_position
def validate(state_dict_path, use_gpu, device): model = UNet(n_channels=1, n_classes=2) model.load_state_dict(torch.load(state_dict_path, map_location='cpu' if not use_gpu else device)) model.to(device) val_transforms = transforms.Compose([ ToTensor(), NormalizeBRATS()]) BraTS_val_ds = BRATS2018('./BRATS2018',\ data_set='val',\ seg_type='et',\ scan_type='t1ce',\ transform=val_transforms) data_loader = DataLoader(BraTS_val_ds, batch_size=2, shuffle=False, num_workers=0) running_dice_score = 0. for batch_ind, batch in enumerate(data_loader): imgs, targets = batch imgs = imgs.to(device) targets = targets.to(device) model.eval() with torch.no_grad(): outputs = model(imgs) preds = torch.argmax(F.softmax(outputs, dim=1), dim=1) running_dice_score += dice_score(preds, targets) * targets.size(0) print('running dice score: {:.6f}'.format(running_dice_score)) dice = running_dice_score / len(BraTS_val_ds) print('mean dice score of the validating set: {:.6f}'.format(dice))
def load_finetuned_model(args, baseline_model): """ :param args: :param baseline_model: :return: """ # augment_net = Net(0, 0.0, 32, 3, 0.0, num_classes=32**2 * 3, do_res=True) augment_net = UNet(in_channels=3, n_classes=3, depth=1, wf=2, padding=True, batch_norm=False, do_noise_channel=True, up_mode='upsample', use_identity_residual=True) # TODO(PV): Initialize UNet properly # TODO (JON): DEPTH 1 WORKED WELL. Changed upconv to upsample. Use a wf of 2. # This ResNet outputs scalar weights to be applied element-wise to the per-example losses from models.simple_models import CNN, Net imsize, in_channel, num_classes = 32, 3, 10 reweighting_net = Net(0, 0.0, imsize, in_channel, 0.0, num_classes=1) #resnet_cifar.resnet20(num_classes=1) if args.load_finetune_checkpoint: checkpoint = torch.load(args.load_finetune_checkpoint) baseline_model.load_state_dict(checkpoint['elementary_model_state_dict']) augment_net.load_state_dict(checkpoint['augment_model_state_dict']) try: reweighting_net.load_state_dict(checkpoint['reweighting_model_state_dict']) except KeyError: pass augment_net, reweighting_net, baseline_model = augment_net.cuda(), reweighting_net.cuda(), baseline_model.cuda() augment_net.train(), reweighting_net.train(), baseline_model.train() return augment_net, reweighting_net, baseline_model
class EventGANBase(object): def __init__(self, options): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.generator = UNet(num_input_channels=2*options.n_image_channels, num_output_channels=options.n_time_bins * 2, skip_type='concat', activation='relu', num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm='BN', use_upsample_conv=True, with_activation=True, sn=options.sn, multi=False) latest_checkpoint = get_latest_checkpoint(options.checkpoint_dir) checkpoint = torch.load(latest_checkpoint) self.generator.load_state_dict(checkpoint["gen"]) self.generator.to(self.device) def forward(self, images, is_train=False): if len(images.shape) == 3: images = images[None, ...] assert len(images.shape) == 4 and images.shape[1] == 2, \ "Input images must be either 2xHxW or Bx2xHxW." if not is_train: with torch.no_grad(): self.generator.eval() event_volume = self.generator(images) self.generator.train() else: event_volume = self.generator(images) return event_volume
def prediction_to_json(image_path, chkp_path, net=None) -> dict: """ Convert mask prediction to json. The format matches the format in the training annotation data: {'filename':file_name, 'labels': [{'name': label_name, 'annotations': [{'id':some_unique_integer_id, 'segmentation':[x,y,x,y,x,y....]} ....] } ....] } """ file_name = os.path.basename(image_path) annotation = {'filename': file_name, 'labels': []} device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if not net: net = UNet(n_channels=3, n_classes=4) net.to(device=device) net.load_state_dict(torch.load(chkp_path, map_location=device)) img = Image.open(image_path) msk = predict_on_image(net=net, device=device, src_img=img) msk = msk.transpose((1, 2, 0)) h, w, n_labels = msk.shape rgb_mask = np.ones((h, w, 3), dtype=np.uint8) annotation['height'] = h annotation['width'] = w for label in range(1, n_labels): color = hex_labels[str(label)] category = category_labels[str(label)] c_label = {'color': color, 'name': category, 'annotations': []} label_mask = msk[:, :, label].astype(int).astype(np.uint8) contours, hierarchy = cv2.findContours(label_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: vector_points = [] for x, y in contour.reshape((len(contour), 2)): vector_points += [float(x), float(y)] c_label['annotations'].append({'segmentation': vector_points}) idx = np.where(msk[:, :, label].astype(int) == 1) rgb_mask[idx] = colors_from_hex[str(label)] annotation['labels'].append(c_label) return annotation
class Visualizer(object): def __init__(self, input_topic, output_topic, resize_width, resize_height, model_path, force_cpu): self.bridge = CvBridge() self.graph = UNet([3, resize_width, resize_height], 3) self.graph.load_state_dict(torch.load(model_path)) self.force_cpu = force_cpu and torch.cuda.is_available() self.resize_width, self.resize_height = resize_width, resize_height if not self.force_cpu: self.graph.cuda() self.graph.eval() self.to_tensor = transforms.Compose([transforms.ToTensor()]) self.publisher = rospy.Publisher(output_topic, ImMsg, queue_size=1) self.raw_subscriber = rospy.Subscriber(input_topic, CompressedImage, self.image_cb, queue_size=1, buff_size=10**8) def convert_to_tensor(self, image): np_arr = np.fromstring(image.data, np.uint8) image_np = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) image_np = cv2.resize(image_np, dsize=(self.resize_width, self.resize_height)) img_to_tensor = PIL.Image.fromarray(image_np) img_tensor = self.to_tensor(img_to_tensor) if not self.force_cpu: return Variable(img_tensor.unsqueeze(0)).cuda() else: return Variable(img_tensor.unsqueeze(0)) def image_cb(self, image): img_tensor = self.convert_to_tensor(image) # Inference output = self.graph(img_tensor) output_data = output.cpu().data.numpy()[0][0] # # Convert from 32fc1 (0 - 1) to 8uc1 (0 - 255) cv_output = np.uint8(255 * output_data) cv_output = cv2.applyColorMap(cv_output, cv2.COLORMAP_JET) # Convert to ROS message to publish msg_out = self.bridge.cv2_to_imgmsg(cv_output, 'bgr8') msg_out.header.stamp = image.header.stamp self.publisher.publish(msg_out)
def load_finetuned_model(self, baseline_model): """ Loads the augmentation net, sample reweighting net, and baseline model Note: sets all these models to train mode """ # augment_net = Net(0, 0.0, 32, 3, 0.0, num_classes=32**2 * 3, do_res=True) if self.args.dataset == DATASET_MNIST: imsize, in_channel, num_classes = 28, 1, 10 else: imsize, in_channel, num_classes = 32, 3, 10 augment_net = UNet( in_channels=in_channel, n_classes=in_channel, depth=2, wf=3, padding=True, batch_norm=False, do_noise_channel=True, up_mode='upconv', use_identity_residual=True) # TODO(PV): Initialize UNet properly # TODO (JON): DEPTH 1 WORKED WELL. Changed upconv to upsample. Use a wf of 2. # This ResNet outputs scalar weights to be applied element-wise to the per-example losses reweighting_net = Net(1, 0.0, imsize, in_channel, 0.0, num_classes=1) # resnet_cifar.resnet20(num_classes=1) if self.args.load_finetune_checkpoint: checkpoint = torch.load(self.args.load_finetune_checkpoint) # temp_baseline_model = baseline_model # baseline_model.load_state_dict(checkpoint['elementary_model_state_dict']) if 'weight_decay' in checkpoint: baseline_model.weight_decay = checkpoint['weight_decay'] # baseline_model.weight_decay = temp_baseline_model.weight_decay # baseline_model.load_state_dict(checkpoint['elementary_model_state_dict']) augment_net.load_state_dict(checkpoint['augment_model_state_dict']) try: reweighting_net.load_state_dict( checkpoint['reweighting_model_state_dict']) except KeyError: pass augment_net, reweighting_net, baseline_model = augment_net.to( self.device), reweighting_net.to(self.device), baseline_model.to( self.device) augment_net.train(), reweighting_net.train(), baseline_model.train() return augment_net, reweighting_net, baseline_model
def train(frame_num, layer_nums, input_channels, output_channels, discriminator_num_filters, bn=False, pretrain=False, generator_pretrain_path=None, discriminator_pretrain_path=None): generator = UNet(n_channels=input_channels, layer_nums=layer_nums, output_channel=output_channels, bn=bn) discriminator = PixelDiscriminator(output_channels, discriminator_num_filters, use_norm=False) generator = generator.cuda() discriminator = discriminator.cuda() flow_network = Network() flow_network.load_state_dict(torch.load(lite_flow_model_path)) flow_network.cuda().eval() adversarial_loss = Adversarial_Loss().cuda() discriminate_loss = Discriminate_Loss().cuda() gd_loss = Gradient_Loss(alpha, num_channels).cuda() op_loss = Flow_Loss().cuda() int_loss = Intensity_Loss(l_num).cuda() step = 0 if not pretrain: generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) else: assert (generator_pretrain_path != None and discriminator_pretrain_path != None) generator.load_state_dict(torch.load(generator_pretrain_path)) discriminator.load_state_dict(torch.load(discriminator_pretrain_path)) step = int(generator_pretrain_path.split('-')[-1]) print('pretrained model loaded!') print('initializing the model with Generator-Unet {} layers,' 'PixelDiscriminator with filters {} '.format( layer_nums, discriminator_num_filters)) optimizer_G = torch.optim.Adam(generator.parameters(), lr=g_lr) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=d_lr) writer = SummaryWriter(writer_path) dataset = img_dataset.ano_pred_Dataset(training_data_folder, frame_num) dataset_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True) test_dataset = img_dataset.ano_pred_Dataset(testing_data_folder, frame_num) test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True) for epoch in range(epochs): for (input, _), (test_input, _) in zip(dataset_loader, test_dataloader): # generator = generator.train() # discriminator = discriminator.train() target = input[:, -1, :, :, :].cuda() input = input[:, :-1, ] input_last = input[:, -1, ].cuda() input = input.view(input.shape[0], -1, input.shape[-2], input.shape[-1]).cuda() test_target = test_input[:, -1, ].cuda() test_input = test_input[:, :-1].view(test_input.shape[0], -1, test_input.shape[-2], test_input.shape[-1]).cuda() #------- update optim_G -------------- G_output = generator(input) pred_flow_esti_tensor = torch.cat([input_last, G_output], 1) gt_flow_esti_tensor = torch.cat([input_last, target], 1) flow_gt = batch_estimate(gt_flow_esti_tensor, flow_network) flow_pred = batch_estimate(pred_flow_esti_tensor, flow_network) g_adv_loss = adversarial_loss(discriminator(G_output)) g_op_loss = op_loss(flow_pred, flow_gt) g_int_loss = int_loss(G_output, target) g_gd_loss = gd_loss(G_output, target) g_loss = lam_adv * g_adv_loss + lam_gd * g_gd_loss + lam_op * g_op_loss + lam_int * g_int_loss optimizer_G.zero_grad() g_loss.backward() optimizer_G.step() train_psnr = psnr_error(G_output, target) #----------- update optim_D ------- optimizer_D.zero_grad() d_loss = discriminate_loss(discriminator(target), discriminator(G_output.detach())) #d_loss.requires_grad=True d_loss.backward() optimizer_D.step() #----------- cal psnr -------------- test_generator = generator.eval() test_output = test_generator(test_input) test_psnr = psnr_error(test_output, test_target).cuda() if step % 10 == 0: print("[{}/{}]: g_loss: {} d_loss {}".format( step, epoch, g_loss, d_loss)) print('\t gd_loss {}, op_loss {}, int_loss {} ,'.format( g_gd_loss, g_op_loss, g_int_loss)) print('\t train psnr{},test_psnr {}'.format( train_psnr, test_psnr)) writer.add_scalar('psnr/train_psnr', train_psnr, global_step=step) writer.add_scalar('psnr/test_psnr', test_psnr, global_step=step) writer.add_scalar('total_loss/g_loss', g_loss, global_step=step) writer.add_scalar('total_loss/d_loss', d_loss, global_step=step) writer.add_scalar('g_loss/adv_loss', g_adv_loss, global_step=step) writer.add_scalar('g_loss/op_loss', g_op_loss, global_step=step) writer.add_scalar('g_loss/int_loss', g_int_loss, global_step=step) writer.add_scalar('g_loss/gd_loss', g_gd_loss, global_step=step) writer.add_image('image/train_target', target[0], global_step=step) writer.add_image('image/train_output', G_output[0], global_step=step) writer.add_image('image/test_target', test_target[0], global_step=step) writer.add_image('image/test_output', test_output[0], global_step=step) step += 1 if step % 500 == 0: utils.saver(generator.state_dict(), model_generator_save_path, step, max_to_save=10) utils.saver(discriminator.state_dict(), model_discriminator_save_path, step, max_to_save=10) if step >= 2000: print('==== begin evaluate the model of {} ===='.format( model_generator_save_path + '-' + str(step))) auc = evaluate(frame_num=5, layer_nums=4, input_channels=12, output_channels=3, model_path=model_generator_save_path + '-' + str(step), evaluate_name='compute_auc') writer.add_scalar('results/auc', auc, global_step=step)
num_workers=5, pin_memory=True) test_b_loader = DataLoader(test_b_dataset, batch_size=1, shuffle=False, num_workers=5, pin_memory=True) # net and optimizer ds_unet = UNet(1, 1, domain_specific=True) ds_unet.cuda() labeller = UNet(1, 1) # import weights here... labeller_path = './results/unet_sobel_eadan_in/net' labeller.load_state_dict( torch.load(os.path.join(labeller_path), map_location=lambda storage, loc: storage)) labeller.cuda() optimiser = optim.Adam(ds_unet.parameters(), lr=learning_rate) print('Project name ', project_name) train_dices = [] train_losses = [] val_a_dices = [] val_a_losses = [] val_b_dices = [] val_b_losses = []
def main(): """Create the model and start the evaluation process.""" args = get_arguments() gpu0 = args.gpu if not os.path.exists(args.save): os.makedirs(args.save) model = UNet(3, n_classes=args.num_classes) saved_state_dict = torch.load(args.restore_from) model.load_state_dict(saved_state_dict) model.cuda(gpu0) model.train() testloader = data.DataLoader(REFUGE(False, domain='REFUGE_TEST', is_transform=True), batch_size=args.batch_size, shuffle=False, pin_memory=True) if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(460, 460), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(460, 460), mode='bilinear') for index, batch in enumerate(testloader): if index % 100 == 0: print('%d processd' % index) image, label, _, _, name = batch if args.model == 'Unet': _, _, _, _, output2 = model( Variable(image, volatile=True).cuda(gpu0)) output = interp(output2).cpu().data.numpy() for idx, one_name in enumerate(name): pred = output[idx] pred = pred.transpose(1, 2, 0) pred = np.asarray(np.argmax(pred, axis=2), dtype=np.uint8) output_col = colorize_mask(pred) if is_polar: # plt.imshow(output_col) # plt.show() output_col = np.array(output_col) output_col[output_col == 0] = 0 output_col[output_col == 1] = 128 output_col[output_col == 2] = 255 # plt.imshow(output_col) # plt.show() output_col = cv2.linearPolar( rotate(output_col, 90), (args.ROI_size / 2, args.ROI_size / 2), args.ROI_size / 2, cv2.WARP_FILL_OUTLIERS + cv2.WARP_INVERSE_MAP) # plt.imshow(output_col) # plt.show() output_col = np.array(output_col * 255, dtype=np.uint8) output_col[output_col > 200] = 210 output_col[output_col == 0] = 255 output_col[output_col == 210] = 0 output_col[(output_col > 0) & (output_col < 255)] = 128 output_col = Image.fromarray(output_col) # plt.imshow(output_col) # plt.show() one_name = one_name.split('/')[-1] if len(one_name.split('_')) > 0: one_name = one_name[:-4] #pred.save('%s/%s.bmp' % (args.save, one_name)) output_col = output_col.convert('L') print(output_col.size) output_col.save('%s/%s.bmp' % (args.save, one_name.split('.')[0]))
}, { 'cmap': 'jet', 'vmin': 0, 'vmax': eval_label.max() }) net_is_3d = False if torch.cuda.device_count() > 1: print("Using", torch.cuda.device_count(), "GPUs.") device_ids = [i for i in range(torch.cuda.device_count())] model = nn.DataParallel(model, device_ids=device_ids) model = model.to(device) if experiment == "Unet": model.load_state_dict(torch.load("best_weights.pth")) elif experiment == "DeepLab": model.load_state_dict(torch.load(f"best_weights_{backbone}_deeplab.pth")) model.eval() eval_images, eval_labels, eval_label_corners = batch_generator( eval_image, eval_label, **windowing_params, return_corners=True) eval_dataset = PlateletDataset(eval_images, eval_labels, train=False) prob_maps = stitch(model, eval_images, eval_labels, eval_label.shape, eval_label_corners, windowing_params, net_is_3d, n_classes, device, channels)
def evaluate(frame_num, layer_nums, input_channels, output_channels, model_path, evaluate_name, bn=False): ''' :param frame_num: :param layer_nums: :param input_channels: :param output_channels: :param model_path: :param evaluate_name: compute_auc :param bn: :return: ''' generator = UNet(n_channels=input_channels, layer_nums=layer_nums, output_channel=output_channels, bn=bn).cuda().eval() video_dirs = os.listdir(testing_data_folder) video_dirs.sort() num_videos = len(video_dirs) time_stamp = time.time() psnr_records = [] total = 0 generator.load_state_dict(torch.load(model_path)) for dir in video_dirs: _temp_test_folder = os.path.join(testing_data_folder, dir) dataset = img_dataset.test_dataset(_temp_test_folder, clip_length=frame_num) len_dataset = dataset.pics_len test_iters = len_dataset - frame_num + 1 test_counter = 0 data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1) psnrs = np.empty(shape=(len_dataset, ), dtype=np.float32) for test_input, _ in data_loader: test_target = test_input[:, -1].cuda() test_input = test_input[:, :-1].view(test_input.shape[0], -1, test_input.shape[-2], test_input.shape[-1]).cuda() g_output = generator(test_input) test_psnr = psnr_error(g_output, test_target) test_psnr = test_psnr.tolist() psnrs[test_counter + frame_num - 1] = test_psnr test_counter += 1 total += 1 if test_counter >= test_iters: psnrs[:frame_num - 1] = psnrs[frame_num - 1] psnr_records.append(psnrs) print('finish test video set {}'.format(_temp_test_folder)) break result_dict = { 'dataset': dataset_name, 'psnr': psnr_records, 'flow': [], 'names': [], 'diff_mask': [] } used_time = time.time() - time_stamp print('total time = {}, fps = {}'.format(used_time, total / used_time)) pickle_path = os.path.join(psnr_dir, os.path.split(model_path)[-1]) with open(pickle_path, 'wb') as writer: pickle.dump(result_dict, writer, pickle.HIGHEST_PROTOCOL) results = eval_metric.evaluate(evaluate_name, pickle_path) print(results)
def train_eval_model(opts): # parse model configuration num_epochs = opts["num_epochs"] train_batch_size = opts["train_batch_size"] val_batch_size = opts["eval_batch_size"] dataset_type = opts["dataset_type"] opti_mode = opts["optimizer"] loss_criterion = opts["loss_criterion"] lr = opts["lr"] lr_decay = opts["lr_decay"] wd = opts["weight_decay"] gpus = opts["gpu_list"].split(',') os.environ['CUDA_VISIBLE_DEVICE'] = opts["gpu_list"] train_dir = opts["log_dir"] train_data_dir = opts["train_data_dir"] eval_data_dir = opts["eval_data_dir"] pretrained = opts["pretrained_model"] resume = opts["resume"] display_iter = opts["display_iter"] save_epoch = opts["save_every_epoch"] show = opts["vis"] # backup train configs log_file = os.path.join(train_dir, "log_file.txt") os.makedirs(train_dir, exist_ok=True) model_dir = os.path.join(train_dir, "code_backup") os.makedirs(model_dir, exist_ok=True) if resume is None and os.path.exists(log_file): os.remove(log_file) shutil.copy("./models/unet.py", os.path.join(model_dir, "unet.py")) shutil.copy("./trainer_unet.py", os.path.join(model_dir, "trainer_unet.py")) shutil.copy("./datasets/dataset.py", os.path.join(model_dir, "dataset.py")) ckt_dir = os.path.join(train_dir, "checkpoints") os.makedirs(ckt_dir, exist_ok=True) # format printing configs print("*" * 50) table_key = [] table_value = [] n = 0 for key, value in opts.items(): table_key.append(key) table_value.append(str(value)) n += 1 print_table([table_key, ["="] * n, table_value]) # format gpu list gpu_list = [] for str_id in gpus: id = int(str_id) gpu_list.append(id) # dataloader print("==> Create dataloader") dataloaders_dict = { "train": er_data_loader(train_data_dir, train_batch_size, dataset_type, is_train=True), "eval": er_data_loader(eval_data_dir, val_batch_size, dataset_type, is_train=False) } # define parameters of two networks print("==> Create network") num_channels = 1 num_classes = 1 model = UNet(num_channels, num_classes) init_weights(model) # loss layer criterion = create_criterion(criterion=loss_criterion) best_acc = 0.0 start_epoch = 0 # load pretrained model if pretrained is not None and os.path.isfile(pretrained): print("==> Train from model '{}'".format(pretrained)) checkpoint_gan = torch.load(pretrained) model.load_state_dict(checkpoint_gan['model_state_dict']) print("==> Loaded checkpoint '{}')".format(pretrained)) for param in model.parameters(): param.requires_grad = False # resume training elif resume is not None and os.path.isfile(resume): print("==> Resume from checkpoint '{}'".format(resume)) checkpoint = torch.load(resume) start_epoch = checkpoint['epoch'] + 1 best_acc = checkpoint['best_acc'] model_dict = model.state_dict() pretrained_dict = { k: v for k, v in checkpoint['model_state_dict'].items() if k in model_dict and v.size() == model_dict[k].size() } model_dict.update(pretrained_dict) model.load_state_dict(pretrained_dict) print("==> Loaded checkpoint '{}' (epoch {})".format( resume, checkpoint['epoch'] + 1)) # train from scratch else: print("==> Train from initial or random state.") # define mutiple-gpu mode device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model.cuda() model = nn.DataParallel(model) # print learnable parameters print("==> List learnable parameters") for name, param in model.named_parameters(): if param.requires_grad == True: print("\t{}, size {}".format(name, param.size())) params_to_update = [{'params': model.parameters()}] # define optimizer print("==> Create optimizer") optimizer = create_optimizer(params_to_update, opti_mode, lr=lr, momentum=0.9, wd=wd) if resume is not None and os.path.isfile(resume): optimizer.load_state_dict(checkpoint['optimizer']) # start training since = time.time() # Each epoch has a training and validation phase print("==> Start training") total_steps = 0 for epoch in range(start_epoch, num_epochs): print('-' * 50) print("==> Epoch {}/{}".format(epoch + 1, num_epochs)) total_steps = train_one_epoch(epoch, total_steps, dataloaders_dict['train'], model, device, criterion, optimizer, lr, lr_decay, display_iter, log_file, show) epoch_acc, epoch_iou, epoch_f1 = eval_one_epoch( epoch, dataloaders_dict['eval'], model, device, log_file) if best_acc < epoch_acc and epoch >= 5: best_acc = epoch_acc torch.save( { 'epoch': epoch, 'model_state_dict': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'best_acc': best_acc }, os.path.join(ckt_dir, "best.pth")) if (epoch + 1) % save_epoch == 0 and (epoch + 1) >= 20: torch.save( { 'epoch': epoch, 'model_state_dict': model.module.state_dict(), 'optimizer': optimizer.state_dict(), 'best_iou': epoch_iou }, os.path.join(ckt_dir, "checkpoints_" + str(epoch + 1) + ".pth")) time_elapsed = time.time() - since time_message = 'Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60) print(time_message) with open(log_file, "a+") as fid: fid.write('%s\n' % time_message) print('==> Best val Acc: {:4f}'.format(best_acc))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'{torch.cuda.device_count()} cuda device available.') print(f'Using {device} device.') batch_size = args.batch_size if torch.cuda.device_count() > 1: batch_size *= torch.cuda.device_count() testset = MaskFolder(args.dataset, transform=Compose([Resize((512, 512)), ToTensor()])) testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=batch_size, pin_memory=True, drop_last=False) model = UNet(3, 1) model.load_state_dict(torch.load(args.weights)) if torch.cuda.device_count() > 1: model = nn.DataParallel(model) model.to(device) prec, reca, f1 = test(model, tqdm(testloader, desc=f'Testing threshold {args.threshold}'), device, threshold=args.threshold) print(f'Precision {prec}. Recall {reca}. F1 {f1}.')
def inference(): """Support two mode: evaluation (on valid set) or inference mode (on test-set for submission) """ parser = argparse.ArgumentParser(description="Inference mode") parser.add_argument('-testf', "--test-filepath", type=str, default=None, required=True, help="testing dataset filepath.") parser.add_argument("-eval", "--evaluate", action="store_true", default=False, help="Evaluation mode") parser.add_argument("--load-weights", type=str, default=None, help="Load pretrained weights, torch state_dict() (filepath, default: None)") parser.add_argument("--load-model", type=str, default=None, help="Load pretrained model, entire model (filepath, default: None)") parser.add_argument("--save2dir", type=str, default=None, help="save the prediction labels to the directory (default: None)") parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--batch-size", type=int, default=32, help="Batch size") parser.add_argument("--num-cpu", type=int, default=10, help="Number of CPUs to use in parallel for dataloader.") parser.add_argument('--cuda', type=int, default=0, help='CUDA visible device (use CPU if -1, default: 0)') args = parser.parse_args() printYellow("="*10 + " Inference mode. "+"="*10) if args.save2dir: os.makedirs(args.save2dir, exist_ok=True) device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available() and (args.cuda >= 0) else "cpu") transform_normalize = transforms.Normalize(mean=[0.5], std=[0.5]) data_transform = transforms.Compose([ transforms.ToTensor(), transform_normalize ]) data_loader_params = {'batch_size': args.batch_size, 'shuffle': False, 'num_workers': args.num_cpu, 'drop_last': False, 'pin_memory': False } test_set = LiTSDataset(args.test_filepath, dtype=np.float32, pixelwise_transform=data_transform, inference_mode=(not args.evaluate), ) dataloader_test = torch.utils.data.DataLoader(test_set, **data_loader_params) # =================== Build model =================== if args.load_weights: model = UNet(in_ch=1, out_ch=3, # there are 3 classes: 0: background, 1: liver, 2: tumor depth=4, start_ch=64, inc_rate=2, kernel_size=3, padding=True, batch_norm=True, spec_norm=False, dropout=0.5, up_mode='upconv', include_top=True, include_last_act=False, ) model.load_state_dict(torch.load(args.load_weights)) printYellow("Successfully loaded pretrained weights.") elif args.load_model: # load entire model model = torch.load(args.load_model) printYellow("Successfully loaded pretrained model.") model.eval() model.to(device) # n_batch_per_epoch = len(dataloader_test) sigmoid_act = torch.nn.Sigmoid() st = time.time() volume_start_index = test_set.volume_start_index spacing = test_set.spacing direction = test_set.direction # use it for the submission offset = test_set.offset msk_pred_buffer = [] if args.evaluate: msk_gt_buffer = [] for data_batch in tqdm(dataloader_test): # import ipdb # ipdb.set_trace() if args.evaluate: img, msk_gt = data_batch msk_gt_buffer.append(msk_gt.cpu().detach().numpy()) else: img = data_batch img = img.to(device) with torch.no_grad(): msk_pred = model(img) # shape (N, 3, H, W) msk_pred = sigmoid_act(msk_pred) msk_pred_buffer.append(msk_pred.cpu().detach().numpy()) msk_pred_buffer = np.vstack(msk_pred_buffer) # shape (N, 3, H, W) if args.evaluate: msk_gt_buffer = np.vstack(msk_gt_buffer) results = [] for vol_ind, vol_start_ind in enumerate(volume_start_index): if vol_ind == len(volume_start_index) - 1: volume_msk = msk_pred_buffer[vol_start_ind:] # shape (N, 3, H, W) if args.evaluate: volume_msk_gt = msk_gt_buffer[vol_start_ind:] else: vol_end_ind = volume_start_index[vol_ind+1] volume_msk = msk_pred_buffer[vol_start_ind:vol_end_ind] # shape (N, 3, H, W) if args.evaluate: volume_msk_gt = msk_gt_buffer[vol_start_ind:vol_end_ind] if args.evaluate: # liver liver_scores = get_scores(volume_msk[:, 1] >= 0.5, volume_msk_gt >= 1, spacing[vol_ind]) # tumor lesion_scores = get_scores(volume_msk[:, 2] >= 0.5, volume_msk_gt == 2, spacing[vol_ind]) print("Liver dice", liver_scores['dice'], "Lesion dice", lesion_scores['dice']) results.append([vol_ind, liver_scores, lesion_scores]) # =========================== else: # import ipdb; ipdb.set_trace() if args.save2dir: # reverse the order, because we prioritize tumor, liver then background. msk_pred = (volume_msk >= 0.5)[:, ::-1, ...] # shape (N, 3, H, W) msk_pred = np.argmax(msk_pred, axis=1) # shape (N, H, W) = (z, x, y) msk_pred = np.transpose(msk_pred, axes=(1, 2, 0)) # shape (x, y, z) # remember to correct 'direction' and np.transpose before the submission !!! if direction[vol_ind][0] == -1: # x-axis msk_pred = msk_pred[::-1, ...] if direction[vol_ind][1] == -1: # y-axis msk_pred = msk_pred[:, ::-1, :] if direction[vol_ind][2] == -1: # z-axis msk_pred = msk_pred[..., ::-1] # save medical image header as well # see: http://loli.github.io/medpy/generated/medpy.io.header.Header.html file_header = med_header(spacing=tuple(spacing[vol_ind]), offset=tuple(offset[vol_ind]), direction=np.diag(direction[vol_ind])) # submission guide: # see: https://github.com/PatrickChrist/LITS-CHALLENGE/blob/master/submission-guide.md # test-segmentation-X.nii filepath = os.path.join(args.save2dir, f"test-segmentation-{vol_ind}.nii") med_save(msk_pred, filepath, hdr=file_header) if args.save2dir: # outpath = os.path.join(args.save2dir, "results.csv") outpath = os.path.join(args.save2dir, "results.pkl") with open(outpath, "wb") as file: final_result = {} final_result['liver'] = defaultdict(list) final_result['tumor'] = defaultdict(list) for vol_ind, liver_scores, lesion_scores in results: # [OTC] assuming vol_ind is continuous for key in liver_scores: final_result['liver'][key].append(liver_scores[key]) for key in lesion_scores: final_result['tumor'][key].append(lesion_scores[key]) pickle.dump(final_result, file, protocol=3) # ======== code from official metric ======== # create line for csv file # outstr = str(vol_ind) + ',' # for l in [liver_scores, lesion_scores]: # for k, v in l.items(): # outstr += str(v) + ',' # outstr += '\n' # # create header for csv file if necessary # if not os.path.isfile(outpath): # headerstr = 'Volume,' # for k, v in liver_scores.items(): # headerstr += 'Liver_' + k + ',' # for k, v in liver_scores.items(): # headerstr += 'Lesion_' + k + ',' # headerstr += '\n' # outstr = headerstr + outstr # # write to file # f = open(outpath, 'a+') # f.write(outstr) # f.close() # =========================== printGreen(f"Total elapsed time: {time.time()-st}") return results
datasets = torchvision.datasets.VOCSegmentation( dataroot, year='2012', image_set='train', download=False, transform=original_transform, target_transform=teacher_transform) train_loader = torch.utils.data.DataLoader(datasets, batch_size=1, shuffle=False) model = UNet(n_channels=3, n_classes=21) #.cuda() model.load_state_dict(torch.load(checkpoint_path)) model.eval() with torch.no_grad(): # data, target = iter(train_loader).next() abab = 0 for data, target in train_loader: # output = model(data).data abab += 1 # data = data.cuda() # output = model(data).data # [4, 21, 512, 512] img = Image.fromarray(data[0].detach().cpu().transpose(0, 1).transpose( 1, 2).numpy().astype(np.uint8)) y = output[0].detach().cpu() anno_class_img = Image.fromarray(np.uint8(np.argmax(y.numpy(), axis=0)),
class Trainer: def __init__(self, seq_length, color_channels, unet_path="pretrained/unet.mdl", discrim_path="pretrained/dicrim.mdl", facenet_path="pretrained/facenet.mdl", vgg_path="", embedding_size=1000, unet_depth=3, unet_filts=32, facenet_filts=32, resnet=18): self.color_channels = color_channels self.margin = 0.5 self.writer = SummaryWriter(log_dir="logs") self.unet_path = unet_path self.discrim_path = discrim_path self.facenet_path = facenet_path self.unet = UNet(in_channels=color_channels, out_channels=color_channels, depth=unet_depth, start_filts=unet_filts, up_mode="upsample", merge_mode='concat').to(device) self.discrim = FaceNetModel(embedding_size=embedding_size, start_filts=facenet_filts, in_channels=color_channels, resnet=resnet, pretrained=False).to(device) self.facenet = FaceNetModel(embedding_size=embedding_size, start_filts=facenet_filts, in_channels=color_channels, resnet=resnet, pretrained=False).to(device) if os.path.isfile(unet_path): self.unet.load_state_dict(torch.load(unet_path)) print("unet loaded") if os.path.isfile(discrim_path): self.discrim.load_state_dict(torch.load(discrim_path)) print("discrim loaded") if os.path.isfile(facenet_path): self.facenet.load_state_dict(torch.load(facenet_path)) print("facenet loaded") if os.path.isfile(vgg_path): self.vgg_loss_network = LossNetwork(vgg_face_dag(vgg_path)).to(device) self.vgg_loss_network.eval() print("vgg loaded") self.mse_loss_function = nn.MSELoss().to(device) self.discrim_loss_function = nn.BCELoss().to(device) self.triplet_loss_function = TripletLoss(margin=self.margin) self.unet_optimizer = torch.optim.Adam(self.unet.parameters(), betas=(0.9, 0.999)) self.discrim_optimizer = torch.optim.Adam(self.discrim.parameters(), betas=(0.9, 0.999)) self.facenet_optimizer = torch.optim.Adam(self.facenet.parameters(), betas=(0.9, 0.999)) def test(self, test_loader, epoch=0): X, y = next(iter(test_loader)) B, D, C, W, H = X.shape # X = X.view(B, C * D, W, H) self.unet.eval() self.facenet.eval() self.discrim.eval() with torch.no_grad(): y_ = self.unet(X.to(device)) mse = self.mse_loss_function(y_, y.to(device)) loss_G = self.loss_GAN_generator(btch_X=X.to(device)) loss_D = self.loss_GAN_discrimator(btch_X=X.to(device), btch_y=y.to(device)) loss_facenet, _, n_bad = self.loss_facenet(X.to(device), y.to(device)) plt.title(f"epoch {epoch} mse={mse.item():.4} facenet={loss_facenet.item():.4} bad={n_bad / B ** 2}") i = np.random.randint(0, B) a = np.hstack((y[i].transpose(0, 1).transpose(1, 2), y_[i].transpose(0, 1).transpose(1, 2).to(cpu))) b = np.hstack((X[i][0].transpose(0, 1).transpose(1, 2), X[i][-1].transpose(0, 1).transpose(1, 2))) plt.imshow(np.vstack((a, b))) plt.axis('off') plt.show() self.writer.add_scalar("test bad_percent", n_bad / B ** 2, global_step=epoch) self.writer.add_scalar("test loss", mse.item(), global_step=epoch) # self.writer.add_scalars("test GAN", {"discrim": loss_D.item(), # "gen": loss_G.item()}, global_step=epoch) with torch.no_grad(): n_for_show = 10 y_show_ = y_.to(device) y_show = y.to(device) embeddings_anc, _ = self.facenet(y_show_) embeddings_pos, _ = self.facenet(y_show) embeds = torch.cat((embeddings_anc[:n_for_show], embeddings_pos[:n_for_show])) imgs = torch.cat((y_show_[:n_for_show], y_show[:n_for_show])) names = list(range(n_for_show)) * 2 # print(embeds.shape, imgs.shape, len(names)) # self.writer.add_embedding(mat=embeds, metadata=names, label_img=imgs, tag="embeddings", global_step=epoch) trshs, fprs, tprs = roc_curve(embeddings_anc.detach().to(cpu), embeddings_pos.detach().to(cpu)) rnk1 = rank1(embeddings_anc.detach().to(cpu), embeddings_pos.detach().to(cpu)) plt.step(fprs, tprs) # plt.xlim((1e-4, 1)) plt.yticks(np.arange(0, 1, 0.05)) plt.xticks(np.arange(min(fprs), max(fprs), 10)) plt.xscale('log') plt.title(f"ROC auc={auc(fprs, tprs)} rnk1={rnk1}") self.writer.add_figure("ROC test", plt.gcf(), global_step=epoch) self.writer.add_scalar("auc", auc(fprs, tprs), global_step=epoch) self.writer.add_scalar("rank1", rnk1, global_step=epoch) print(f"\n###### {epoch} TEST mse={mse.item():.4} GAN(G/D)={loss_G.item():.4}/{loss_D.item():.4} " f"facenet={loss_facenet.item():.4} bad={n_bad / B ** 2:.4} auc={auc(fprs, tprs)} rank1={rnk1} #######") def test_test(self, test_loader): X, ys = next(iter(test_loader)) true_idx = 0 x = X[true_idx] D, C, W, H = x.shape # x = x.view(C * D, W, H) dists = list() with torch.no_grad(): y_ = self.unet(x.to(device)) embedding_anc, _ = self.facenet(y_) embeddings_pos, _ = self.facenet(ys) for emb_pos_item in embeddings_pos: dist = l2_dist.forward(embedding_anc, emb_pos_item) dists.append(dist) a_sorted = np.argsort(dists) a = np.hstack((ys[true_idx].transpose(0, 1).transpose(1, 2), y_.transpose(0, 1).transpose(1, 2).to(cpu).numpy(), ys[a_sorted[0]].transpose(0, 1).transpose(1, 2))) b = np.hstack((x[0:3].transpose(0, 1).transpose(1, 2), x[D // 2 * C:D // 2 * C + 3].transpose(0, 1).transpose(1, 2), x[-3:].transpose(0, 1).transpose(1, 2))) b_ = b - np.min(b) b_ = b_ / np.max(b) b_ = equalize_func([(b_ * 255).astype(np.uint8)], use_clahe=True)[0] b = b_.astype(np.float32) / 255 plt.imshow(cv2.cvtColor(np.vstack((a, b)), cv2.COLOR_BGR2RGB)) plt.axis('off') plt.show() def loss_facenet(self, X, y, is_detached=False): B, D, C, W, H = X.shape y_ = self.unet(X) embeddings_anc, D_fake = self.facenet(y_ if not is_detached else y_.detach()) embeddings_pos, D_real = self.facenet(y) target_real = torch.full_like(D_fake, 1) loss_gen = self.discrim_loss_function(D_fake, target_real) pos_dist = l2_dist.forward(embeddings_anc, embeddings_pos) bad_triplets_loss = None n_bad = 0 for shift in range(1, B): embeddings_neg = torch.roll(embeddings_pos, shift, 0) neg_dist = l2_dist.forward(embeddings_anc, embeddings_neg) bad_triplets_idxs = np.where((neg_dist - pos_dist < self.margin).cpu().numpy().flatten())[0] if shift == 1: bad_triplets_loss = self.triplet_loss_function.forward(embeddings_anc[bad_triplets_idxs], embeddings_pos[bad_triplets_idxs], embeddings_neg[bad_triplets_idxs]).to( device) else: bad_triplets_loss += self.triplet_loss_function.forward(embeddings_anc[bad_triplets_idxs], embeddings_pos[bad_triplets_idxs], embeddings_neg[bad_triplets_idxs]).to(device) n_bad += len(bad_triplets_idxs) bad_triplets_loss /= B return bad_triplets_loss, torch.mean(loss_gen), n_bad # def loss_mse(self, btch_X, btch_y): # btch_y_ = self.unet(btch_X) # loss_unet = self.mse_loss_function(btch_y_, btch_y) # # features_target = self.facenet.forward_mse(btch_y) # features = self.facenet.forward_mse(btch_y_) # # loss_first_layer = self.mse_loss_function(features, features_target) # return loss_unet + loss_first_layer def loss_mse_vgg(self, btch_X, btch_y, k_mse, k_vgg): btch_y_ = self.unet(btch_X) # print(btch_y_.shape,btch_y.shape) perceptual_btch_y_ = self.vgg_loss_network(btch_y_) perceptual_btch_y = self.vgg_loss_network(btch_y) perceptual_loss = 0.0 for a, b in zip(perceptual_btch_y_, perceptual_btch_y): perceptual_loss += self.mse_loss_function(a, b) return k_vgg * perceptual_loss + k_mse * self.mse_loss_function(btch_y_, btch_y) def loss_GAN_discrimator(self, btch_X, btch_y): btch_y_ = self.unet(btch_X) _, y_D_fake_ = self.discrim(btch_y_.detach()) _, y_D_real_ = self.discrim(btch_y) target_fake = torch.full_like(y_D_fake_, 0) target_real = torch.full_like(y_D_real_, 1) loss_D_fake_ = self.discrim_loss_function(y_D_fake_, target_fake) loss_D_real_ = self.discrim_loss_function(y_D_real_, target_real) loss_discrim = (loss_D_real_ + loss_D_fake_) return loss_discrim def loss_GAN_generator(self, btch_X): btch_y_ = self.unet(btch_X) _, y_D_fake_ = self.discrim(btch_y_) target_real = torch.full_like(y_D_fake_, 1) loss_gen = self.discrim_loss_function(y_D_fake_, target_real) return loss_gen def relax_discriminator(self, btch_X, btch_y): self.discrim.zero_grad() # train with real y_discrim_real_ = self.discrim(btch_y) y_discrim_real_ = y_discrim_real_.mean() y_discrim_real_.backward(self.mone) # train with fake btch_y_ = self.unet(btch_X) y_discrim_fake_detached_ = self.discrim(btch_y_.detach()) y_discrim_fake_detached_ = y_discrim_fake_detached_.mean() y_discrim_fake_detached_.backward(self.one) # gradient_penalty gradient_penalty = self.discrim_gradient_penalty(btch_y, btch_y_) gradient_penalty.backward() self.discrim_optimizer.step() def relax_generator(self, btch_X): self.unet.zero_grad() btch_y_ = self.unet(btch_X) y_discrim_fake_ = self.discrim(btch_y_) y_discrim_fake_ = y_discrim_fake_.mean() y_discrim_fake_.backward(self.mone) self.unet_optimizer.step() def discrim_gradient_penalty(self, real_y, fake_y): lambd = 10 btch_size = real_y.shape[0] alpha = torch.rand(btch_size, 1, 1, 1).to(device) # print(alpha.shape, real_y.shape) alpha = alpha.expand_as(real_y) interpolates = alpha * real_y + (1 - alpha) * fake_y interpolates = interpolates.to(device) interpolates = autograd.Variable(interpolates, requires_grad=True) interpolates_out = self.discrim(interpolates) gradients = autograd.grad(outputs=interpolates_out, inputs=interpolates, grad_outputs=torch.ones(interpolates_out.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambd return gradient_penalty def train(self, train_loader, test_loader, batch_size=2, epochs=30, k_gen=1, k_discrim=1, k_mse=1, k_facenet=1, k_facenet_back=1, k_vgg=1): """ :param X: np.array shape=(n_videos, n_frames, h, w) :param y: np.array shape=(n_videos, h, w) :param epochs: int """ print("\nSTART TRAINING\n") for epoch in range(epochs): self.test(test_loader, epoch) self.unet.train() self.facenet.train() self.discrim.train() # train by batches for idx, (btch_X, btch_y) in enumerate(train_loader): B, D, C, W, H = btch_X.shape # btch_X = btch_X.view(B, C * D, W, H) btch_X = btch_X.to(device) btch_y = btch_y.to(device) # Mse loss self.unet.zero_grad() mse = self.loss_mse_vgg(btch_X, btch_y, k_mse, k_vgg) mse.backward() self.unet_optimizer.step() # facenet_backup = deepcopy(self.facenet.state_dict()) # for i in range(unrolled_iterations): self.discrim.zero_grad() loss_D = self.loss_GAN_discrimator(btch_X, btch_y) loss_D = k_discrim * loss_D loss_D.backward() self.discrim_optimizer.step() self.discrim.zero_grad() self.unet.zero_grad() loss_G = self.loss_GAN_generator(btch_X) loss_G = k_gen * loss_G loss_G.backward() self.unet_optimizer.step() # Facenet self.unet.zero_grad() self.facenet.zero_grad() facenet_loss, _, n_bad = self.loss_facenet(btch_X, btch_y) facenet_loss = k_facenet * facenet_loss facenet_loss.backward() self.facenet_optimizer.step() self.unet.zero_grad() self.facenet.zero_grad() facenet_back_loss, _, n_bad = self.loss_facenet(btch_X, btch_y) facenet_back_loss = k_facenet_back * facenet_back_loss facenet_back_loss.backward() self.unet_optimizer.step() print(f"btch {idx * batch_size} mse={mse.item():.4} GAN(G/D)={loss_G.item():.4}/{loss_D.item():.4} " f"facenet={facenet_loss.item():.4} bad={n_bad / B ** 2:.4}") global_step = epoch * len(train_loader.dataset) // batch_size + idx self.writer.add_scalar("train bad_percent", n_bad / B ** 2, global_step=global_step) self.writer.add_scalar("train loss", mse.item(), global_step=global_step) # self.writer.add_scalars("train GAN", {"discrim": loss_D.item(), # "gen": loss_G.item()}, global_step=global_step) torch.save(self.unet.state_dict(), self.unet_path) torch.save(self.discrim.state_dict(), self.discrim_path) torch.save(self.facenet.state_dict(), self.facenet_path)
class Noise2Noise(object): """Implementation of Noise2Noise from Lehtinen et al. (2018).""" def __init__(self, params, trainable): """Initializes model.""" self.p = params self.trainable = trainable self._compile() #初始化模型 def _compile(self): """ Compiles model (architecture, loss function, optimizers, etc.). 初始化 网络、损失函数、优化器等 """ print('Noise2Noise: Learning Image Restoration without Clean Data (Lethinen et al., 2018)') # Model (3x3=9 channels for Monte Carlo since it uses 3 HDR buffers) 已删除蒙特卡洛相关代码 if self.p.noise_type == 'mc': self.is_mc = True self.model = UNet(in_channels=9) else: self.is_mc = False self.model = UNet(in_channels=3) # Set optimizer and loss, if in training mode # 如果 为训练,则初始化优化器和损失 if self.trainable: self.optim = Adam(self.model.parameters(), lr=self.p.learning_rate, betas=self.p.adam[:2], eps=self.p.adam[2]) # Learning rate adjustment self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim, patience=self.p.nb_epochs/4, factor=0.5, verbose=True) # Loss function if self.p.loss == 'hdr': assert self.is_mc, 'Using HDR loss on non Monte Carlo images' self.loss = HDRLoss() elif self.p.loss == 'l2': self.loss = nn.MSELoss() else: self.loss = nn.L1Loss() # CUDA support self.use_cuda = torch.cuda.is_available() and self.p.cuda if self.use_cuda: self.model = self.model.cuda() if self.trainable: self.loss = self.loss.cuda() def _print_params(self): """Formats parameters to print when training.""" print('Training parameters: ') self.p.cuda = self.use_cuda param_dict = vars(self.p) pretty = lambda x: x.replace('_', ' ').capitalize() print('\n'.join(' {} = {}'.format(pretty(k), str(v)) for k, v in param_dict.items())) print() def save_model(self, epoch, stats, first=False): """Saves model to files; can be overwritten at every epoch to save disk space.""" # Create directory for model checkpoints, if nonexistent if first: if self.p.clean_targets: ckpt_dir_name = f'{datetime.now():{self.p.noise_type}-clean-%H%M}' else: ckpt_dir_name = f'{datetime.now():{self.p.noise_type}-%H%M}' if self.p.ckpt_overwrite: if self.p.clean_targets: ckpt_dir_name = f'{self.p.noise_type}-clean' else: ckpt_dir_name = self.p.noise_type self.ckpt_dir = os.path.join(self.p.ckpt_save_path, ckpt_dir_name) if not os.path.isdir(self.p.ckpt_save_path): os.mkdir(self.p.ckpt_save_path) if not os.path.isdir(self.ckpt_dir): os.mkdir(self.ckpt_dir) # Save checkpoint dictionary if self.p.ckpt_overwrite: fname_unet = '{}/n2n-{}.pt'.format(self.ckpt_dir, self.p.noise_type) else: valid_loss = stats['valid_loss'][epoch] fname_unet = '{}/n2n-epoch{}-{:>1.5f}.pt'.format(self.ckpt_dir, epoch + 1, valid_loss) print('Saving checkpoint to: {}\n'.format(fname_unet)) torch.save(self.model.state_dict(), fname_unet) # Save stats to JSON fname_dict = '{}/n2n-stats.json'.format(self.ckpt_dir) with open(fname_dict, 'w') as fp: json.dump(stats, fp, indent=2) def load_model(self, ckpt_fname): """Loads model from checkpoint file.""" print('Loading checkpoint from: {}'.format(ckpt_fname)) if self.use_cuda: self.model.load_state_dict(torch.load(ckpt_fname)) else: self.model.load_state_dict(torch.load(ckpt_fname, map_location='cpu')) def _on_epoch_end(self, stats, train_loss, epoch, epoch_start, valid_loader): """Tracks and saves starts after each epoch.""" # Evaluate model on validation set print('\rTesting model on validation set... ', end='') epoch_time = time_elapsed_since(epoch_start)[0] valid_loss, valid_time, valid_psnr = self.eval(valid_loader) show_on_epoch_end(epoch_time, valid_time, valid_loss, valid_psnr) # Decrease learning rate if plateau self.scheduler.step(valid_loss) # Save checkpoint stats['train_loss'].append(train_loss) stats['valid_loss'].append(valid_loss) stats['valid_psnr'].append(valid_psnr) self.save_model(epoch, stats, epoch == 0) def test(self, test_loader, show=1): """Evaluates denoiser on test set.""" self.model.train(False) source_imgs = [] denoised_imgs = [] clean_imgs = [] # Create directory for denoised images denoised_dir = os.path.dirname(self.p.data) save_path = os.path.join(denoised_dir, 'denoised') if not os.path.isdir(save_path): os.mkdir(save_path) for batch_idx, (source, target) in enumerate(test_loader): # Only do first <show> images if show == 0 or batch_idx >= show: break source_imgs.append(source) clean_imgs.append(target) if self.use_cuda: source = source.cuda() # Denoise denoised_img = self.model(source).detach() denoised_imgs.append(denoised_img) # Squeeze tensors source_imgs = [t.squeeze(0) for t in source_imgs] denoised_imgs = [t.squeeze(0) for t in denoised_imgs] clean_imgs = [t.squeeze(0) for t in clean_imgs] # Create montage and save images print('Saving images and montages to: {}'.format(save_path)) for i in range(len(source_imgs)): img_name = test_loader.dataset.imgs[i] create_montage(img_name, self.p.noise_type, save_path, source_imgs[i], denoised_imgs[i], clean_imgs[i], show) def eval(self, valid_loader): """Evaluates denoiser on validation set.""" self.model.train(False) valid_start = datetime.now() loss_meter = AvgMeter() psnr_meter = AvgMeter() for batch_idx, (source, target) in enumerate(valid_loader): if self.use_cuda: source = source.cuda() target = target.cuda() # Denoise source_denoised = self.model(source) # Update loss loss = self.loss(source_denoised, target) loss_meter.update(loss.item()) # Compute PSRN if self.is_mc: source_denoised = reinhard_tonemap(source_denoised) # TODO: Find a way to offload to GPU, and deal with uneven batch sizes for i in range(self.p.batch_size): source_denoised = source_denoised.cpu() target = target.cpu() psnr_meter.update(psnr(source_denoised[i], target[i]).item()) valid_loss = loss_meter.avg valid_time = time_elapsed_since(valid_start)[0] psnr_avg = psnr_meter.avg return valid_loss, valid_time, psnr_avg def train(self, train_loader, valid_loader): """Trains denoiser on training set.""" self.model.train(True) self._print_params() num_batches = len(train_loader) assert num_batches % self.p.report_interval == 0, 'Report interval must divide total number of batches' # Dictionaries of tracked stats stats = {'noise_type': self.p.noise_type, 'noise_param': self.p.noise_param, 'train_loss': [], 'valid_loss': [], 'valid_psnr': []} # Main training loop train_start = datetime.now() for epoch in range(self.p.nb_epochs): print('EPOCH {:d} / {:d}'.format(epoch + 1, self.p.nb_epochs)) # Some stats trackers epoch_start = datetime.now() train_loss_meter = AvgMeter() loss_meter = AvgMeter() time_meter = AvgMeter() # Minibatch SGD for batch_idx, (source, target) in enumerate(train_loader): batch_start = datetime.now() progress_bar(batch_idx, num_batches, self.p.report_interval, loss_meter.val) if self.use_cuda: source = source.cuda() target = target.cuda() # Denoise image source_denoised = self.model(source) loss = self.loss(source_denoised, target) loss_meter.update(loss.item()) # Zero gradients, perform a backward pass, and update the weights self.optim.zero_grad() loss.backward() self.optim.step() # Report/update statistics time_meter.update(time_elapsed_since(batch_start)[1]) if (batch_idx + 1) % self.p.report_interval == 0 and batch_idx: show_on_report(batch_idx, num_batches, loss_meter.avg, time_meter.avg) train_loss_meter.update(loss_meter.avg) loss_meter.reset() time_meter.reset() # Epoch end, save and reset tracker self._on_epoch_end(stats, train_loss_meter.avg, epoch, epoch_start, valid_loader) train_loss_meter.reset() train_elapsed = time_elapsed_since(train_start)[0] print('Training done! Total elapsed time: {}\n'.format(train_elapsed))
def val(cfg, model=None): if model: test_folder = cfg.test_folder print("The test folder", test_folder) else: model_path = '/project/bo/exp_data/FFP/%s_%d/' % (cfg.dataset_type, cfg.version) ckpt_path = model_path + "model-%d.pth" % cfg.ckpt_step if cfg.dataset_augment_test_type != "frames/testing/" and "venue" in cfg.dataset_type: rain_type = str( cfg.dataset_augment_test_type.strip().split('_')[0]) brightness = int( cfg.dataset_augment_test_type.strip().split('_')[-1]) / 10 data_dir = cfg.dataset_path + "Avenue/frames/%s_testing/bright_%.2f/" % ( rain_type, brightness) if not os.path.exists(data_dir): aug_data.save_avenue_rain_or_bright(cfg.dataset_path, rain_type, True, "testing", bright_space=brightness) else: data_dir = cfg.dataset_path + '/%s/%s/' % ( "Avenue", cfg.dataset_augment_test_type) rain_type = "original" brightness = 1.0 test_folder = data_dir orig_stdout = sys.stdout f = open( os.path.join( model_path, 'output_rain_%s_bright_%s.txt' % (rain_type, brightness)), 'w') sys.stdout = f cfg.gt = np.load('/project/bo/anomaly_data/Avenue/gt_label.npy', allow_pickle=True) if model: # This is for testing during training. generator = model generator.eval() else: generator = UNet(input_channels=12, output_channel=3).cuda().eval() generator.load_state_dict(torch.load(ckpt_path)['net_g']) # generator.load_state_dict(torch.load('weights/' + cfg.trained_model)['net_g']) print("The pre-trained generator has been loaded from", ckpt_path) # print(f'The pre-trained generator has been loaded from \'weights/{cfg.trained_model}\'.\n') videos = {} videos, video_string = input_utils.setup(test_folder, videos) fps = 0 psnr_group = [] if not model: if cfg.show_curve: fig = plt.figure("Image") manager = plt.get_current_fig_manager() manager.window.setGeometry(550, 200, 600, 500) # This works for QT backend, for other backends, check this ⬃⬃⬃. # https://stackoverflow.com/questions/7449585/how-do-you-set-the-absolute-position-of-figure-windows-with-matplotlib plt.xlabel('frames') plt.ylabel('psnr') plt.title('psnr curve') plt.grid(ls='--') cv2.namedWindow('target frames', cv2.WINDOW_NORMAL) cv2.resizeWindow('target frames', 384, 384) cv2.moveWindow("target frames", 100, 100) if cfg.show_heatmap: cv2.namedWindow('difference map', cv2.WINDOW_NORMAL) cv2.resizeWindow('difference map', 384, 384) cv2.moveWindow('difference map', 100, 550) with torch.no_grad(): for i, folder in enumerate(video_string): if not model: name = folder.split('/')[-1] fourcc = cv2.VideoWriter_fourcc('X', 'V', 'I', 'D') if cfg.show_curve: video_writer = cv2.VideoWriter(f'results/{name}_video.avi', fourcc, 30, cfg.img_size) curve_writer = cv2.VideoWriter(f'results/{name}_curve.avi', fourcc, 30, (600, 430)) js = [] plt.clf() ax = plt.axes(xlim=(0, len(dataset)), ylim=(30, 45)) line, = ax.plot([], [], '-b') if cfg.show_heatmap: heatmap_writer = cv2.VideoWriter( f'results/{name}_heatmap.avi', fourcc, 30, cfg.img_size) psnrs = [] dataset = input_utils.test_dataset(videos[folder]['frame'], [imh, imw]) print("Start video %s with %d frames...................." % (folder, len(dataset))) psnrs = [] for j, clip in enumerate(dataset): input_np = clip[0:12, :, :] target_np = clip[12:15, :, :] input_frames = torch.from_numpy(input_np).unsqueeze(0).cuda() target_frame = torch.from_numpy(target_np).unsqueeze(0).cuda() G_frame = generator(input_frames) test_psnr = psnr_error(G_frame, target_frame).cpu().detach().numpy() psnrs.append(float(test_psnr)) if not model: if cfg.show_curve: cv2_frame = ((target_np + 1) * 127.5).transpose( 1, 2, 0).astype('uint8') js.append(j) line.set_xdata( js ) # This keeps the existing figure and updates the X-axis and Y-axis data, line.set_ydata( psnrs) # which is faster, but still not perfect. plt.pause(0.001) # show curve cv2.imshow('target frames', cv2_frame) cv2.waitKey(1) # show video video_writer.write( cv2_frame) # Write original video frames. buffer = io.BytesIO( ) # Write curve frames from buffer. fig.canvas.print_png(buffer) buffer.write(buffer.getvalue()) curve_img = np.array(Image.open(buffer))[..., (2, 1, 0)] curve_writer.write(curve_img) if cfg.show_heatmap: diff_map = torch.sum( torch.abs(G_frame - target_frame).squeeze(), 0) diff_map -= diff_map.min() # Normalize to 0 ~ 255. diff_map /= diff_map.max() diff_map *= 255 diff_map = diff_map.cpu().detach().numpy().astype( 'uint8') heat_map = cv2.applyColorMap(diff_map, cv2.COLORMAP_JET) cv2.imshow('difference map', heat_map) cv2.waitKey(1) heatmap_writer.write(heat_map) # Write heatmap frames. torch.cuda.synchronize() end = time.time() if j > 1: # Compute fps by calculating the time used in one completed iteration, this is more accurate. fps = 1 / (end - temp) temp = end # print(f'\rDetecting: [{i + 1:02d}] {j + 1}/{len(dataset)}, {fps:.2f} fps.', end='') psnr_group.append(np.array(psnrs)) if not model: if cfg.show_curve: video_writer.release() curve_writer.release() if cfg.show_heatmap: heatmap_writer.release() print('\nAll frames were detected, begin to compute AUC.') auc = give_score(psnr_group, cfg.gt) if not model: sys.stdout = orig_stdout f.close() return auc
default=1, type='int', help='select model (int): (1-Unet, )') parser.add_option('-c', '--load', dest='load', default=False, help='load file model') (options, args) = parser.parse_args() if (options.model == 1): net = UNet(3, 1) if options.load: net.load_state_dict(torch.load(options.load)) print('Model loaded from {}'.format(options.load)) if options.gpu: net.cuda() cudnn.benchmark = True try: train_net(net, options.epochs, options.batchsize, options.lr, gpu=options.gpu) except KeyboardInterrupt: torch.save(net.state_dict(), 'INTERRUPTED.pth') print('Saved interrupt')
# Предсказание и сохранение результата TEST_DATA_PATH = '/home/kovalexal/Spaces/learning/made/made_cv/competitions/facial_points/data/test/' test_dataset = ThousandLandmarksDataset(TEST_DATA_PATH, train_transforms, split='test') # Размер батча TEST_BATCH_SIZE = 2 test_dataloader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE, num_workers=8, pin_memory=True, shuffle=False, drop_last=False) with open('{}_best.pth'.format(MODEL_NAME), 'rb') as fp: best_state_dict = torch.load(fp, map_location="cpu") model.load_state_dict(best_state_dict) test_predictions = predict(model, test_dataloader, device) with open('{}_test_predictions.pkl'.format(MODEL_NAME), 'wb') as fp: pickle.dump( { 'image_names': test_dataset.image_names, 'landmarks': test_predictions }, fp) create_submission(TEST_DATA_PATH, test_predictions, '{}_submit.csv'.format(MODEL_NAME))
def train(input_data_type, grade, seg_type, num_classes, batch_size, epochs, use_gpu, learning_rate, w_decay, pre_trained=False): logger.info('Start training using {} modal.'.format(input_data_type)) model = UNet(4, 4, residual=True, expansion=2) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=w_decay) if pre_trained: checkpoint = torch.load(pre_trained_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) if use_gpu: ts = time.time() model.to(device) print("Finish cuda loading, time elapsed {}".format(time.time() - ts)) scheduler = lr_scheduler.StepLR( optimizer, step_size=step_size, gamma=gamma) # decay LR by a factor of 0.5 every 5 epochs data_set, data_loader = get_dataset_dataloader(input_data_type, seg_type, batch_size, grade=grade) since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_iou = 0.0 epoch_loss = np.zeros((2, epochs)) epoch_acc = np.zeros((2, epochs)) epoch_class_acc = np.zeros((2, epochs)) epoch_mean_iou = np.zeros((2, epochs)) evaluator = Evaluator(num_classes) def term_int_handler(signal_num, frame): np.save(os.path.join(score_dir, 'epoch_accuracy'), epoch_acc) np.save(os.path.join(score_dir, 'epoch_mean_iou'), epoch_mean_iou) np.save(os.path.join(score_dir, 'epoch_loss'), epoch_loss) model.load_state_dict(best_model_wts) logger.info('Got terminated and saved model.state_dict') torch.save(model.state_dict(), os.path.join(score_dir, 'terminated_model.pt')) torch.save( { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, os.path.join(score_dir, 'terminated_model.tar')) quit() signal.signal(signal.SIGINT, term_int_handler) signal.signal(signal.SIGTERM, term_int_handler) for epoch in range(epochs): logger.info('Epoch {}/{}'.format(epoch + 1, epochs)) logger.info('-' * 28) for phase_ind, phase in enumerate(['train', 'val']): if phase == 'train': model.train() logger.info(phase) else: model.eval() logger.info(phase) evaluator.reset() running_loss = 0.0 running_dice = 0.0 for batch_ind, batch in enumerate(data_loader[phase]): imgs, targets = batch imgs = imgs.to(device) targets = targets.to(device) # zero the learnable parameters gradients optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(imgs) loss = criterion(outputs, targets) if phase == 'train': loss.backward() optimizer.step() preds = torch.argmax(F.softmax(outputs, dim=1), dim=1, keepdim=True) running_loss += loss * imgs.size(0) logger.debug('Batch {} running loss: {:.4f}'.format(batch_ind,\ running_loss)) # test the iou and pixelwise accuracy using evaluator preds = torch.squeeze(preds, dim=1) preds = preds.cpu().numpy() targets = targets.cpu().numpy() evaluator.add_batch(targets, preds) epoch_loss[phase_ind, epoch] = running_loss / len(data_set[phase]) epoch_acc[phase_ind, epoch] = evaluator.Pixel_Accuracy() epoch_class_acc[phase_ind, epoch] = evaluator.Pixel_Accuracy_Class() epoch_mean_iou[phase_ind, epoch] = evaluator.Mean_Intersection_over_Union() logger.info('{} loss: {:.4f}, acc: {:.4f}, class acc: {:.4f}, mean iou: {:.6f}'.format(phase,\ epoch_loss[phase_ind, epoch],\ epoch_acc[phase_ind, epoch],\ epoch_class_acc[phase_ind, epoch],\ epoch_mean_iou[phase_ind, epoch])) if phase == 'val' and epoch_mean_iou[phase_ind, epoch] > best_iou: best_iou = epoch_mean_iou[phase_ind, epoch] best_model_wts = copy.deepcopy(model.state_dict()) if phase == 'val' and (epoch + 1) % 10 == 0: logger.info('Saved model.state_dict in epoch {}'.format(epoch + 1)) torch.save( model.state_dict(), os.path.join(score_dir, 'epoch{}_model.pt'.format(epoch + 1))) print() time_elapsed = time.time() - since logger.info('Training completed in {}m {}s'.format(int(time_elapsed / 60),\ int(time_elapsed) % 60)) # load best model weights model.load_state_dict(best_model_wts) # save numpy results np.save(os.path.join(score_dir, 'epoch_accuracy'), epoch_acc) np.save(os.path.join(score_dir, 'epoch_mean_iou'), epoch_mean_iou) np.save(os.path.join(score_dir, 'epoch_loss'), epoch_loss) return model, optimizer
batch_size=cfg.bs, shuffle=True, num_workers=8, pin_memory=True, drop_last=False) if cfg.model == 'unet': model = UNet(input_channels=3).cuda() model.apply(model.weights_init_normal) else: model = DLASeg(cfg).cuda() model.train() if cfg.resume: resume_epoch = int(cfg.resume.split('.')[0].split('_')[1]) + 1 model.load_state_dict(torch.load('weights/' + cfg.resume), strict=True) print(f'Resume training with \'{cfg.resume}\'.') else: resume_epoch = 0 print('Training with ImageNet pre-trained weights.') criterion = nn.CrossEntropyLoss(ignore_index=255).cuda() if cfg.optim == 'sgd': optimizer = torch.optim.SGD(model.optim_parameters(), cfg.lr, cfg.momentum, weight_decay=cfg.decay) elif cfg.optim == 'radam': optimizer = RAdam(model.optim_parameters(), lr=cfg.lr, weight_decay=cfg.decay)
model = UNet(input_channels=NUM_INPUT_CHANNELS, output_channels=NUM_OUTPUT_CHANNELS) elif args.model == "segnet": model = SegNet(input_channels=NUM_INPUT_CHANNELS, output_hannels=NUM_OUTPUT_CHANNELS) else: model = PSPNet( layers=50, bins=(1, 2, 3, 6), dropout=0.1, classes=NUM_OUTPUT_CHANNELS, use_ppm=True, pretrained=True, ) # class_weights = 1.0 / train_dataset.get_class_probability() # criterion = torch.nn.CrossEntropyLoss(weight=class_weights) criterion = torch.nn.CrossEntropyLoss() if CUDA: model = model.cuda(device=GPU_ID) # class_weights = class_weights.cuda(GPU_ID) criterion = criterion.cuda(device=GPU_ID) if args.checkpoint: model.load_state_dict(torch.load(args.checkpoint)) optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) train()
def train(args): ''' -------------------------Hyperparameters-------------------------- ''' EPOCHS = args.epochs START = 0 # could enter a checkpoint start epoch ITER = args.iterations # per epoch LR = args.lr MOM = args.momentum # LOGInterval = args.log_interval BATCHSIZE = args.batch_size TEST_BATCHSIZE = args.test_batch_size NUMBER_OF_WORKERS = args.workers DATA_FOLDER = args.data TESTSET_FOLDER = args.testset ROOT = args.run WEIGHT_DIR = os.path.join(ROOT, "weights") CUSTOM_LOG_DIR = os.path.join(ROOT, "additionalLOGS") CHECKPOINT = os.path.join(WEIGHT_DIR, str(args.model) + str(args.name) + ".pt") useTensorboard = args.tb # check existance of data if not os.path.isdir(DATA_FOLDER): print("data folder not existant or in wrong layout.\n\t", DATA_FOLDER) exit(0) # check existance of testset if TESTSET_FOLDER is not None and not os.path.isdir(TESTSET_FOLDER): print("testset folder not existant or in wrong layout.\n\t", DATA_FOLDER) exit(0) ''' ---------------------------preparations--------------------------- ''' # CUDA for PyTorch use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if use_cuda else "cpu") print("using device: ", str(device)) # loading the validation samples to make online evaluations path_to_valX = args.valX path_to_valY = args.valY valX = None valY = None if path_to_valX is not None and path_to_valY is not None \ and os.path.exists(path_to_valX) and os.path.exists(path_to_valY) \ and os.path.isfile(path_to_valX) and os.path.isfile(path_to_valY): with torch.no_grad(): valX, valY = torch.load(path_to_valX, map_location='cpu'), \ torch.load(path_to_valY, map_location='cpu') ''' ---------------------------loading dataset and normalizing--------------------------- ''' # Dataloader Parameters train_params = { 'batch_size': BATCHSIZE, 'shuffle': True, 'num_workers': NUMBER_OF_WORKERS } test_params = { 'batch_size': TEST_BATCHSIZE, 'shuffle': False, 'num_workers': NUMBER_OF_WORKERS } # create a folder for the weights and custom logs if not os.path.isdir(WEIGHT_DIR): os.makedirs(WEIGHT_DIR) if not os.path.isdir(CUSTOM_LOG_DIR): os.makedirs(CUSTOM_LOG_DIR) labelsNorm = None # NORMLABEL # normalizing on a trainingset wide mean and std mean = None std = None if args.norm: print('computing mean and std over trainingset') # computes mean and std over all ground truths in dataset to tackle the problem of numerical insignificance mean, std = computeMeanStdOverDataset('CONRADataset', DATA_FOLDER, train_params, device) print('\niodine (mean/std): {}\t{}'.format(mean[0], std[0])) print('water (mean/std): {}\t{}\n'.format(mean[1], std[1])) labelsNorm = transforms.Normalize(mean=[0, 0], std=std) m2, s2 = computeMeanStdOverDataset('CONRADataset', DATA_FOLDER, train_params, device, transform=labelsNorm) print("new mean and std are:") print('\nnew iodine (mean/std): {}\t{}'.format(m2[0], s2[0])) print('new water (mean/std): {}\t{}\n'.format(m2[1], s2[1])) traindata = CONRADataset(DATA_FOLDER, True, device=device, precompute=True, transform=labelsNorm) testdata = None if TESTSET_FOLDER is not None: testdata = CONRADataset(TESTSET_FOLDER, False, device=device, precompute=True, transform=labelsNorm) else: testdata = CONRADataset(DATA_FOLDER, False, device=device, precompute=True, transform=labelsNorm) trainingset = DataLoader(traindata, **train_params) testset = DataLoader(testdata, **test_params) ''' ----------------loading model and checkpoints--------------------- ''' if args.model == "unet": m = UNet(2, 2).to(device) print( "using the U-Net architecture with {} trainable params; Good Luck!" .format(count_trainables(m))) else: m = simpleConvNet(2, 2).to(device) o = optim.SGD(m.parameters(), lr=LR, momentum=MOM) loss_fn = nn.MSELoss() test_loss = None train_loss = None if len(os.listdir(WEIGHT_DIR)) != 0: checkpoints = os.listdir(WEIGHT_DIR) checkDir = {} latestCheckpoint = 0 for i, checkpoint in enumerate(checkpoints): stepOfCheckpoint = int( checkpoint.split(str(args.model) + str(args.name))[-1].split('.pt')[0]) checkDir[stepOfCheckpoint] = checkpoint latestCheckpoint = max(latestCheckpoint, stepOfCheckpoint) print("[{}] {}".format(stepOfCheckpoint, checkpoint)) # if on development machine, prompt for input, else just take the most recent one if 'faui' in os.uname()[1]: toUse = int(input("select checkpoint to use: ")) else: toUse = latestCheckpoint checkpoint = torch.load(os.path.join(WEIGHT_DIR, checkDir[toUse])) m.load_state_dict(checkpoint['model_state_dict']) m.to(device) # pushing weights to gpu o.load_state_dict(checkpoint['optimizer_state_dict']) train_loss = checkpoint['train_loss'] test_loss = checkpoint['test_loss'] START = checkpoint['epoch'] print("using checkpoint {}:\n\tloss(train/test): {}/{}".format( toUse, train_loss, test_loss)) else: print("starting from scratch") ''' -----------------------------training----------------------------- ''' global_step = 0 # calculating initial loss if test_loss is None or train_loss is None: print("calculating initial loss") m.eval() print("testset...") test_loss = calculate_loss(set=testset, loss_fn=loss_fn, length_set=len(testdata), dev=device, model=m) print("trainset...") train_loss = calculate_loss(set=trainingset, loss_fn=loss_fn, length_set=len(traindata), dev=device, model=m) ## SSIM and R value R = [] SSIM = [] performanceFLE = os.path.join(CUSTOM_LOG_DIR, "performance.csv") with open(performanceFLE, 'w+') as f: f.write( "step, SSIMiodine, SSIMwater, Riodine, Rwater, train_loss, test_loss\n" ) print("computing ssim and r coefficents to: {}".format(performanceFLE)) # printing runtime information print( "starting training at {} for {} epochs {} iterations each\n\t{} total". format(START, EPOCHS, ITER, EPOCHS * ITER)) print("\tbatchsize: {}\n\tloss: {}\n\twill save results to \"{}\"".format( BATCHSIZE, train_loss, CHECKPOINT)) print( "\tmodel: {}\n\tlearningrate: {}\n\tmomentum: {}\n\tnorming output space: {}" .format(args.model, LR, MOM, args.norm)) #start actual training loops for e in range(START, START + EPOCHS): # iterations will not be interupted with validation and metrics for i in range(ITER): global_step = (e * ITER) + i # training m.train() iteration_loss = 0 for x, y in tqdm(trainingset): x, y = x.to(device=device, dtype=torch.float), y.to(device=device, dtype=torch.float) pred = m(x) loss = loss_fn(pred, y) iteration_loss += loss.item() o.zero_grad() loss.backward() o.step() print("\niteration {}: --accumulated loss {}".format( global_step, iteration_loss)) # validation, saving and logging print("\nvalidating") m.eval() # disable dropout batchnorm etc print("testset...") test_loss = calculate_loss(set=testset, loss_fn=loss_fn, length_set=len(testdata), dev=device, model=m) print("trainset...") train_loss = calculate_loss(set=trainingset, loss_fn=loss_fn, length_set=len(traindata), dev=device, model=m) print("calculating SSIM and R coefficients") currSSIM, currR = performance(set=testset, dev=device, model=m, bs=TEST_BATCHSIZE) print("SSIM (iod/water): {}/{}\nR (iod/water): {}/{}".format( currSSIM[0], currSSIM[1], currR[0], currR[1])) with open(performanceFLE, 'a') as f: newCSVline = "{}, {}, {}, {}, {}, {}, {}\n".format( global_step, currSSIM[0], currSSIM[1], currR[0], currR[1], train_loss, test_loss) f.write(newCSVline) print("wrote new line to csv:\n\t{}".format(newCSVline)) ''' if valX and valY were set in preparations, use them to perform analytics. if not, use the first sample from the testset to perform analytics ''' with torch.no_grad(): truth, pred = None, None IMAGE_LOG_DIR = os.path.join(CUSTOM_LOG_DIR, str(global_step)) if not os.path.isdir(IMAGE_LOG_DIR): os.makedirs(IMAGE_LOG_DIR) if valX is not None and valY is not None: batched = np.zeros((BATCHSIZE, *valX.numpy().shape)) batched[0] = valX.numpy() batched = torch.from_numpy(batched).to(device=device, dtype=torch.float) pred = m(batched) pred = pred.cpu().numpy()[0] truth = valY.numpy() # still on cpu assert pred.shape == truth.shape else: for x, y in testset: # x, y in shape[2,2,480,620] [b,c,h,w] x, y = x.to(device=device, dtype=torch.float), y.to(device=device, dtype=torch.float) pred = m(x) pred = pred.cpu().numpy()[ 0] # taking only the first sample of batch truth = y.cpu().numpy()[ 0] # first projection for evaluation advanvedMetrics(truth, pred, mean, std, global_step, args.norm, IMAGE_LOG_DIR) print("logging") CHECKPOINT = os.path.join( WEIGHT_DIR, str(args.model) + str(args.name) + str(global_step) + ".pt") torch.save( { 'epoch': e + 1, # end of this epoch; so resume at next. 'model_state_dict': m.state_dict(), 'optimizer_state_dict': o.state_dict(), 'train_loss': train_loss, 'test_loss': test_loss }, CHECKPOINT) print('\tsaved weigths to: ', CHECKPOINT) if logger is not None and train_loss is not None: logger.add_scalar('test_loss', test_loss, global_step=global_step) logger.add_scalar('train_loss', train_loss, global_step=global_step) logger.add_image("iodine-prediction", pred[0].reshape(1, 480, 620), global_step=global_step) logger.add_image("water-prediction", pred[1].reshape(1, 480, 620), global_step=global_step) # logger.add_image("water-prediction", wat) print( "\ttensorboard updated with test/train loss and a sample image" ) elif train_loss is not None: print("\tloss of global-step {}: {}".format( global_step, train_loss)) elif not useTensorboard: print("\t(tb-logging disabled) test/train loss: {}/{} ".format( test_loss, train_loss)) else: print("\tno loss accumulated yet") # saving final results print("saving upon exit") torch.save( { 'epoch': EPOCHS, 'model_state_dict': m.state_dict(), 'optimizer_state_dict': o.state_dict(), 'train_loss': train_loss, 'test_loss': test_loss }, CHECKPOINT) print('\tsaved progress to: ', CHECKPOINT) if logger is not None and train_loss is not None: logger.add_scalar('test_loss', test_loss, global_step=global_step) logger.add_scalar('train_loss', train_loss, global_step=global_step)
def main(): parser = argparse.ArgumentParser(description="Train the model") parser.add_argument('-trainf', "--train-filepath", type=str, default=None, required=True, help="training dataset filepath.") parser.add_argument('-validf', "--val-filepath", type=str, default=None, help="validation dataset filepath.") parser.add_argument("--shuffle", action="store_true", default=False, help="Shuffle the dataset") parser.add_argument("--load-weights", type=str, default=None, help="load pretrained weights") parser.add_argument("--load-model", type=str, default=None, help="load pretrained model, entire model (filepath, default: None)") parser.add_argument("--debug", action="store_true", default=False) parser.add_argument('--epochs', type=int, default=30, help='number of epochs to train (default: 30)') parser.add_argument("--batch-size", type=int, default=32, help="Batch size") parser.add_argument('--img-shape', type=str, default="(1,512,512)", help='Image shape (default "(1,512,512)"') parser.add_argument("--num-cpu", type=int, default=10, help="Number of CPUs to use in parallel for dataloader.") parser.add_argument('--cuda', type=int, default=0, help='CUDA visible device (use CPU if -1, default: 0)') parser.add_argument('--cuda-non-deterministic', action='store_true', default=False, help="sets flags for non-determinism when using CUDA (potentially fast)") parser.add_argument('-lr', type=float, default=0.0005, help='Learning rate') parser.add_argument('--seed', type=int, default=0, help='Seed (numpy and cuda if GPU is used.).') parser.add_argument('--log-dir', type=str, default=None, help='Save the results/model weights/logs under the directory.') args = parser.parse_args() # TODO: support image reshape img_shape = tuple(map(int, args.img_shape.strip()[1:-1].split(","))) if args.log_dir: os.makedirs(args.log_dir, exist_ok=True) best_model_path = os.path.join(args.log_dir, "model_weights.pth") else: best_model_path = None if args.seed is not None: np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda >= 0: if args.cuda_non_deterministic: printBlue("Warning: using CUDA non-deterministc. Could be faster but results might not be reproducible.") else: printBlue("Using CUDA deterministc. Use --cuda-non-deterministic might accelerate the training a bit.") # Make CuDNN Determinist torch.backends.cudnn.deterministic = not args.cuda_non_deterministic # torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # TODO [OPT] enable multi-GPUs ? # https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available() and (args.cuda >= 0) else "cpu") # ================= Build dataloader ================= # DataLoader # transform_normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], # std=[0.5, 0.5, 0.5]) transform_normalize = transforms.Normalize(mean=[0.5], std=[0.5]) # Warning: DO NOT use geometry transform (do it in the dataloader instead) data_transform = transforms.Compose([ # transforms.ToPILImage(mode='F'), # mode='F' for one-channel image # transforms.Resize((256, 256)) # NO # transforms.RandomResizedCrop(256), # NO # transforms.RandomHorizontalFlip(p=0.5), # NO # WARNING, ISSUE: transforms.ColorJitter doesn't work with ToPILImage(mode='F'). # Need custom data augmentation functions: TODO: DONE. # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Use OpenCVRotation, OpenCVXXX, ... (our implementation) # OpenCVRotation((-10, 10)), # angles (in degree) transforms.ToTensor(), # already done in the dataloader transform_normalize ]) geo_transform = GeoCompose([ OpenCVRotation(angles=(-10, 10), scales=(0.9, 1.1), centers=(-0.05, 0.05)), # TODO add more data augmentation here ]) def worker_init_fn(worker_id): # WARNING spawn start method is used, # worker_init_fn cannot be an unpicklable object, e.g., a lambda function. # A work-around for issue #5059: https://github.com/pytorch/pytorch/issues/5059 np.random.seed() data_loader_train = {'batch_size': args.batch_size, 'shuffle': args.shuffle, 'num_workers': args.num_cpu, # 'sampler': balanced_sampler, 'drop_last': True, # for GAN-like 'pin_memory': False, 'worker_init_fn': worker_init_fn, } data_loader_valid = {'batch_size': args.batch_size, 'shuffle': False, 'num_workers': args.num_cpu, 'drop_last': False, 'pin_memory': False, } train_set = LiTSDataset(args.train_filepath, dtype=np.float32, geometry_transform=geo_transform, # TODO enable data augmentation pixelwise_transform=data_transform, ) valid_set = LiTSDataset(args.val_filepath, dtype=np.float32, pixelwise_transform=data_transform, ) dataloader_train = torch.utils.data.DataLoader(train_set, **data_loader_train) dataloader_valid = torch.utils.data.DataLoader(valid_set, **data_loader_valid) # =================== Build model =================== # TODO: control the model by bash command if args.load_weights: model = UNet(in_ch=1, out_ch=3, # there are 3 classes: 0: background, 1: liver, 2: tumor depth=4, start_ch=32, # 64 inc_rate=2, kernel_size=5, # 3 padding=True, batch_norm=True, spec_norm=False, dropout=0.5, up_mode='upconv', include_top=True, include_last_act=False, ) printYellow(f"Loading pretrained weights from: {args.load_weights}...") model.load_state_dict(torch.load(args.load_weights)) printYellow("+ Done.") elif args.load_model: # load entire model model = torch.load(args.load_model) printYellow("Successfully loaded pretrained model.") model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.95)) # TODO best_valid_loss = float('inf') # TODO TODO: add learning decay for epoch in range(args.epochs): for valid_mode, dataloader in enumerate([dataloader_train, dataloader_valid]): n_batch_per_epoch = len(dataloader) if args.debug: n_batch_per_epoch = 1 # infinite dataloader allows several update per iteration (for special models e.g. GAN) dataloader = infinite_dataloader(dataloader) if valid_mode: printYellow("Switch to validation mode.") model.eval() prev_grad_mode = torch.is_grad_enabled() torch.set_grad_enabled(False) else: model.train() st = time.time() cum_loss = 0 for iter_ind in range(n_batch_per_epoch): supplement_logs = "" # reset cumulated losses at the begining of each batch # loss_manager.reset_losses() # TODO: use torch.utils.tensorboard !! optimizer.zero_grad() img, msk = next(dataloader) img, msk = img.to(device), msk.to(device) # TODO this is ugly: convert dtype and convert the shape from (N, 1, 512, 512) to (N, 512, 512) msk = msk.to(torch.long).squeeze(1) msk_pred = model(img) # shape (N, 3, 512, 512) # label_weights is determined according the liver_ratio & tumor_ratio # loss = CrossEntropyLoss(msk_pred, msk, label_weights=[1., 10., 100.], device=device) loss = DiceLoss(msk_pred, msk, label_weights=[1., 20., 50.], device=device) # loss = DiceLoss(msk_pred, msk, label_weights=[1., 20., 500.], device=device) if valid_mode: pass else: loss.backward() optimizer.step() loss = loss.item() # release cum_loss += loss if valid_mode: print("\r--------(valid) {:.2%} Loss: {:.3f} (time: {:.1f}s) |supp: {}".format( (iter_ind+1)/n_batch_per_epoch, cum_loss/(iter_ind+1), time.time()-st, supplement_logs), end="") else: print("\rEpoch: {:3}/{} {:.2%} Loss: {:.3f} (time: {:.1f}s) |supp: {}".format( (epoch+1), args.epochs, (iter_ind+1)/n_batch_per_epoch, cum_loss/(iter_ind+1), time.time()-st, supplement_logs), end="") print() if valid_mode: torch.set_grad_enabled(prev_grad_mode) valid_mean_loss = cum_loss/(iter_ind+1) # validation (mean) loss of the current epoch if best_model_path and (valid_mean_loss < best_valid_loss): printGreen("Valid loss decreases from {:.5f} to {:.5f}, saving best model.".format( best_valid_loss, valid_mean_loss)) best_valid_loss = valid_mean_loss # Only need to save the weights # torch.save(model.state_dict(), best_model_path) # save the entire model torch.save(model, best_model_path) return best_valid_loss
device = torch.device( config['device_num'] if torch.cuda.is_available() else 'cpu') if config['cont_model_path'] is None: # define the network structure -- UNet # the output size is not always equal to your input size !!! model = UNet(**config['model']).to(device) # load alrady trained model elif os.path.isdir(config['cont_model_path']): cont_model_path = config['cont_model_path'] current_optimizer = config['optimizer'] with open(os.path.join(config['cont_model_path'], 'config.json'), 'r') as f: model_config = json.load(f) # We overwrite this to be able to configure the optimizer on subsequent runs. config['model'] = model_config['model'] model = UNet(**config['model']).to(device) model.load_state_dict( torch.load(os.path.join(cont_model_path, 'checkpoint.pt'), map_location=device)) else: raise Exception( f"Model to continue training not found: {config['cont_model_path']}." ) # # need to add the mask parameter when training the partial Unet model trainNet(model, train_loader, val_loader, val_loader_ttimes, device)
if not (os.path.exists(CHECKPOINT) and os.path.isfile(CHECKPOINT)): print("weights in wrong format or non-existant: \n\t{}".format( CHECKPOINT)) exit() # loading the model m = None if args.model == "unet": m = UNet(2, 2).to(device) else: m = simpleConvNet(2, 2).to(device) print("loading model weights from \"{}\"".format(CHECKPOINT)) checkpoint = torch.load(CHECKPOINT) m.load_state_dict(checkpoint['model_state_dict']) m.to(device) # pushing weights to gpu train_loss = checkpoint['train_loss'] test_loss = checkpoint['test_loss'] START = checkpoint['epoch'] scans = [ os.path.join(root_dir, i) for i in os.listdir(os.path.abspath(root_dir)) if os.path.isdir(os.path.join(os.path.abspath(root_dir), i)) and "_" in i ] if len(scans) == 0: print( "no scan data found (folder name must be in format mmddhhmmss_x with x beeing the serialnumber" ) exit()
def val(cfg, model=None): if model: # This is for testing during training. generator = model generator.eval() else: generator = UNet(input_channels=12, output_channel=3).cuda().eval() generator.load_state_dict(torch.load('weights/' + cfg.trained_model)['net_g']) print(f'The pre-trained generator has been loaded from \'weights/{cfg.trained_model}\'.\n') # video_folders = os.listdir(cfg.test_data) # video_folders.sort() # video_folders = [os.path.join(cfg.test_data, aa) for aa in video_folders] with open(os.path.join(cfg.data_root, 'val_split_with_obj.txt')) as f: all_video_names = f.read().splitlines() video_folders = [os.path.join(cfg.data_root, 'frames', vid, 'images') for vid in all_video_names] fps = 0 psnr_group = [] if not model: if cfg.show_curve: fig = plt.figure("Image") manager = plt.get_current_fig_manager() manager.window.setGeometry(550, 200, 600, 500) # This works for QT backend, for other backends, check this ⬃⬃⬃. # https://stackoverflow.com/questions/7449585/how-do-you-set-the-absolute-position-of-figure-windows-with-matplotlib plt.xlabel('frames') plt.ylabel('psnr') plt.title('psnr curve') plt.grid(ls='--') cv2.namedWindow('target frames', cv2.WINDOW_NORMAL) cv2.resizeWindow('target frames', 384, 384) cv2.moveWindow("target frames", 100, 100) if cfg.show_heatmap: cv2.namedWindow('difference map', cv2.WINDOW_NORMAL) cv2.resizeWindow('difference map', 384, 384) cv2.moveWindow('difference map', 100, 550) # load gt labels gt_loader = Label_loader(cfg, video_folders) # Get gt labels. gt, gt_bboxes = gt_loader() with torch.no_grad(): for i, folder in tqdm(enumerate(video_folders)): dataset = Dataset.test_dataset(cfg, folder) test_dataloader = DataLoader(dataset=dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.batch_size) vid = folder.split('/')[-2] if not model: name = folder.split('/')[-1] fourcc = cv2.VideoWriter_fourcc('X', 'V', 'I', 'D') if cfg.show_curve: video_writer = cv2.VideoWriter(f'results/{name}_video.avi', fourcc, 30, cfg.img_size) curve_writer = cv2.VideoWriter(f'results/{name}_curve.avi', fourcc, 30, (600, 430)) js = [] plt.clf() ax = plt.axes(xlim=(0, len(dataset)), ylim=(30, 45)) line, = ax.plot([], [], '-b') if cfg.show_heatmap: heatmap_writer = cv2.VideoWriter(f'results/{name}_heatmap.avi', fourcc, 30, cfg.img_size) psnrs = [] diff_maps = [] # for j, clip in enumerate(dataset): for clip in test_dataloader: input_frames = clip[:, 0:12, :, :].cuda() target_frame = clip[:, 12:15, :, :].cuda() # input_np = clip[0:12, :, :] # target_np = clip[12:15, :, :] # input_frames = torch.from_numpy(input_np).unsqueeze(0).cuda() # target_frame = torch.from_numpy(target_np).unsqueeze(0).cuda() G_frame = generator(input_frames) '''TODO: save predicted frame or difference ''' test_psnr = psnr_error(G_frame, target_frame, reduce_batch=False).cpu().detach().numpy() # NOTE: Save squred diff so that we could reuse it for differen evaluation square_diff = (target_frame - G_frame).pow(2).mean(dim=1).cpu().detach().numpy().astype('float16') diff_maps.append(square_diff) # psnrs.append(float(test_psnr)) psnrs += list(test_psnr) if not model: if cfg.show_curve: cv2_frame = ((target_np + 1) * 127.5).transpose(1, 2, 0).astype('uint8') js.append(j) line.set_xdata(js) # This keeps the existing figure and updates the X-axis and Y-axis data, line.set_ydata(psnrs) # which is faster, but still not perfect. plt.pause(0.001) # show curve cv2.imshow('target frames', cv2_frame) cv2.waitKey(1) # show video video_writer.write(cv2_frame) # Write original video frames. buffer = io.BytesIO() # Write curve frames from buffer. fig.canvas.print_png(buffer) buffer.write(buffer.getvalue()) curve_img = np.array(Image.open(buffer))[..., (2, 1, 0)] curve_writer.write(curve_img) if cfg.show_heatmap: diff_map = torch.sum(torch.abs(G_frame - target_frame).squeeze(), 0) diff_map -= diff_map.min() # Normalize to 0 ~ 255. diff_map /= diff_map.max() diff_map *= 255 diff_map = diff_map.cpu().detach().numpy().astype('uint8') heat_map = cv2.applyColorMap(diff_map, cv2.COLORMAP_JET) cv2.imshow('difference map', heat_map) cv2.waitKey(1) heatmap_writer.write(heat_map) # Write heatmap frames. torch.cuda.synchronize() # end = time.time() # if j > 1: # Compute fps by calculating the time used in one completed iteration, this is more accurate. # fps = 1 / (end - temp) # temp = end # print(f'\rDetecting: [{i + 1:02d}] {j + 1}/{len(dataset)}, {fps:.2f} fps.', end='') diff_maps = np.concatenate(diff_maps, axis=0) np.save(os.path.join('saved_difference_map', vid+'.npy'), diff_maps) if len(psnrs) != len(gt[i]) - 4 or len(psnrs) != len(diff_maps): pdb.set_trace() psnr_group.append(np.array(psnrs)) if not model: if cfg.show_curve: video_writer.release() curve_writer.release() if cfg.show_heatmap: heatmap_writer.release() print('\nAll frames were detected, begin to compute AUC.') assert len(psnr_group) == len(gt), f'Ground truth has {len(gt)} videos, but got {len(psnr_group)} detected videos.' # save psnr torch.save(psnr_group, 'results/psnr_group.pth') scores = np.array([], dtype=np.float32) labels = np.array([], dtype=np.int8) for i in range(len(psnr_group)): distance = psnr_group[i] distance -= min(distance) # distance = (distance - min) / (max - min) distance /= max(distance) scores = np.concatenate((scores, distance), axis=0) labels = np.concatenate((labels, gt[i][4:]), axis=0) # Exclude the first 4 unpredictable frames in gt. torch.save(psnr_group, 'results/psnr_normalized.pth') assert scores.shape == labels.shape, \ f'Ground truth has {labels.shape[0]} frames, but got {scores.shape[0]} detected frames.' fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=0) auc = metrics.auc(fpr, tpr) print(f'AUC: {auc}\n') return auc