def build_loader_model_grapher(args): """builds a model, a dataloader and a grapher :param args: argparse :param transform: the dataloader transform :returns: a dataloader, a grapher and a model :rtype: list """ train_transform, test_transform = build_train_and_test_transforms() loader_dict = {'train_transform': train_transform, 'test_transform': test_transform, **vars(args)} loader = get_loader(**loader_dict) # set the input tensor shape (ignoring batch dimension) and related dataset sizing args.input_shape = loader.input_shape args.num_train_samples = loader.num_train_samples // args.num_replicas args.num_test_samples = loader.num_test_samples # Test isn't currently split across devices args.num_valid_samples = loader.num_valid_samples // args.num_replicas args.steps_per_train_epoch = args.num_train_samples // args.batch_size # drop-remainder args.total_train_steps = args.epochs * args.steps_per_train_epoch # build the network network = models.__dict__[args.arch](pretrained=args.pretrained, num_classes=loader.output_size) network = nn.SyncBatchNorm.convert_sync_batchnorm(network) if args.convert_to_sync_bn else network network = torch.jit.script(network) if args.jit else network network = network.cuda() if args.cuda else network lazy_generate_modules(network, loader.train_loader) network = layers.init_weights(network, init=args.weight_initialization) if args.num_replicas > 1: print("wrapping model with DDP...") network = layers.DistributedDataParallelPassthrough(network, device_ids=[0], # set w/cuda environ var output_device=0, # set w/cuda environ var find_unused_parameters=True) # Get some info about the structure and number of params. print(network) print("model has {} million parameters.".format( utils.number_of_parameters(network) / 1e6 )) # build the grapher object grapher = None if args.visdom_url is not None and args.distributed_rank == 0: grapher = Grapher('visdom', env=utils.get_name(args), server=args.visdom_url, port=args.visdom_port, log_folder=args.log_dir) elif args.distributed_rank == 0: grapher = Grapher( 'tensorboard', logdir=os.path.join(args.log_dir, utils.get_name(args))) return loader, network, grapher
def build_encoder(self): """ helper to build the encoder type :returns: an encoder :rtype: nn.Module """ encoder = layers.get_encoder(**self.config)( output_size=self.reparameterizer.input_size) print('encoder has {} parameters\n'.format( utils.number_of_parameters(encoder) / 1e6)) return torch.jit.script(encoder) if self.config['jit'] else encoder
def build_decoder(self, reupsample=True): """ helper function to build convolutional or dense decoder :returns: a decoder :rtype: nn.Module """ dec_conf = deepcopy(self.config) if dec_conf['nll_type'] == 'pixel_wise': dec_conf['input_shape'][0] *= 256 decoder = layers.get_decoder( output_shape=dec_conf['input_shape'], **dec_conf)(input_size=self.reparameterizer.output_size) print('decoder has {} parameters\n'.format( utils.number_of_parameters(decoder) / 1e6)) # append the variance as necessary decoder = self._append_variance_projection(decoder) return torch.jit.script(decoder) if self.config['jit'] else decoder
def run(args): # collect our model and data loader model, loader, grapher = get_model_and_loader() print("model has {} params".format(number_of_parameters(model))) # collect our optimizer optimizer = build_optimizer(model) # train the VAE on the same distributions as the model pool if args.restore is None: print("training current distribution for {} epochs".format( args.epochs)) early = EarlyStopping(model, burn_in_interval=100, max_steps=80) if args.early_stop else None test_map = {} for epoch in range(1, args.epochs + 1): generate(epoch, model, grapher) train(epoch, model, optimizer, loader.train_loader, grapher) test_map = test(epoch, model, loader.test_loader, grapher) if args.early_stop and early(test_map['pred_loss_mean']): early.restore() # restore and test again test_map = test(epoch, model, loader.test_loader, grapher) break # adjust the LR if using momentum sgd if args.optimizer == 'sgd_momentum': decay_lr_every(optimizer, args.lr, epoch) grapher.save() # save to endpoint after training else: assert model.load(args.restore), "Failed to load model" test_loss, test_acc = test(epoch, model, loader.test_loader, grapher) # evaluate one-time metrics scalar_map_to_csvs(test_map) # cleanups grapher.close()
def run(args): # collect our model and data loader model, loader, grapher = get_model_and_loader() print("model has {} params".format(number_of_parameters(model))) # collect our optimizer optimizer = build_optimizer(model) # train the VAE on the same distributions as the model pool if args.restore is None: print("training current distribution for {} epochs".format( args.epochs)) early = EarlyStopping(model, burn_in_interval=100, max_steps=80) if args.early_stop else None test_loss, test_acc = 0.0, 0.0 for epoch in range(1, args.epochs + 1): train(epoch, model, optimizer, loader.train_loader, grapher) test_loss = test(epoch, model, loader.test_loader, grapher) if args.early_stop and early(test_loss['loss_mean']): early.restore() # restore and test+generate again test_loss = test(epoch, model, loader.test_loader, grapher) break # adjust the LR if using momentum sgd if args.optimizer == 'sgd_momentum': decay_lr_every(optimizer, args.lr, epoch) grapher.save() # save to endpoint after training else: model = torch.load(args.restore) test_loss = test(epoch, model, loader.test_loader, grapher) # evaluate one-time metrics append_to_csv([test_loss['acc_mean']], "{}_test_acc.csv".format(args.uid)) # cleanups grapher.close()
def train_loop(data_loaders, model, fid_model, grapher, args): ''' simple helper to run the entire train loop; not needed for eval modes''' optimizer = build_optimizer(model.student) # collect our optimizer print( "there are {} params with {} elems in the st-model and {} params in the student with {} elems" .format(len(list(model.parameters())), number_of_parameters(model), len(list(model.student.parameters())), number_of_parameters(model.student))) # main training loop fisher = None for j, loader in enumerate(data_loaders): num_epochs = args.epochs # TODO: randomize epochs by something like: + np.random.randint(0, 13) print("training current distribution for {} epochs".format(num_epochs)) early = EarlyStopping( model, max_steps=50, burn_in_interval=None) if args.early_stop else None #burn_in_interval=int(num_epochs*0.2)) if args.early_stop else None test_loss = None for epoch in range(1, num_epochs + 1): train(epoch, model, fisher, optimizer, loader.train_loader, grapher) test_loss = test(epoch, model, fisher, loader.test_loader, grapher) if args.early_stop and early(test_loss['loss_mean']): early.restore() # restore and test+generate again test_loss = test_and_generate(epoch, model, fisher, loader, grapher) break generate(model, grapher, 'student') # generate student samples generate(model, grapher, 'teacher') # generate teacher samples # evaluate and save away one-time metrics, these include: # 1. test elbo # 2. FID # 3. consistency # 4. num synth + num true samples # 5. dump config to visdom check_or_create_dir(os.path.join(args.output_dir)) append_to_csv([test_loss['elbo_mean']], os.path.join(args.output_dir, "{}_test_elbo.csv".format(args.uid))) append_to_csv([test_loss['elbo_mean']], os.path.join(args.output_dir, "{}_test_elbo.csv".format(args.uid))) num_synth_samples = np.ceil(epoch * args.batch_size * model.ratio) num_true_samples = np.ceil(epoch * (args.batch_size - (args.batch_size * model.ratio))) append_to_csv([num_synth_samples], os.path.join(args.output_dir, "{}_numsynth.csv".format(args.uid))) append_to_csv([num_true_samples], os.path.join(args.output_dir, "{}_numtrue.csv".format(args.uid))) append_to_csv([epoch], os.path.join(args.output_dir, "{}_epochs.csv".format(args.uid))) grapher.vis.text(num_synth_samples, opts=dict(title="num_synthetic_samples")) grapher.vis.text(num_true_samples, opts=dict(title="num_true_samples")) grapher.vis.text(pprint.PrettyPrinter(indent=4).pformat( model.student.config), opts=dict(title="config")) # calc the consistency using the **PREVIOUS** loader if j > 0: append_to_csv( calculate_consistency(model, data_loaders[j - 1], args.reparam_type, args.vae_type, args.cuda), os.path.join(args.output_dir, "{}_consistency.csv".format(args.uid))) if args.calculate_fid_with is not None: # TODO: parameterize num fid samples, currently use less for inceptionv3 as it's COSTLY num_fid_samples = 4000 if args.calculate_fid_with != 'inceptionv3' else 1000 append_to_csv( calculate_fid(fid_model=fid_model, model=model, loader=loader, grapher=grapher, num_samples=num_fid_samples, cuda=args.cuda), os.path.join(args.output_dir, "{}_fid.csv".format(args.uid))) grapher.save() # save the remote visdom graphs if j != len(data_loaders) - 1: if args.ewc_gamma > 0: # calculate the fisher from the previous data loader print("computing fisher info matrix....") fisher_tmp = estimate_fisher( model.student, # this is pre-fork loader, args.batch_size, cuda=args.cuda) if fisher is not None: assert len(fisher) == len( fisher_tmp), "#fisher params != #new fisher params" for (kf, vf), (kft, vft) in zip(fisher.items(), fisher_tmp.items()): fisher[kf] += fisher_tmp[kft] else: fisher = fisher_tmp # spawn a new student & rebuild grapher; we also pass # the new model's parameters through a new optimizer. if not args.disable_student_teacher: model.fork() lazy_generate_modules(model, data_loaders[0].img_shp) optimizer = build_optimizer(model.student) print( "there are {} params with {} elems in the st-model and {} params in the student with {} elems" .format(len(list(model.parameters())), number_of_parameters(model), len(list(model.student.parameters())), number_of_parameters(model.student))) else: # increment anyway for vanilla models # so that we can have a separate visdom env model.current_model += 1 grapher = Grapher(env=model.get_name(), server=args.visdom_url, port=args.visdom_port)
def build_loader_model_grapher(args): """builds a model, a dataloader and a grapher :param args: argparse :param transform: the dataloader transform :returns: a dataloader, a grapher and a model :rtype: list """ train_transform, test_transform = build_train_and_test_transforms() loader_dict = {'train_transform': train_transform, 'test_transform': test_transform, **vars(args)} loader = get_loader(**loader_dict) # set the input tensor shape (ignoring batch dimension) and related dataset sizing args.input_shape = loader.input_shape args.output_size = loader.output_size args.num_train_samples = loader.num_train_samples // args.num_replicas args.num_test_samples = loader.num_test_samples # Test isn't currently split across devices args.num_valid_samples = loader.num_valid_samples // args.num_replicas args.steps_per_train_epoch = args.num_train_samples // args.batch_size # drop-remainder args.total_train_steps = args.epochs * args.steps_per_train_epoch # build the network network = build_vae(args.vae_type)(loader.input_shape, kwargs=deepcopy(vars(args))) network = network.cuda() if args.cuda else network lazy_generate_modules(network, loader.train_loader) network = layers.init_weights(network, init=args.weight_initialization) if args.num_replicas > 1: print("wrapping model with DDP...") network = layers.DistributedDataParallelPassthrough(network, device_ids=[0], # set w/cuda environ var output_device=0, # set w/cuda environ var find_unused_parameters=True) # Get some info about the structure and number of params. print(network) print("model has {} million parameters.".format( utils.number_of_parameters(network) / 1e6 )) # add the test set as a np array for metrics calc if args.metrics_server is not None: network.test_images = get_numpy_dataset(task=args.task, data_dir=args.data_dir, test_transform=test_transform, split='test', image_size=args.image_size_override, cuda=args.cuda) print("Metrics test images: ", network.test_images.shape) # build the grapher object grapher = None if args.visdom_url is not None and args.distributed_rank == 0: grapher = Grapher('visdom', env=utils.get_name(args), server=args.visdom_url, port=args.visdom_port, log_folder=args.log_dir) elif args.distributed_rank == 0: grapher = Grapher( 'tensorboard', logdir=os.path.join(args.log_dir, utils.get_name(args))) return loader, network, grapher