def build_loader_model_grapher(args): """builds a model, a dataloader and a grapher :param args: argparse :param transform: the dataloader transform :returns: a dataloader, a grapher and a model :rtype: list """ train_transform, test_transform = build_train_and_test_transforms() loader_dict = {'train_transform': train_transform, 'test_transform': test_transform, **vars(args)} loader = get_loader(**loader_dict) # set the input tensor shape (ignoring batch dimension) and related dataset sizing args.input_shape = loader.input_shape args.num_train_samples = loader.num_train_samples // args.num_replicas args.num_test_samples = loader.num_test_samples # Test isn't currently split across devices args.num_valid_samples = loader.num_valid_samples // args.num_replicas args.steps_per_train_epoch = args.num_train_samples // args.batch_size # drop-remainder args.total_train_steps = args.epochs * args.steps_per_train_epoch # build the network network = models.__dict__[args.arch](pretrained=args.pretrained, num_classes=loader.output_size) network = nn.SyncBatchNorm.convert_sync_batchnorm(network) if args.convert_to_sync_bn else network network = torch.jit.script(network) if args.jit else network network = network.cuda() if args.cuda else network lazy_generate_modules(network, loader.train_loader) network = layers.init_weights(network, init=args.weight_initialization) if args.num_replicas > 1: print("wrapping model with DDP...") network = layers.DistributedDataParallelPassthrough(network, device_ids=[0], # set w/cuda environ var output_device=0, # set w/cuda environ var find_unused_parameters=True) # Get some info about the structure and number of params. print(network) print("model has {} million parameters.".format( utils.number_of_parameters(network) / 1e6 )) # build the grapher object grapher = None if args.visdom_url is not None and args.distributed_rank == 0: grapher = Grapher('visdom', env=utils.get_name(args), server=args.visdom_url, port=args.visdom_port, log_folder=args.log_dir) elif args.distributed_rank == 0: grapher = Grapher( 'tensorboard', logdir=os.path.join(args.log_dir, utils.get_name(args))) return loader, network, grapher
def build_loader_model_grapher(args): """builds a model, a dataloader and a grapher :param args: argparse :param transform: the dataloader transform :returns: a dataloader, a grapher and a model :rtype: list """ resize_shape = (args.image_size_override, args.image_size_override) transform = [torchvision.transforms.Resize(resize_shape)] \ if args.image_size_override else None loader = get_loader(args, transform=transform, **vars(args)) # build the loader args.input_shape = loader.img_shp if args.image_size_override is None \ else [loader.img_shp[0], *resize_shape] # set the input size # build the network vae_dict = { 'simple': SimpleVAE, 'msg': MSGVAE, 'parallel': ParallellyReparameterizedVAE, 'sequential': SequentiallyReparameterizedVAE, 'vrnn': VRNN } network = vae_dict[args.vae_type](loader.img_shp, kwargs=deepcopy(vars(args))) lazy_generate_modules(network, loader.train_loader) network = network.cuda() if args.cuda else network network = append_save_and_load_fns(network, prefix="VAE_") if args.ngpu > 1: print("data-paralleling...") network.parallel() # build the grapher object if args.visdom_url: grapher = Grapher('visdom', env=get_name(args), server=args.visdom_url, port=args.visdom_port) else: grapher = Grapher('tensorboard', comment=get_name(args)) return loader, network, grapher
def build_loader_model_grapher(args): """builds a model, a dataloader and a grapher :param args: argparse :param transform: the dataloader transform :returns: a dataloader, a grapher and a model :rtype: list """ resize_shape = (args.image_size_override, args.image_size_override) transform = [transforms.Resize(resize_shape)] \ if args.image_size_override else None loader = get_loader(args, transform=transform, **vars(args)) # build the loader args.input_shape = loader.img_shp if args.image_size_override is None \ else [loader.img_shp[0], *resize_shape] # set the input size # build the network; to use your own model import and construct it here network = resnet18(num_classes=loader.output_size) lazy_generate_modules(network, loader.train_loader) network = network.cuda() if args.cuda else network network = append_save_and_load_fns(network, prefix="VAE_") if args.ngpu > 1: print("data-paralleling...") network.parallel() # build the grapher object if args.visdom_url: grapher = Grapher('visdom', env=get_name(args), server=args.visdom_url, port=args.visdom_port) else: grapher = Grapher('tensorboard', comment=get_name(args)) return loader, network, grapher
def build_loader_model_grapher(args): """builds a model, a dataloader and a grapher :param args: argparse :param transform: the dataloader transform :returns: a dataloader, a grapher and a model :rtype: list """ train_transform, test_transform = build_train_and_test_transforms() loader_dict = {'train_transform': train_transform, 'test_transform': test_transform, **vars(args)} loader = get_loader(**loader_dict) # set the input tensor shape (ignoring batch dimension) and related dataset sizing args.input_shape = loader.input_shape args.output_size = loader.output_size args.num_train_samples = loader.num_train_samples // args.num_replicas args.num_test_samples = loader.num_test_samples # Test isn't currently split across devices args.num_valid_samples = loader.num_valid_samples // args.num_replicas args.steps_per_train_epoch = args.num_train_samples // args.batch_size # drop-remainder args.total_train_steps = args.epochs * args.steps_per_train_epoch # build the network network = build_vae(args.vae_type)(loader.input_shape, kwargs=deepcopy(vars(args))) network = network.cuda() if args.cuda else network lazy_generate_modules(network, loader.train_loader) network = layers.init_weights(network, init=args.weight_initialization) if args.num_replicas > 1: print("wrapping model with DDP...") network = layers.DistributedDataParallelPassthrough(network, device_ids=[0], # set w/cuda environ var output_device=0, # set w/cuda environ var find_unused_parameters=True) # Get some info about the structure and number of params. print(network) print("model has {} million parameters.".format( utils.number_of_parameters(network) / 1e6 )) # add the test set as a np array for metrics calc if args.metrics_server is not None: network.test_images = get_numpy_dataset(task=args.task, data_dir=args.data_dir, test_transform=test_transform, split='test', image_size=args.image_size_override, cuda=args.cuda) print("Metrics test images: ", network.test_images.shape) # build the grapher object grapher = None if args.visdom_url is not None and args.distributed_rank == 0: grapher = Grapher('visdom', env=utils.get_name(args), server=args.visdom_url, port=args.visdom_port, log_folder=args.log_dir) elif args.distributed_rank == 0: grapher = Grapher( 'tensorboard', logdir=os.path.join(args.log_dir, utils.get_name(args))) return loader, network, grapher
def execute_graph(epoch, model, loader, grapher, optimizer=None, prefix='test'): """ execute the graph; when 'train' is in the name the model runs the optimizer :param epoch: the current epoch number :param model: the torch model :param loader: the train or **TEST** loader :param grapher: the graph writing helper (eg: visdom / tf wrapper) :param optimizer: the optimizer :param prefix: 'train', 'test' or 'valid' :returns: dictionary with scalars :rtype: dict """ start_time = time.time() is_eval = 'train' not in prefix model.eval() if is_eval else model.train() loss_map, num_samples = {}, 0 # iterate over data and labels for num_minibatches, (minibatch, labels) in enumerate(loader): minibatch = minibatch.cuda( non_blocking=True) if args.cuda else minibatch labels = labels.cuda(non_blocking=True) if args.cuda else labels with torch.no_grad(): if is_eval and args.polyak_ema > 0: # use the Polyak model for predictions pred_logits, reparam_map = layers.get_polyak_prediction( model, pred_fn=functools.partial(model, minibatch, labels=labels)) else: pred_logits, reparam_map = model( minibatch, labels=labels) # get normal predictions # compute loss loss_t = model.loss_function(recon_x=pred_logits, x=minibatch, params=reparam_map, K=args.monte_carlo_posterior_samples) loss_map = loss_t if not loss_map else tree.map_structure( # aggregate loss _extract_sum_scalars, loss_map, loss_t) num_samples += minibatch.size(0) # count minibatch samples del loss_t if args.debug_step and num_minibatches > 1: # for testing purposes break # compute the mean of the dict loss_map = tree.map_structure( lambda v: v / (num_minibatches + 1), loss_map) # reduce the map to get actual means # log some stuff def tensor2item(t): return t.detach().item() if isinstance(t, torch.Tensor) else t to_log = '{}-{}[Epoch {}][{} samples][{:.2f} sec]:\t Loss: {:.4f}\t-ELBO: {:.4f}\tNLL: {:.4f}\tKLD: {:.4f}\tMI: {:.4f}' print( to_log.format(prefix, args.distributed_rank, epoch, num_samples, time.time() - start_time, tensor2item(loss_map['loss_mean']), tensor2item(loss_map['elbo_mean']), tensor2item(loss_map['nll_mean']), tensor2item(loss_map['kld_mean']), tensor2item(loss_map['mut_info_mean']))) # build the image map image_map = {'input_imgs': minibatch} # activate the logits of the reconstruction and get the dict image_map = { **image_map, **model.get_activated_reconstructions(pred_logits) } # tack on remote metrics information if requested, do it in-frequently. if args.metrics_server is not None: request_remote_metrics_calc(epoch, model, grapher, prefix) # Add generations to our image dict with torch.no_grad(): prior_generated = model.generate_synthetic_samples( 10000, reset_state=True, use_aggregate_posterior=False) ema_generated = model.generate_synthetic_samples( 10000, reset_state=True, use_aggregate_posterior=True) image_map['prior'] = prior_generated image_map['ema'] = ema_generated # tack on MSSIM information if requested if args.calculate_msssim: loss_map['prior_gen_msssim_mean'] = metrics.calculate_mssim( minibatch, prior_generated[0:minibatch.shape[0]]) loss_map['ema_gen_msssim_mean'] = metrics.calculate_mssim( minibatch, ema_generated[0:minibatch.shape[0]]) # save all the images image_dir = os.path.join(args.log_dir, utils.get_name(args), 'images') os.makedirs(image_dir, exist_ok=True) for k, v in image_map.items(): grid = torchvision.utils.make_grid(v, normalize=True, scale_each=True) grid_filename = os.path.join(image_dir, "{}.png".format(k)) transforms.ToPILImage()(grid.cpu()).save(grid_filename) for idx, sample in enumerate(v): current_filename = os.path.join(image_dir, "{}_{}.png".format(idx, k)) transforms.ToPILImage()(sample.cpu()).save(current_filename) # cleanups (see https://tinyurl.com/ycjre67m) + return ELBO for early stopping for d in [loss_map, image_map, reparam_map]: d.clear() del minibatch del labels