def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) all_data = torch.load(args.data_file) x_train, x_val, x_test = all_data y_size = 1 y_train = torch.zeros(x_train.size(0), y_size) y_val = torch.zeros(x_val.size(0), y_size) y_test = torch.zeros(x_test.size(0), y_size) train = torch.utils.data.TensorDataset(x_train, y_train) val = torch.utils.data.TensorDataset(x_val, y_val) test = torch.utils.data.TensorDataset(x_test, y_test) train_loader = torch.utils.data.DataLoader(train, batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val, batch_size=args.batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test, batch_size=args.batch_size, shuffle=True) print('Train data: %d batches' % len(train_loader)) print('Val data: %d batches' % len(val_loader)) print('Test data: %d batches' % len(test_loader)) if args.slurm == 0: cuda.set_device(args.gpu) if args.model == 'autoreg': args.latent_feature_map = 0 if args.train_from == '': model = CNNVAE(img_size=args.img_size, latent_dim=args.latent_dim, enc_layers=args.enc_layers, dec_kernel_size=args.dec_kernel_size, dec_layers=args.dec_layers, latent_feature_map=args.latent_feature_map) else: print('loading model from ' + args.train_from) checkpoint = torch.load(args.train_from) model = checkpoint['model'] print("model architecture") print(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999)) model.cuda() model.train() def variational_loss(input, img, model, z=None): mean, logvar = input z_samples = model._reparameterize(mean, logvar, z) preds = model._dec_forward(img, z_samples) nll = utils.log_bernoulli_loss(preds, img) kl = utils.kl_loss_diag(mean, logvar) return nll + args.beta * kl update_params = list(model.dec.parameters()) meta_optimizer = OptimN2N(variational_loss, model, update_params, eps=args.eps, lr=[args.svi_lr1, args.svi_lr2], iters=args.svi_steps, momentum=args.momentum, acc_param_grads=args.train_n2n == 1, max_grad_norm=args.svi_max_grad_norm) epoch = 0 t = 0 best_val_nll = 1e5 best_epoch = 0 loss_stats = [] if args.warmup == 0: args.beta = 1. else: args.beta = 0.1 if args.test == 1: args.beta = 1 eval(test_loader, model, meta_optimizer) exit() while epoch < args.num_epochs: start_time = time.time() epoch += 1 print('Starting epoch %d' % epoch) train_nll_vae = 0. train_nll_autoreg = 0. train_kl_vae = 0. train_nll_svi = 0. train_kl_svi = 0. num_examples = 0 for b, datum in enumerate(train_loader): if args.warmup > 0: args.beta = min( 1, args.beta + 1. / (args.warmup * len(train_loader))) img, _ = datum img = torch.bernoulli(img) batch_size = img.size(0) img = Variable(img.cuda()) t += 1 optimizer.zero_grad() if args.model == 'autoreg': preds = model._dec_forward(img, None) nll_autoreg = utils.log_bernoulli_loss(preds, img) train_nll_autoreg += nll_autoreg.data[0] * batch_size nll_autoreg.backward() elif args.model == 'svi': mean_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad=True) logvar_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], img, t % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) preds = model._dec_forward(img, z_samples) nll_svi = utils.log_bernoulli_loss(preds, img) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward() else: mean, logvar = model._enc_forward(img) z_samples = model._reparameterize(mean, logvar) preds = model._dec_forward(img, z_samples) nll_vae = utils.log_bernoulli_loss(preds, img) train_nll_vae += nll_vae.data[0] * batch_size kl_vae = utils.kl_loss_diag(mean, logvar) train_kl_vae += kl_vae.data[0] * batch_size if args.model == 'vae': vae_loss = nll_vae + args.beta * kl_vae vae_loss.backward(retain_graph=True) if args.model == 'savae': var_params = torch.cat([mean, logvar], 1) mean_svi = Variable(mean.data, requires_grad=True) logvar_svi = Variable(logvar.data, requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], img, t % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final, logvar_svi_final) preds = model._dec_forward(img, z_samples) nll_svi = utils.log_bernoulli_loss(preds, img) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward(retain_graph=True) if args.train_n2n == 0: if args.train_kl == 1: mean_final = mean_svi_final.detach() logvar_final = logvar_svi_final.detach() kl_init_final = utils.kl_loss( mean, logvar, mean_final, logvar_final) kl_init_final.backward(retain_graph=True) else: vae_loss = nll_vae + args.beta * kl_vae var_param_grads = torch.autograd.grad( vae_loss, [mean, logvar], retain_graph=True) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads, retain_graph=True) else: var_param_grads = meta_optimizer.backward( [mean_svi_final.grad, logvar_svi_final.grad], t % args.print_every == 0) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads) if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm) optimizer.step() num_examples += batch_size if t % args.print_every == 0: param_norm = sum([p.norm()**2 for p in model.parameters()]).data[0]**0.5 print( 'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARNLL: %.2f, TrainVAE_NLL: %.2f, TrainVAE_KL: %.4f, TrainVAE_NLLBnd: %.2f, TrainSVI_NLL: %.2f, TrainSVI_KL: %.4f, TrainSVI_NLLBnd: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.3f, Throughput: %.2f examples/sec' % (t, epoch, b + 1, len(train_loader), args.lr, train_nll_autoreg / num_examples, train_nll_vae / num_examples, train_kl_vae / num_examples, (train_nll_vae + train_kl_vae) / num_examples, train_nll_svi / num_examples, train_kl_svi / num_examples, (train_nll_svi + train_kl_svi) / num_examples, param_norm, best_val_nll, best_epoch, args.beta, num_examples / (time.time() - start_time))) print('--------------------------------') print('Checking validation perf...') val_nll = eval(val_loader, model, meta_optimizer) loss_stats.append(val_nll) if val_nll < best_val_nll: best_val_nll = val_nll best_epoch = epoch checkpoint = { 'args': args.__dict__, 'model': model, 'optimizer': optimizer, 'loss_stats': loss_stats } print('Savaeng checkpoint to %s' % args.checkpoint_path) torch.save(checkpoint, args.checkpoint_path)
def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) train_data = Dataset(args.train_file) val_data = Dataset(args.val_file) test_data = Dataset(args.test_file) train_sents = train_data.batch_size.sum() vocab_size = int(train_data.vocab_size) logger.info('Train data: %d batches' % len(train_data)) logger.info('Val data: %d batches' % len(val_data)) logger.info('Test data: %d batches' % len(test_data)) logger.info('Word vocab size: %d' % vocab_size) checkpoint_dir = args.checkpoint_dir if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) suffix = "%s_%s.pt" % (args.model, 'cyc') checkpoint_path = os.path.join(checkpoint_dir, suffix) if args.slurm == 0: cuda.set_device(args.gpu) if args.train_from == '': model = RNNVAE(vocab_size=vocab_size, enc_word_dim=args.enc_word_dim, enc_h_dim=args.enc_h_dim, enc_num_layers=args.enc_num_layers, dec_word_dim=args.dec_word_dim, dec_h_dim=args.dec_h_dim, dec_num_layers=args.dec_num_layers, dec_dropout=args.dec_dropout, latent_dim=args.latent_dim, mode=args.model) for param in model.parameters(): param.data.uniform_(-0.1, 0.1) else: logger.info('loading model from ' + args.train_from) checkpoint = torch.load(args.train_from) model = checkpoint['model'] logger.info("model architecture") print(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.warmup == 0: args.beta = 1. else: args.beta = 0.1 criterion = nn.NLLLoss() model.cuda() criterion.cuda() model.train() def variational_loss(input, sents, model, z=None): mean, logvar = input z_samples = model._reparameterize(mean, logvar, z) preds = model._dec_forward(sents, z_samples) nll = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(preds.size(1)) ]) kl = utils.kl_loss_diag(mean, logvar) return nll + args.beta * kl update_params = list(model.dec.parameters()) meta_optimizer = OptimN2N(variational_loss, model, update_params, eps=args.eps, lr=[args.svi_lr1, args.svi_lr2], iters=args.svi_steps, momentum=args.momentum, acc_param_grads=args.train_n2n == 1, max_grad_norm=args.svi_max_grad_norm) if args.test == 1: args.beta = 1 test_data = Dataset(args.test_file) eval(test_data, model, meta_optimizer) exit() t = 0 best_val_nll = 1e5 best_epoch = 0 val_stats = [] epoch = 0 while epoch < args.num_epochs: start_time = time.time() epoch += 1 logger.info('Starting epoch %d' % epoch) train_nll_vae = 0. train_nll_autoreg = 0. train_kl_vae = 0. train_nll_svi = 0. train_kl_svi = 0. train_kl_init_final = 0. num_sents = 0 num_words = 0 b = 0 tmp = float((epoch - 1) % args.cycle) / args.cycle cur_lr = args.lr * 0.5 * (1 + np.cos(tmp * np.pi)) for param_group in optimizer.param_groups: param_group['lr'] = cur_lr if (epoch - 1) % args.cycle == 0: args.beta = 0.1 logger.info('KL annealing restart') for i in np.random.permutation(len(train_data)): if args.warmup > 0: args.beta = min( 1, args.beta + 1. / (args.warmup * len(train_data))) sents, length, batch_size = train_data[i] if args.gpu >= 0: sents = sents.cuda() b += 1 optimizer.zero_grad() if args.model == 'autoreg': preds = model._dec_forward(sents, None, True) nll_autoreg = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(length) ]) train_nll_autoreg += nll_autoreg.data[0] * batch_size nll_autoreg.backward() elif args.model == 'svi': mean_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad=True) logvar_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], sents, b % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) preds = model._dec_forward(sents, z_samples) nll_svi = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(length) ]) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward(retain_graph=True) else: mean, logvar = model._enc_forward(sents) z_samples = model._reparameterize(mean, logvar) preds = model._dec_forward(sents, z_samples) nll_vae = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(length) ]) train_nll_vae += nll_vae.data[0] * batch_size kl_vae = utils.kl_loss_diag(mean, logvar) train_kl_vae += kl_vae.data[0] * batch_size if args.model == 'vae': vae_loss = nll_vae + args.beta * kl_vae vae_loss.backward(retain_graph=True) if args.model == 'savae': var_params = torch.cat([mean, logvar], 1) mean_svi = Variable(mean.data, requires_grad=True) logvar_svi = Variable(logvar.data, requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], sents, b % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final, logvar_svi_final) preds = model._dec_forward(sents, z_samples) nll_svi = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(length) ]) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward(retain_graph=True) if args.train_n2n == 0: if args.train_kl == 1: mean_final = mean_svi_final.detach() logvar_final = logvar_svi_final.detach() kl_init_final = utils.kl_loss( mean, logvar, mean_final, logvar_final) train_kl_init_final += kl_init_final.data[ 0] * batch_size kl_init_final.backward(retain_graph=True) else: vae_loss = nll_vae + args.beta * kl_vae var_param_grads = torch.autograd.grad( vae_loss, [mean, logvar], retain_graph=True) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads, retain_graph=True) else: var_param_grads = meta_optimizer.backward( [mean_svi_final.grad, logvar_svi_final.grad], b % args.print_every == 0) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads) if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm) optimizer.step() num_sents += batch_size num_words += batch_size * length if b % args.print_every == 0: param_norm = sum([p.norm()**2 for p in model.parameters()]).data[0]**0.5 logger.info( 'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARNLL: %.4f, TrainARPPL: %.2f, TrainVAE_NLL: %.4f, TrainVAE_REC: %.4f, TrainVAE_KL: %.4f, TrainVAE_PPL: %.2f, TrainSVI_NLL: %.2f, TrainSVI_REC: %.2f, TrainSVI_KL: %.4f, TrainSVI_PPL: %.2f, KLInitFinal: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.4f, Throughput: %.2f examples/sec' % (t, epoch, b + 1, len(train_data), cur_lr, train_nll_autoreg / num_sents, np.exp(train_nll_autoreg / num_words), (train_nll_vae + train_kl_vae) / num_sents, train_nll_vae / num_sents, train_kl_vae / num_sents, np.exp((train_nll_vae + train_kl_vae) / num_words), (train_nll_svi + train_kl_svi) / num_sents, train_nll_svi / num_sents, train_kl_svi / num_sents, np.exp((train_nll_svi + train_kl_svi) / num_words), train_kl_init_final / num_sents, param_norm, best_val_nll, best_epoch, args.beta, num_sents / (time.time() - start_time))) epoch_train_time = time.time() - start_time logger.info('Time Elapsed: %.1fs' % epoch_train_time) logger.info('--------------------------------') logger.info('Checking validation perf...') logger.record_tabular('Epoch', epoch) logger.record_tabular('Mode', 'Val') logger.record_tabular('LR', cur_lr) logger.record_tabular('Epoch Train Time', epoch_train_time) val_nll = eval(val_data, model, meta_optimizer) val_stats.append(val_nll) logger.info('--------------------------------') logger.info('Checking test perf...') logger.record_tabular('Epoch', epoch) logger.record_tabular('Mode', 'Test') logger.record_tabular('LR', cur_lr) logger.record_tabular('Epoch Train Time', epoch_train_time) test_nll = eval(test_data, model, meta_optimizer) if val_nll < best_val_nll: best_val_nll = val_nll best_epoch = epoch model.cpu() checkpoint = { 'args': args.__dict__, 'model': model, 'val_stats': val_stats } logger.info('Save checkpoint to %s' % checkpoint_path) torch.save(checkpoint, checkpoint_path) model.cuda() else: if epoch >= args.min_epochs: args.decay = 1
def train(config): base_network = network.ResNetFc('ResNet50', use_bottleneck=True, bottleneck_dim=config["bottleneck_dim"], new_cls=True, class_num=config["class_num"]) ad_net = network.AdversarialNetwork(config["bottleneck_dim"], config["hidden_dim"]) base_network = base_network.cuda() ad_net = ad_net.cuda() parameter_list = base_network.get_parameters() + ad_net.get_parameters() source_path = ImageList(open(config["s_path"]).readlines(), transform=preprocess.image_train(resize_size=256, crop_size=224)) target_path = ImageList(open(config["t_path"]).readlines(), transform=preprocess.image_train(resize_size=256, crop_size=224)) test_path = ImageList(open(config["t_path"]).readlines(), transform=preprocess.image_test(resize_size=256, crop_size=224)) source_loader = DataLoader(source_path, batch_size=config["train_bs"], shuffle=True, num_workers=0, drop_last=True) target_loader = DataLoader(target_path, batch_size=config["train_bs"], shuffle=True, num_workers=0, drop_last=True) test_loader = DataLoader(test_path, batch_size=config["test_bs"], shuffle=True, num_workers=0, drop_last=True) optimizer_config = config["optimizer"] optimizer = optimizer_config["type"](parameter_list, \ **(optimizer_config["optim_params"])) param_lr = [] for param_group in optimizer.param_groups: param_lr.append(param_group["lr"]) schedule_param = optimizer_config["lr_param"] lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]] gpus = config["gpus"].split(',') if len(gpus) > 1: ad_net = nn.DataParallel(ad_net, device_ids=[int(i) for i in gpus]) base_network = nn.DataParallel(base_network, device_ids=[int(i) for i in gpus]) len_train_source = len(source_loader) len_train_target = len(target_loader) transfer_loss_value = classifier_loss_value = total_loss_value = 0.0 best_acc = 0.0 best_model_path = None for i in trange(config["iterations"], leave=False): if i % config["test_interval"] == config["test_interval"] - 1: base_network.train(False) temp_acc = image_classification_test(test_loader, base_network) temp_model = nn.Sequential(base_network) if temp_acc > best_acc: best_acc = temp_acc best_model = copy.deepcopy(temp_model) best_iter = i if best_model_path and osp.exists(best_model_path): try: os.remove(best_model_path) except: pass best_model_path = osp.join(config["output_path"], "iter_{:05d}.pth.tar".format(best_iter)) torch.save(best_model, best_model_path) log_str = "iter: {:05d}, precision: {:.5f}".format(i, temp_acc) config["out_file"].write(log_str+"\n") config["out_file"].flush() # print("cut_loss: ", cut_loss.item()) print("mix_loss: ", mix_loss.item()) print(log_str) base_network.train(True) ad_net.train(True) optimizer = lr_scheduler(optimizer, i, **schedule_param) optimizer.zero_grad() if i % len_train_source == 0: iter_source = iter(source_loader) if i % len_train_target == 0: iter_target = iter(target_loader) inputs_source, labels_source = iter_source.next() inputs_target, labels_target = iter_target.next() inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda() labels_src_one_hot = torch.nn.functional.one_hot(labels_source, config["class_num"]).float() # inputs_cut, labels_cut = cutmix(base_network, inputs_source, labels_src_one_hot, inputs_target, config["alpha"], config["class_num"]) inputs_mix, labels_mix = mixup(base_network, inputs_source, labels_src_one_hot, inputs_target, config["alpha"], config["class_num"], config["temperature"]) features_source, outputs_source = base_network(inputs_source) features_target, outputs_target = base_network(inputs_target) # features_cut, outputs_cut = base_network(inputs_cut) features_mix, outputs_mix = base_network(inputs_mix) features = torch.cat((features_source, features_target), dim=0) outputs = torch.cat((outputs_source, outputs_target), dim=0) softmax_out = nn.Softmax(dim=1)(outputs) if config["method"] == 'DANN': transfer_loss = loss.DANN(features, ad_net) # cut_loss = utils.kl_loss(outputs_cut, labels_cut.detach()) mix_loss = utils.kl_loss(outputs_mix, labels_mix.detach()) else: raise ValueError('Method cannot be recognized.') classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source) total_loss = transfer_loss + classifier_loss + (5*mix_loss) total_loss.backward() optimizer.step() torch.save(best_model, osp.join(config["output_path"], "best_model.pth.tar")) print("Training Finished! Best Accuracy: ", best_acc) return best_acc
SAVE_PATH = "trained_models/VAE_jet_L12_BE2.dat" N_EPOCHS = 10000 optimizer = optim.Adam(net.parameters()) rms_loss = [] kldiv_loss = [] for epoch in range(N_EPOCHS): epoch_rms_loss = [] epoch_kldiv_loss = [] for minibatch in dataloader: inputs, outputs = minibatch optimizer.zero_grad() pred = net.forward(inputs) kl = beta * kl_loss(net.mu, net.log_sigma) rms = target_loss(pred, outputs) loss = rms + kl loss.backward() optimizer.step() epoch_rms_loss.append(np.mean(rms.data.detach().numpy())) epoch_kldiv_loss.append(np.mean(kl.data.detach().numpy())) kldiv_loss.append(np.mean(epoch_kldiv_loss)) rms_loss.append(np.mean(epoch_rms_loss)) print("Epoch %d -- rms error %f -- kl loss %f" % (epoch + 1, rms_loss[-1], kldiv_loss[-1])) torch.save(net.state_dict(), SAVE_PATH) print("Model saved to %s" % SAVE_PATH)
def train(batch_size, epochs, model, dataset, valid_size=5000, label_smoothing=0.0, gpu='cuda:0'): device = torch.device(gpu if torch.cuda.is_available() else 'cpu') assert model in ['resnet18', 'resnet101', 'densenet121', 'densenet169'] assert dataset in ['cifar10', 'cifar100'] print("batch_size =", batch_size) print("epochs =", epochs) print("model =", model) print("data set =", dataset) print("label_smoothing =", label_smoothing) if dataset == 'cifar100': num_classes = 100 mean = [0.5071, 0.4867, 0.4408] std = [0.2675, 0.2565, 0.2761] train_set = datasets.CIFAR100( '../data', train=True, download=True, transform=transforms.Compose([ transforms.RandomCrop(32, padding=4, padding_mode='reflect'), transforms.RandomHorizontalFlip(), # transforms.RandomRotation(15), transforms.ToTensor(), transforms.RandomErasing(p=0.5), transforms.Normalize(mean=mean, std=std) ])) valid_set = datasets.CIFAR100('../data', train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ])) train_indices = torch.load('./train_indices_cifar100.pth') valid_indices = torch.load('./valid_indices_cifar100.pth') elif dataset == 'cifar10': # cifar10 num_classes = 10 mean = [0.4914, 0.48216, 0.44653] std = [0.2470, 0.2435, 0.26159] train_set = datasets.CIFAR10( '../data', train=True, download=True, transform=transforms.Compose([ transforms.RandomCrop(32, padding=4, padding_mode='reflect'), transforms.RandomHorizontalFlip(), # transforms.RandomRotation(15), transforms.ToTensor(), transforms.RandomErasing(p=0.5), transforms.Normalize(mean=mean, std=std) ])) valid_set = datasets.CIFAR10('../data', train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ])) train_indices = torch.load('./train_indices_cifar10.pth') valid_indices = torch.load('./valid_indices_cifar10.pth') # indices = torch.randperm(len(train_set)) # train_indices = indices[:len(indices) - valid_size] # valid_indices = indices[len(indices) - valid_size:] # torch.save(train_indices, './train_indices_' + dataset + '.pth') # torch.save(valid_indices, './valid_indices_' + dataset + '.pth') train_loader = torch.utils.data.DataLoader( train_set, batch_size=batch_size, sampler=SubsetRandomSampler(train_indices)) valid_loader = torch.utils.data.DataLoader( valid_set, batch_size=batch_size, sampler=SubsetRandomSampler(valid_indices)) net = BayesianNet(num_classes=num_classes, model=model).to(device) net.apply(xavier_normal_init) # optimizer_net = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-6) optimizer_net = optim.AdamW(net.parameters(), lr=0.01) lr_scheduler_net = optim.lr_scheduler.ReduceLROnPlateau(optimizer_net, patience=10, factor=0.1) train_losses = [] train_accuracies = [] valid_losses = [] valid_accuracies = [] for e in range(epochs): net.train() epoch_train_loss = [] epoch_train_acc = [] is_best = False print("lr =", optimizer_net.param_groups[0]['lr']) for batch_idx, (data, target) in enumerate(tqdm(train_loader)): data, target = data.to(device), target.to(device) optimizer_net.zero_grad() logits = net(data) xent = F.cross_entropy(logits, target) kll = kl_loss(logits) loss = xent + label_smoothing * kll loss.backward() epoch_train_loss.append(loss.item()) epoch_train_acc.append(accuracy(logits, target)) optimizer_net.step() epoch_train_loss = np.mean(epoch_train_loss) epoch_train_acc = np.mean(epoch_train_acc) lr_scheduler_net.step(epoch_train_loss) net.eval() epoch_valid_loss = [] epoch_valid_acc = [] with torch.no_grad(): for batch_idx, (data, target) in enumerate(tqdm(valid_loader)): data, target = data.to(device), target.to(device) logits = net(data) loss = F.cross_entropy(logits, target) epoch_valid_loss.append(loss.item()) epoch_valid_acc.append(accuracy(logits, target)) epoch_valid_loss = np.mean(epoch_valid_loss) epoch_valid_acc = np.mean(epoch_valid_acc) print( "Epoch {:d}: loss: {:4f}, acc: {:4f}, val_loss: {:4f}, val_acc: {:4f}" .format( e, epoch_train_loss, epoch_train_acc, epoch_valid_loss, epoch_valid_acc, )) # save epoch losses train_losses.append(epoch_train_loss) train_accuracies.append(epoch_train_acc) valid_losses.append(epoch_valid_loss) valid_accuracies.append(epoch_valid_acc) if valid_losses[-1] <= np.min(valid_losses): is_best = True if is_best: filename = f"../snapshots/{model}_best.pth.tar" print("Saving best weights so far with val_loss: {:4f}".format( valid_losses[-1])) torch.save( { 'epoch': e, 'state_dict': net.state_dict(), 'optimizer': optimizer_net.state_dict(), 'train_losses': train_losses, 'train_accs': train_accuracies, 'val_losses': valid_losses, 'val_accs': valid_accuracies, }, filename) if e == epochs - 1: filename = f"../snapshots/{model}_{e}.pth.tar" print("Saving weights at epoch {:d}".format(e)) torch.save( { 'epoch': e, 'state_dict': net.state_dict(), 'optimizer': optimizer_net.state_dict(), 'train_losses': train_losses, 'train_accs': train_accuracies, 'val_losses': valid_losses, 'val_accs': valid_accuracies, }, filename)
def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) train_data = Dataset(args.train_file) val_data = Dataset(args.val_file) train_sents = train_data.batch_size.sum() vocab_size = int(train_data.vocab_size) print('Train data: %d batches' % len(train_data)) print('Val data: %d batches' % len(val_data)) print('Word vocab size: %d' % vocab_size) if args.slurm == 0: # cuda.set_device(args.gpu) gpu_id = 0 device = torch.device( f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") if args.train_from == '': model = RNNVAE(vocab_size=vocab_size, enc_word_dim=args.enc_word_dim, enc_h_dim=args.enc_h_dim, enc_num_layers=args.enc_num_layers, dec_word_dim=args.dec_word_dim, dec_h_dim=args.dec_h_dim, dec_num_layers=args.dec_num_layers, dec_dropout=args.dec_dropout, latent_dim=args.latent_dim, mode=args.model) for param in model.parameters(): param.data.uniform_(-0.1, 0.1) else: print('loading model from ' + args.train_from) checkpoint = torch.load(args.train_from) model = checkpoint['model'] print("model architecture") print(model) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) if args.warmup == 0: args.beta = 1. else: args.beta = args.kl_start criterion = nn.NLLLoss(reduce=False) # criterion = nn.NLLLoss() # model.cuda() # criterion.cuda() # model = torch.nn.DataParallel(net, device_ids=[0, 1]) model.to(device) criterion.to(device) model.train() def variational_loss(input, sents, model, z=None): mean, logvar = input z_samples = model._reparameterize(mean, logvar, z) preds = model._dec_forward(sents, z_samples) nll = sum([ criterion(preds[:, l], sents[:, l + 1]) for l in range(preds.size(1)) ]) kl = utils.kl_loss_diag(mean, logvar) return nll + args.beta * kl update_params = list(model.dec.parameters()) meta_optimizer = OptimN2N(variational_loss, model, update_params, eps=args.eps, lr=[args.svi_lr1, args.svi_lr2], iters=args.svi_steps, momentum=args.momentum, acc_param_grads=args.train_n2n == 1, max_grad_norm=args.svi_max_grad_norm) if args.test == 1: args.beta = 1 test_data = Dataset(args.test_file) eval(args, test_data, model, meta_optimizer, device) exit() t = 0 best_val_nll = 1e5 best_epoch = 0 val_stats = [] epoch = 0 while epoch < args.num_epochs: start_time = time.time() epoch += 1 print('Starting epoch %d' % epoch) train_nll_vae = 0. train_nll_autoreg = 0. train_kl_vae = 0. train_nll_svi = 0. train_kl_svi = 0. train_kl_init_final = 0. num_sents = 0 num_words = 0 b = 0 for i in np.random.permutation(len(train_data)): if args.warmup > 0: args.beta = min( 1, args.beta + 1. / (args.warmup * len(train_data))) sents, length, batch_size = train_data[i] length = length.item() batch_size = batch_size.item() if args.gpu >= 0: # sents = sents.cuda() sents = sents.to(device) # batch_size = batch_size.to(device) b += 1 optimizer.zero_grad() if args.model == 'autoreg': preds = model._dec_forward(sents, None, True) tgt = sents[:, 1:].contiguous() nll_autoreg = criterion(preds.view(-1, preds.size(2)), tgt.view(-1)).view(preds.size(0), -1).sum(-1).mean(0) # nll_autoreg = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) train_nll_autoreg += nll_autoreg.item() * batch_size # train_nll_autoreg += nll_autoreg.data[0]*batch_size #old nll_autoreg.backward() elif args.model == 'svi': # mean_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True) # logvar_svi = Variable(0.1*torch.zeros(batch_size, args.latent_dim).cuda(), requires_grad = True) mean_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).to(device), requires_grad=True) logvar_svi = Variable( 0.1 * torch.zeros(batch_size, args.latent_dim).to(device), requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], sents, b % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final.detach(), logvar_svi_final.detach()) preds = model._dec_forward(sents, z_samples) tgt = sents[:, 1:].contiguous() nll_svi = criterion(preds.view(-1, preds.size(2)), tgt.view(-1)).view(preds.size(0), -1).sum(-1).mean(0) # nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward(retain_graph=True) else: mean, logvar = model._enc_forward(sents) z_samples = model._reparameterize(mean, logvar) preds = model._dec_forward(sents, z_samples) tgt = sents[:, 1:].contiguous() nll_vae = criterion(preds.view(-1, preds.size(2)), tgt.view(-1)).view(preds.size(0), -1).sum(-1).mean(0) # nll_vae = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) # train_nll_vae += nll_vae.data[0]*batch_size#old train_nll_vae += nll_vae.item() * batch_size kl_vae = utils.kl_loss_diag(mean, logvar) # train_kl_vae += kl_vae.data[0]*batch_size#old train_kl_vae += kl_vae.item() * batch_size if args.model == 'vae': vae_loss = nll_vae + args.beta * kl_vae vae_loss.backward(retain_graph=True) if args.model == 'savae': var_params = torch.cat([mean, logvar], 1) mean_svi = Variable(mean.data, requires_grad=True) logvar_svi = Variable(logvar.data, requires_grad=True) var_params_svi = meta_optimizer.forward( [mean_svi, logvar_svi], sents, b % args.print_every == 0) mean_svi_final, logvar_svi_final = var_params_svi z_samples = model._reparameterize(mean_svi_final, logvar_svi_final) preds = model._dec_forward(sents, z_samples) tgt = sents[:, 1:].contiguous() nll_svi = criterion(preds.view(-1, preds.size(2)), tgt.view(-1)).view(preds.size(0), -1).sum(-1).mean(0) # nll_svi = sum([criterion(preds[:, l], sents[:, l+1]) for l in range(length)]) train_nll_svi += nll_svi.data[0] * batch_size kl_svi = utils.kl_loss_diag(mean_svi_final, logvar_svi_final) train_kl_svi += kl_svi.data[0] * batch_size var_loss = nll_svi + args.beta * kl_svi var_loss.backward(retain_graph=True) if args.train_n2n == 0: if args.train_kl == 1: mean_final = mean_svi_final.detach() logvar_final = logvar_svi_final.detach() kl_init_final = utils.kl_loss( mean, logvar, mean_final, logvar_final) train_kl_init_final += kl_init_final.data[ 0] * batch_size kl_init_final.backward(retain_graph=True) else: vae_loss = nll_vae + args.beta * kl_vae var_param_grads = torch.autograd.grad( vae_loss, [mean, logvar], retain_graph=True) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads, retain_graph=True) else: var_param_grads = meta_optimizer.backward( [mean_svi_final.grad, logvar_svi_final.grad], b % args.print_every == 0) var_param_grads = torch.cat(var_param_grads, 1) var_params.backward(var_param_grads) if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm) optimizer.step() num_sents += batch_size num_words += batch_size * length # num_sents = num_sents.item() # num_words = num_words.item() if b % args.print_every == 0: param_norm = sum([p.norm()**2 for p in model.parameters()]).data[0]**0.5 print( 'Iters: %d, Epoch: %d, Batch: %d/%d, LR: %.4f, TrainARPPL: %.2f, TrainVAE_PPL: %.2f, TrainVAE_KL: %.4f, TrainVAE_PPLBnd: %.2f, TrainSVI_PPL: %.2f, TrainSVI_KL: %.4f, TrainSVI_PPLBnd: %.2f, KLInitFinal: %.2f, |Param|: %.4f, BestValPerf: %.2f, BestEpoch: %d, Beta: %.4f, Throughput: %.2f examples/sec' % (t, epoch, b + 1, len(train_data), args.lr, np.exp(train_nll_autoreg / num_words), np.exp( train_nll_vae / num_words), train_kl_vae / num_sents, np.exp((train_nll_vae + train_kl_vae) / num_words), np.exp( train_nll_svi / num_words), train_kl_svi / num_sents, np.exp((train_nll_svi + train_kl_svi) / num_words), train_kl_init_final / num_sents, param_norm, best_val_nll, best_epoch, args.beta, num_sents / (time.time() - start_time))) print('--------------------------------') print('Checking validation perf...') val_nll = eval(args, val_data, model, meta_optimizer, device) val_stats.append(val_nll) # if val_elbo > self.best_val_elbo: # self.not_improved = 0 # self.best_val_elbo = val_elbo # else: # self.not_improved += 1 # if self.not_improved % 5 == 0: # self.current_lr = self.current_lr * self.config.options.lr_decay # print(f'New LR {self.current_lr}') # model.optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr) # model.enc_optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr) # model.dec_optimizer = torch.optim.SGD(model.parameters(), lr=self.current_lr) if val_nll < best_val_nll: not_improved = 0 best_save = '{}_{}.pt'.format(args.checkpoint_path, best_val_nll) if os.path.exists(best_save): os.remove(best_save) best_val_nll = val_nll best_epoch = epoch model.cpu() checkpoint = { 'args': args.__dict__, 'model': model, 'val_stats': val_stats } print('Savaeng checkpoint to %s' % args.checkpoint_path) best_save = '{}_{}.pt'.format(args.checkpoint_path, best_val_nll) torch.save(checkpoint, best_save) # model.cuda() model.to(device) else: not_improved += 1 if not_improved % 5 == 0: not_improved = 0 args.lr = args.lr * args.lr_decay print(f'New LR: {args.lr}') for param_group in optimizer.param_groups: param_group['lr'] = args.lr