def main(args): # ensures that weight initializations are all the same torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) logging = utils.Logger(args.global_rank, args.save) writer = utils.Writer(args.global_rank, args.save) # Get data loaders. train_queue, valid_queue, num_classes, _ = datasets.get_loaders(args) args.num_total_iter = len(train_queue) * args.epochs warmup_iters = len(train_queue) * args.warmup_epochs swa_start = len(train_queue) * (args.epochs - 1) arch_instance = utils.get_arch_cells(args.arch_instance) model = AutoEncoder(args, writer, arch_instance) model = model.cuda() logging.info('args = %s', args) logging.info('param size = %fM ', utils.count_parameters_in_M(model)) logging.info('groups per scale: %s, total_groups: %d', model.groups_per_scale, sum(model.groups_per_scale)) if args.fast_adamax: # Fast adamax has the same functionality as torch.optim.Adamax, except it is faster. cnn_optimizer = Adamax(model.parameters(), args.learning_rate, weight_decay=args.weight_decay, eps=1e-3) else: cnn_optimizer = torch.optim.Adamax(model.parameters(), args.learning_rate, weight_decay=args.weight_decay, eps=1e-3) cnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( cnn_optimizer, float(args.epochs - args.warmup_epochs - 1), eta_min=args.learning_rate_min) grad_scalar = GradScaler(2**10) num_output = utils.num_output(args.dataset, args) bpd_coeff = 1. / np.log(2.) / num_output # if load checkpoint_file = os.path.join(args.save, 'checkpoint.pt') if args.cont_training: logging.info('loading the model.') checkpoint = torch.load(checkpoint_file, map_location='cpu') init_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) model = model.cuda() cnn_optimizer.load_state_dict(checkpoint['optimizer']) grad_scalar.load_state_dict(checkpoint['grad_scalar']) cnn_scheduler.load_state_dict(checkpoint['scheduler']) global_step = checkpoint['global_step'] else: global_step, init_epoch = 0, 0 for epoch in range(init_epoch, args.epochs): # update lrs. if args.distributed: train_queue.sampler.set_epoch(global_step + args.seed) valid_queue.sampler.set_epoch(0) if epoch > args.warmup_epochs: cnn_scheduler.step() # Logging. logging.info('epoch %d', epoch) # Training. train_nelbo, global_step = train(train_queue, model, cnn_optimizer, grad_scalar, global_step, warmup_iters, writer, logging) logging.info('train_nelbo %f', train_nelbo) writer.add_scalar('train/nelbo', train_nelbo, global_step) model.eval() # generate samples less frequently eval_freq = 1 if args.epochs <= 50 else 20 if epoch % eval_freq == 0 or epoch == (args.epochs - 1): with torch.no_grad(): num_samples = 16 n = int(np.floor(np.sqrt(num_samples))) for t in [0.7, 0.8, 0.9, 1.0]: logits = model.sample(num_samples, t) output = model.decoder_output(logits) output_img = output.mean if isinstance( output, torch.distributions.bernoulli.Bernoulli ) else output.sample(t) output_tiled = utils.tile_image(output_img, n) writer.add_image('generated_%0.1f' % t, output_tiled, global_step) valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=10, args=args, logging=logging) logging.info('valid_nelbo %f', valid_nelbo) logging.info('valid neg log p %f', valid_neg_log_p) logging.info('valid bpd elbo %f', valid_nelbo * bpd_coeff) logging.info('valid bpd log p %f', valid_neg_log_p * bpd_coeff) writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch) writer.add_scalar('val/nelbo', valid_nelbo, epoch) writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch) writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch) save_freq = int(np.ceil(args.epochs / 100)) if epoch % save_freq == 0 or epoch == (args.epochs - 1): if args.global_rank == 0: logging.info('saving the model.') torch.save( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': cnn_optimizer.state_dict(), 'global_step': global_step, 'args': args, 'arch_instance': arch_instance, 'scheduler': cnn_scheduler.state_dict(), 'grad_scalar': grad_scalar.state_dict() }, checkpoint_file) # Final validation valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=1000, args=args, logging=logging) logging.info('final valid nelbo %f', valid_nelbo) logging.info('final valid neg log p %f', valid_neg_log_p) writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch + 1) writer.add_scalar('val/nelbo', valid_nelbo, epoch + 1) writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch + 1) writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch + 1) writer.close()
# Model model = networks.CamelyonClassifier() model.to(device) # Optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=1e-8) # Loss function criterion = utils.loss # Visdom writer writer = utils.Writer() def train(): model.train() losses = [] for idx in range(1, args.n_iters + 1): # Zero gradient optimizer.zero_grad() # Load data to GPU sample = dataset_train[idx] images = sample['images'].to(device) labels = sample['labels'].to(device)