def train(gen, device, train_loader, optimizer, epoch, rff_mmd_loss, log_interval, do_gen_labels, uniform_labels): for batch_idx, (data, labels) in enumerate(train_loader): # print(pt.max(data), pt.min(data)) data = flat_data(data.to(device), labels.to(device), device, n_labels=10, add_label=False) bs = labels.shape[0] if not do_gen_labels: loss = rff_mmd_loss(data, gen(gen.get_code(bs, device))) elif uniform_labels: one_hots = pt.zeros(bs, 10, device=device) one_hots.scatter_(1, labels.to(device)[:, None], 1) gen_code, gen_labels = gen.get_code(bs, device) loss = rff_mmd_loss(data, one_hots, gen(gen_code), gen_labels) else: one_hots = pt.zeros(bs, 10, device=device) one_hots.scatter_(1, labels.to(device)[:, None], 1) gen_enc, gen_labels = gen(gen.get_code(bs, device)) loss = rff_mmd_loss(data, one_hots, gen_enc, gen_labels) optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))
def test(gen, device, test_loader, rff_mmd_loss, epoch, batch_size, do_gen_labels, uniform_labels, log_dir): test_loss = 0 gen_labels, ordered_labels = None, None with pt.no_grad(): for data, labels in test_loader: data = data.to(device) data = flat_data(data.to(device), labels.to(device), device, n_labels=10, add_label=False) bs = labels.shape[0] if not do_gen_labels: gen_samples = gen(gen.get_code(bs, device)) gen_labels = None loss = rff_mmd_loss(data, gen_samples) elif uniform_labels: one_hots = pt.zeros(bs, 10, device=device) one_hots.scatter_(1, labels.to(device)[:, None], 1) gen_code, gen_labels = gen.get_code(bs, device) gen_samples = gen(gen_code) loss = rff_mmd_loss(data, one_hots, gen_samples, gen_labels) else: one_hots = pt.zeros(bs, 10, device=device) one_hots.scatter_(1, labels.to(device)[:, None], 1) gen_samples, gen_labels = gen(gen.get_code(bs, device)) loss = rff_mmd_loss(data, one_hots, gen_samples, gen_labels) test_loss += loss.item() # sum up batch loss test_loss /= (len(test_loader.dataset) / batch_size) data_enc_batch = data.cpu().numpy() med_dist = meddistance(data_enc_batch) print( f'med distance for encodings is {med_dist}, heuristic suggests sigma={med_dist ** 2}' ) if uniform_labels: ordered_labels = pt.repeat_interleave(pt.arange(10), 10)[:, None].to(device) gen_code, gen_labels = gen.get_code(100, device, labels=ordered_labels) gen_samples = gen(gen_code).detach() plot_samples = gen_samples[:100, ...].cpu().numpy() plot_mnist_batch(plot_samples, 10, 10, log_dir + f'samples_ep{epoch}', denorm=False) if gen_labels is not None and ordered_labels is None: save_gen_labels(gen_labels[:100, ...].cpu().numpy(), 10, 10, log_dir + f'labels_ep{epoch}') print('Test set: Average loss: {:.4f}'.format(test_loss))
def train_multi_release(gen, device, train_loader, optimizer, epoch, rff_mmd_loss, log_interval): for batch_idx, (data, labels) in enumerate(train_loader): data, labels = data.to(device), labels.to(device) data = flat_data(data, labels, device, n_labels=10, add_label=False) loss = compute_rff_loss(gen, data, labels, rff_mmd_loss, device) optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % log_interval == 0: n_data = len(train_loader.dataset) print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), n_data, loss.item()))
def test(gen, device, test_loader, rff_mmd_loss, epoch, batch_size, log_dir): test_loss = 0 with pt.no_grad(): for data, labels in test_loader: data, labels = data.to(device), labels.to(device) data = flat_data(data, labels, device, n_labels=10, add_label=False) loss = compute_rff_loss(gen, data, labels, rff_mmd_loss, device) test_loss += loss.item() # sum up batch loss test_loss /= (len(test_loader.dataset) / batch_size) data_enc_batch = data.cpu().numpy() med_dist = meddistance(data_enc_batch) print(f'med distance for encodings is {med_dist}, heuristic suggests sigma={med_dist ** 2}') ordered_labels = pt.repeat_interleave(pt.arange(10), 10)[:, None].to(device) gen_code, gen_labels = gen.get_code(100, device, labels=ordered_labels) gen_samples = gen(gen_code).detach() plot_samples = gen_samples[:100, ...].cpu().numpy() plot_mnist_batch(plot_samples, 10, 10, log_dir + f'samples_ep{epoch}', denorm=False) print('Test set: Average loss: {:.4f}'.format(test_loss))
def noisy_dataset_embedding(train_loader, w_freq, d_rff, device, n_labels, noise_factor, mmd_type, sum_frequency=25, pca_vecs=None): emb_acc = [] n_data = 0 for data, labels in train_loader: data, labels = data.to(device), labels.to(device) data = flat_data(data, labels, device, n_labels=10, add_label=False) data = data if pca_vecs is None else apply_pca(pca_vecs, data) emb_acc.append(data_label_embedding(data, labels, w_freq, mmd_type, labels_to_one_hot=True, n_labels=n_labels, device=device, reduce='sum')) # emb_acc.append(pt.sum(pt.einsum('ki,kj->kij', [rff_gauss(data, w_freq), one_hots]), 0)) n_data += data.shape[0] if len(emb_acc) > sum_frequency: emb_acc = [pt.sum(pt.stack(emb_acc), 0)] print('done collecting batches, n_data', n_data) emb_acc = pt.sum(pt.stack(emb_acc), 0) / n_data print(pt.norm(emb_acc), emb_acc.shape) noise = pt.randn(d_rff, n_labels, device=device) * (2 * noise_factor / n_data) noisy_emb = emb_acc + noise return noisy_emb
def train(enc, dec, device, train_loader, optimizer, epoch, losses, dp_spec, label_ae, conv_ae, log_interval, summary_writer, verbose): enc.train() dec.train() for batch_idx, (data, labels) in enumerate(train_loader): data = data.to(device) labels = labels.to(device) if not conv_ae: data = flat_data(data, labels, device, add_label=label_ae) optimizer.zero_grad() data_enc = enc(data) reconstruction = dec(data_enc) l_enc = bin_ce_loss(reconstruction, data) if losses.do_ce else mse_loss( reconstruction, data) if losses.wsiam > 0.: l_enc = l_enc + losses.wsiam * siamese_loss( data_enc, labels, losses.msiam) if dp_spec.clip is None: l_enc.backward() squared_param_norms, bp_global_norms, rec_loss, siam_loss = None, None, None, None else: l_enc.backward(retain_graph=True) # get grads for encoder # wipe grads from decoder: for param in dec.parameters(): param.grad = None reconstruction = dec(data_enc.detach()) rec_loss = bin_ce_loss(reconstruction, data) if losses.do_ce else mse_loss( reconstruction, data) if losses.wsiam > 0.: siam_loss = losses.wsiam * siamese_loss( data_enc, labels, losses.msiam) full_loss = rec_loss + siam_loss else: siam_loss = None full_loss = rec_loss l_dec = full_loss with backpack(BatchGrad(), BatchL2Grad()): l_dec.backward() # compute global gradient norm from parameter gradient norms squared_param_norms = [p.batch_l2 for p in dec.parameters()] bp_global_norms = pt.sqrt( pt.sum(pt.stack(squared_param_norms), dim=0)) global_clips = pt.clamp_max(dp_spec.clip / bp_global_norms, 1.) # aggregate samplewise grads, replace normal grad for idx, param in enumerate(dec.parameters()): if dp_spec.per_layer: # clip each param by C/sqrt(m), then total sensitivity is still C if dp_spec.layer_clip: local_clips = pt.clamp_max( dp_spec.layer_clip[idx] / pt.sqrt(param.batch_l2), 1.) else: local_clips = pt.clamp_max( dp_spec.clip / np.sqrt(len(squared_param_norms)) / pt.sqrt(param.batch_l2), 1.) clipped_sample_grads = param.grad_batch * expand_vector( local_clips, param.grad_batch) else: clipped_sample_grads = param.grad_batch * expand_vector( global_clips, param.grad_batch) clipped_grad = pt.mean(clipped_sample_grads, dim=0) if dp_spec.noise is not None: bs = clipped_grad.shape[0] noise_sdev = (2 * dp_spec.noise * dp_spec.clip / bs) clipped_grad = clipped_grad + pt.rand_like( clipped_grad, device=device) * noise_sdev param.grad = clipped_grad optimizer.step() if batch_idx % log_interval == 0 and verbose: n_data = len(train_loader.dataset) n_done = batch_idx * len(data) frac_done = 100. * batch_idx / len(train_loader) iter_idx = batch_idx + epoch * (n_data / len(data)) if dp_spec.clip is not None: print( f'max_norm:{pt.max(bp_global_norms).item()}, mean_norm:{pt.mean(bp_global_norms).item()}' ) summary_writer.add_histogram( f'grad_norm_global', bp_global_norms.clone().cpu().numpy(), iter_idx) for idx, sq_norm in enumerate(squared_param_norms): # print(f'param {idx} shape: {list(dec.parameters())[idx].shape}') summary_writer.add_histogram( f'grad_norm_param_{idx}', pt.sqrt(sq_norm).clone().cpu().numpy(), iter_idx) if siam_loss is None: loss_str = 'Loss: {:.6f}'.format(l_enc.item()) else: loss_str = 'Loss: full {:.6f}, rec {:.6f}, siam {:.6f}'.format( l_enc.item(), rec_loss.item(), siam_loss.item()) print('Train Epoch: {} [{}/{} ({:.0f}%)]\t{}'.format( epoch, n_done, n_data, frac_done, loss_str))
def test(enc, dec, device, test_loader, epoch, losses, label_ae, conv_ae, log_spec, last_epoch, data_is_normed): enc.eval() dec.eval() rec_loss_agg = 0 siam_loss_agg = 0 with pt.no_grad(): for data, labels in test_loader: bs = data.shape[0] data = data.to(device) labels = labels.to(device) if not conv_ae: data = flat_data(data, labels, device, add_label=label_ae) data_enc = enc(data) reconstruction = dec(data_enc) rec_loss = bin_ce_loss(reconstruction, data) if losses.do_ce else mse_loss( reconstruction, data) rec_loss_agg += rec_loss.item() * bs if losses.wsiam > 0.: siam_loss = losses.wsiam * siamese_loss( data_enc, labels, losses.msiam) siam_loss_agg += siam_loss.item() * bs n_data = len(test_loader.dataset) rec_loss_agg /= n_data siam_loss_agg /= n_data full_loss = rec_loss_agg + siam_loss_agg reconstruction = reconstruction.cpu().numpy() labels = labels.cpu().numpy() reconstruction, labels = select_balaned_plot_batch(reconstruction, labels, n_classes=10, n_samples_per_class=10) if label_ae: rec_labels = reconstruction[:, 784:] save_gen_labels(rec_labels, 10, 10, log_spec.log_dir + f'rec_ep{epoch}_labels', save_raw=False) reconstruction = reconstruction[:, :784].reshape(-1, 28, 28) else: reconstruction = reconstruction plot_mnist_batch(reconstruction, 10, 10, log_spec.log_dir + f'rec_ep{epoch}', denorm=data_is_normed) if last_epoch: save_dir = log_spec.base_dir + '/overview/' if not os.path.exists(save_dir): os.makedirs(save_dir) save_path = save_dir + log_spec.log_name + f'_rec_ep{epoch}' plot_mnist_batch(reconstruction, 10, 10, save_path, denorm=data_is_normed, save_raw=False) print('Test ep {}: Average loss: full {:.4f}, rec {:.4f}, siam {:.4f}'. format(epoch, full_loss, rec_loss_agg, siam_loss_agg))