def process_sample(X, model_x2y, model_y2x, scm, encoder, decoder): Y = scm(X) if opt.CUDA: X, Y = X.cuda(), Y.cuda() # Decode with torch.no_grad(): X, Y = decoder(X, Y) # Encode X, Y = encoder(X, Y) # Evaluate total regret loss_x2y = mdn_nll(model_x2y(X), Y) loss_y2x = mdn_nll(model_y2x(Y), X) if torch.isnan(loss_x2y).item() or torch.isnan(loss_y2x).item(): raise() return loss_x2y, loss_y2x
def transfer_finetune(args, model_x2y, model_y2x, inputs, targets): optim_x2y = optim.Adam(model_x2y.parameters(), lr=args.finetune_lr) optim_y2x = optim.Adam(model_y2x.parameters(), lr=args.finetune_lr) loss_marg_x2y = marginal_nll(args, inputs) if args.train_gmm else 0. loss_marg_y2x = marginal_nll(args, targets) if args.train_gmm else 0. loss_x2y, loss_y2x = [], [] is_nan = False for _ in range(args.finetune_n_iters): loss_cond_x2y = mdn_nll(model_x2y(inputs), targets) loss_cond_y2x = mdn_nll(model_y2x(targets), inputs) if torch.isnan(loss_cond_x2y).item() or torch.isnan(loss_cond_y2x).item(): is_nan = True break optim_x2y.zero_grad() optim_y2x.zero_grad() loss_cond_x2y.backward(retain_graph=True) loss_cond_y2x.backward(retain_graph=True) nan_in_x2y = gradnan_filter(model_x2y) nan_in_y2x = gradnan_filter(model_y2x) if nan_in_x2y or nan_in_y2x: is_nan = True break optim_x2y.step() optim_y2x.step() loss_x2y.append(loss_cond_x2y + loss_marg_x2y) loss_y2x.append(loss_cond_y2x + loss_marg_y2x) return loss_x2y, loss_y2x, is_nan
def train_mle_nll(args, model, scm, encoder=None, decoder=None, polarity='X2Y'): model = model.cuda() if encoder is not None: encoder = encoder.cuda() if decoder is not None: decoder = decoder.cuda() optimizer_mle = optim.Adam(model.parameters(), lr=args.mle_lr) losses = [] for iter_num in range(1, args.mle_n_iters + 1): x = sample_from_normal(0, 2, args.mle_nsamples, args.n_features) with torch.no_grad(): y = scm(x) x, y = x.cuda(), y.cuda() if decoder is not None: x, y = decoder(x, y) if encoder is not None: x, y = encoder(x, y) if polarity == 'X2Y': inputs, targets = x, y elif polarity == 'Y2X': inputs, targets = y, x else: raise ValueError('%s does not match any known polarity.' % polarity) inputs, targets = inputs.cuda(), targets.cuda() loss_conditional = mdn_nll(model(inputs), targets) optimizer_mle.zero_grad() loss_conditional.backward() optimizer_mle.step() losses.append(loss_conditional.item()) return losses
def encoder_train_shared_regret(opt, model_x2y, model_y2x, scm, encoder, decoder, alpha): if opt.CUDA: model_x2y = model_x2y.cuda() model_y2x = model_y2x.cuda() encoder = encoder.cuda() decoder = decoder.cuda() encoder_optim = torch.optim.Adam(encoder.parameters(), opt.ENCODER_LR) alpha_optim = torch.optim.Adam([alpha], opt.ALPHA_LR) frames = [] start = time.time() for meta_iter in tqdm.trange(opt.NUM_META_ITER): # Preheat the models _ = tu.train_nll(opt, model_x2y, scm, opt.TRAIN_DISTRY, 'X2Y', mdn_nll, decoder, encoder) _ = tu.train_nll(opt, model_y2x, scm, opt.TRAIN_DISTRY, 'Y2X', mdn_nll, decoder, encoder) # Sample from SCM X = opt.TRANS_DISTRY() Y = scm(X) if opt.CUDA: X, Y = X.cuda(), Y.cuda() # Decode with torch.no_grad(): X, Y = decoder(X, Y) # Encode X, Y = encoder(X, Y) with torch.no_grad(): if opt.USE_BASELINE: baseline_y = marginal_nll(opt, Y, mdn_nll) baseline_x = marginal_nll(opt, X, mdn_nll) else: baseline_y = 0. baseline_x = 0. # Save state dicts state_x2y = deepcopy(model_x2y.state_dict()) state_y2x = deepcopy(model_y2x.state_dict()) # Inner loop optim_x2y = torch.optim.Adam(model_x2y.parameters(), lr=opt.FINETUNE_LR) optim_y2x = torch.optim.Adam(model_y2x.parameters(), lr=opt.FINETUNE_LR) regrets_x2y = [] regrets_y2x = [] is_nan = False # Evaluate regret discrepancy for t in range(opt.FINETUNE_NUM_ITER): loss_x2y = mdn_nll(model_x2y(X), Y) loss_y2x = mdn_nll(model_y2x(Y), X) if torch.isnan(loss_x2y).item() or torch.isnan(loss_y2x).item(): is_nan = True break optim_x2y.zero_grad() optim_y2x.zero_grad() loss_x2y.backward(retain_graph=True) loss_y2x.backward(retain_graph=True) # Filter out NaNs that might have sneaked in nan_in_x2y = gradnan_filter(model_x2y) nan_in_y2x = gradnan_filter(model_y2x) if nan_in_x2y or nan_in_y2x: is_nan = True break optim_x2y.step() optim_y2x.step() # Store for encoder regrets_x2y.append(loss_x2y + baseline_x) regrets_y2x.append(loss_y2x + baseline_y) if not is_nan: # Evaluate total regret regret_x2y = torch.stack(regrets_x2y).mean() regret_y2x = torch.stack(regrets_y2x).mean() # Evaluate losses loss = torch.logsumexp( torch.stack([ F.logsigmoid(alpha) + regret_x2y, F.logsigmoid(-alpha) + regret_y2x ]), 0) # Optimize encoder_optim.zero_grad() alpha_optim.zero_grad() loss.backward() # Make sure no nans if torch.isnan(encoder.theta.grad.data).any(): encoder.theta.grad.data.zero_() if torch.isnan(alpha.grad.data).any(): alpha.grad.data.zero_() encoder_optim.step() alpha_optim.step() # Load original state dicts model_x2y.load_state_dict(state_x2y) model_y2x.load_state_dict(state_y2x) # Add info end = time.time() frames.append( Namespace(iter_num=meta_iter, regret_x2y=regret_x2y.item(), regret_y2x=regret_y2x.item(), loss=loss.item(), alpha=alpha.item(), theta=encoder.theta.item(), como_time=end - start)) else: # Load original state dicts model_x2y.load_state_dict(state_x2y) model_y2x.load_state_dict(state_y2x) # Add dummy info end = time.time() frames.append( Namespace(iter_num=meta_iter, regret_x2y=float('nan'), regret_y2x=float('nan'), loss=float('nan'), alpha=float('nan'), theta=float('nan'), como_time=end - start)) return frames
def marginal_nll(args, inputs): gmm = GaussianMixture(args.gmm_n_gaussians, args.n_features).cuda() gmm.fit(inputs, n_iters=args.em_n_iters) with torch.no_grad(): loss_marginal = mdn_nll(gmm(inputs), inputs) return loss_marginal
def train_encoder(args, model_x2y, model_y2x, scm, encoder, decoder): alpha = nn.Parameter(torch.tensor(0.).cuda(), requires_grad=True) optim_encoder = optim.Adam(encoder.parameters(), args.encoder_lr) optim_alpha = optim.Adam([alpha], lr=args.alpha_lr) # print('Start pre-training model_x2y.') _ = train_mle_nll(args, model_x2y, scm, encoder, decoder, polarity='X2Y') # print('Start pre-training model_y2x.') _ = train_mle_nll(args, model_y2x, scm, encoder, decoder, polarity='Y2X') results = [] for meta_iter_num in range(1, args.meta_n_iters + 1): _ = train_mle_nll(args, model_x2y, scm, encoder, decoder, polarity='X2Y') _ = train_mle_nll(args, model_y2x, scm, encoder, decoder, polarity='Y2X') # same mechanism (conditional distribution) param = np.random.uniform(-4, 4) a_ts = sample_from_normal(param, 2, args.meta_nsamples, args.n_features) b_ts = scm(a_ts) a_ts, b_ts = a_ts.cuda(), b_ts.cuda() with torch.no_grad(): x_ts, y_ts = decoder(a_ts, b_ts) x_ts, y_ts = encoder(x_ts, y_ts) loss_marg_x2y = marginal_nll(args, x_ts) if args.train_gmm else 0. loss_marg_y2x = marginal_nll(args, y_ts) if args.train_gmm else 0. state_x2y = deepcopy(model_x2y.state_dict()) state_y2x = deepcopy(model_y2x.state_dict()) # Inner loop optim_x2y = optim.Adam(model_x2y.parameters(), lr=args.finetune_lr) optim_y2x = optim.Adam(model_y2x.parameters(), lr=args.finetune_lr) loss_x2y, loss_y2x = [], [] is_nan = False for _ in range(args.finetune_n_iters): loss_cond_x2y = mdn_nll(model_x2y(x_ts), y_ts) loss_cond_y2x = mdn_nll(model_y2x(y_ts), x_ts) if torch.isnan(loss_cond_x2y).item() or torch.isnan( loss_cond_y2x).item(): is_nan = True break optim_x2y.zero_grad() optim_y2x.zero_grad() loss_cond_x2y.backward(retain_graph=True) loss_cond_y2x.backward(retain_graph=True) nan_in_x2y = gradnan_filter(model_x2y) nan_in_y2x = gradnan_filter(model_y2x) if nan_in_x2y or nan_in_y2x: is_nan = True break optim_x2y.step() optim_y2x.step() loss_x2y.append(loss_cond_x2y + loss_marg_x2y) loss_y2x.append(loss_cond_y2x + loss_marg_y2x) if not is_nan: loss_x2y = torch.stack(loss_x2y).mean() loss_y2x = torch.stack(loss_y2x).mean() log_alpha, log_1_m_alpha = F.logsigmoid(alpha), F.logsigmoid( -alpha) loss = logsumexp(log_alpha + loss_x2y, log_1_m_alpha + loss_y2x) optim_encoder.zero_grad() optim_alpha.zero_grad() loss.backward() if torch.isnan(encoder.theta.grad.data).any(): encoder.theta.grad.data.zero_() if torch.isnan(alpha.grad.data).any(): alpha.grad.data.zero_() optim_encoder.step() optim_alpha.step() model_x2y.load_state_dict(state_x2y) model_y2x.load_state_dict(state_y2x) print( '| Iteration: %d | Prob: %.3f | X2Y_Loss: %.3f | Y2X_Loss: %.3f | Theta: %.3f' % (meta_iter_num, torch.sigmoid(alpha).item(), loss_x2y, loss_y2x, encoder.theta.item())) with torch.no_grad(): results.append( Namespace(iter_num=meta_iter_num, sig_alpha=torch.sigmoid(alpha).item(), loss_x2y=loss_x2y, loss_y2x=loss_y2x, theta=encoder.theta.item())) else: model_x2y.load_state_dict(state_x2y) model_y2x.load_state_dict(state_y2x) with torch.no_grad(): results.append( Namespace(iter_num=meta_iter_num, sig_alpha=float('nan'), loss_x2y=float('nan'), loss_y2x=float('nan'), theta=float('nan'))) return results