def main(args): dataset = get_data_loader( args.files_pattern, batch_size=args.batch_size, num_workers=args.num_data_workers, crop_size=args.crop_size if args.crop_size != 0 else None) device = torch.cuda.current_device() if args.dataloader_only and dist.get_rank() == 0: print("Running in dataloader_only mode ...") else: model = UNet().to(device) if dist.is_initialized(): model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=[args.local_rank]) if args.forward_only: if dist.get_rank() == 0: print("Running in inference (forward only) mode ...") else: if dist.get_rank() == 0: print("Running in training (forward/backward) mode ...") optimizer = torch.optim.Adam(model.parameters()) total_time = 0 for epoch in range(args.epochs + 1): if dist.get_rank() == 0: print("epoch", epoch) t_start = time.time() for idx, data in enumerate(dataset): if idx > args.max_batches_per_epoch: break if args.dataloader_only: continue inp, tar = map(lambda x: x.to(device), data) if args.forward_only: model.eval() gen = model(inp) else: model.zero_grad() model.train() gen = model(inp) loss = torch.nn.functional.l1_loss(gen, tar) loss.backward() optimizer.step() if epoch > 0: total_time += time.time() - t_start n_batches = min(args.max_batches_per_epoch + 1, len(dataset)) if dist.get_rank() == 0: print("Timing:", float(args.batch_size * n_batches * args.epochs) / (total_time), "samples/s")
def main(): parser = argparse.ArgumentParser( description='Visualize segmentations obtained from trained UNet') parser.add_argument('checkpoint', type=str, help='Path to UNet model checkpoint') parser.add_argument( 'img_path', type=str, help='Path to image or directory containing images to segment') parser.add_argument('-r', '--resize', type=int, default=0, help='Resize size of the image prior to segmentation') parser.add_argument('-o', '--out', type=str, default='./', help='Path to write segmentation visualizations to') args = parser.parse_args() if not os.path.exists(args.checkpoint): sys.exit('Specified checkpoint cannot be found') if not os.path.exists(args.img_path): sys.exit('Images for segmentation could not be found') imgs = [] if os.path.isdir(args.img_path): for file in os.listdir(args.img_path): if os.path.isfile(os.path.join(args.img_path, file)): imgs.append(os.path.join(args.img_path, file)) else: imgs.append(args.img_path) checkpoint = torch.load(args.checkpoint, map_location=lambda storage, loc: storage) model = UNet(num_classes=len(datasets.Cityscapes.classes)) model.load_state_dict(checkpoint['model_state_dict']) model.eval() visualizer = CityscapeSegmentationVis(model, input_image_transform(args.resize)) for img in imgs: image_name = 'segmentation_' + img.split('/')[-1] out_location = os.path.join(args.out, image_name) class_tensor = visualizer.get_predicted_segmentation(img) visualizer.save_segmentation(class_tensor, out_location)
def channel_vis_driver(model_name, checkpoint_path, data_path, dataset, conv_layer, channels, init_img_size, upscale_steps, upscale_factor, lr, update_steps, grid, out_path, verbose): if model_name == 'unet': if not os.path.exists(checkpoint_path): sys.exit('Specified checkpoint cannot be found') checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) model = UNet(num_classes=len(datasets.Cityscapes.classes), encoder_only=True) model.load_state_dict(checkpoint['model_state_dict']) elif model_name == 'vggmod': model = models.vgg11(pretrained=True) else: sys.exit('No model provided, please specify --unet or --vgg11 to analyze the UNet or VGG11 encoder, respectively') # Set model to evaluation mode and fix the parameter values model.eval() for param in model.parameters(): param.requires_grad_(False) layer = get_conv_layer(model, conv_layer) analyzer = LayerActivationAnalysis(model, layer) # Save a grid of channel activation visualizations if grid: if not channels: # Get a random sample of 9 activated channels channels = analyzer.get_activated_filter_indices() np.random.shuffle(channels) channels = channels[:9] imgs = [] for i, channel in enumerate(channels): if verbose: print('Generating image {} of {}...'.format(i+1, len(channels))) img = analyzer.get_max_activating_image(channel, initial_img_size=init_img_size, upscaling_steps=upscale_steps, upscaling_factor=upscale_factor, lr=lr, update_steps=update_steps, verbose=verbose) imgs.append(img) channel_string = '-'.join(str(channel_id) for channel_id in channels) output_dest = os.path.join(out_path, '{}_layer{}_channels{}.png'.format(model_name, conv_layer, channel_string)) save_image_grid(imgs, output_dest) # Save a channel activation visualization for each specified channel elif channels is not None: for channel in channels: img = analyzer.get_max_activating_image(channel, initial_img_size=init_img_size, upscaling_steps=upscale_steps, upscaling_factor=upscale_factor, lr=lr, update_steps=update_steps, verbose=verbose) output_dest = os.path.join(out_path, '{}_layer{}_channel{}.png'.format(model_name, conv_layer, channel)) save_image(img, output_dest) else: # Compute the average number number of channels activated in each layer if data_path and dataset: layers = [get_conv_layer(model, i) for i in [1,2,3,4,5,6,7,8]] avg = analyzer.get_avg_activated_channels(layers, data_path, dataset, 100) print('Average number of channels activated per convolutional layer: {}'.format(avg)) # Output the channels activated by a randomly initialize image else: activated_channels = analyzer.get_activated_filter_indices(initial_img_size=init_img_size) print('Output channels in conv layer {} activated by random image input:'.format(conv_layer)) print(activated_channels) print() print('(Total of {} activated channels)'.format(len(activated_channels)))