예제 #1
0
def process_sample(X, model_x2y, model_y2x, scm, encoder, decoder):
    Y = scm(X)
    if opt.CUDA:
        X, Y = X.cuda(), Y.cuda()
    # Decode
    with torch.no_grad():
        X, Y = decoder(X, Y)
    # Encode
    X, Y = encoder(X, Y)

    # Evaluate total regret
    loss_x2y = mdn_nll(model_x2y(X), Y)
    loss_y2x = mdn_nll(model_y2x(Y), X)

    if torch.isnan(loss_x2y).item() or torch.isnan(loss_y2x).item():
        raise()

    return loss_x2y, loss_y2x
예제 #2
0
def transfer_finetune(args, model_x2y, model_y2x, inputs, targets):
    
    optim_x2y = optim.Adam(model_x2y.parameters(), lr=args.finetune_lr)
    optim_y2x = optim.Adam(model_y2x.parameters(), lr=args.finetune_lr)

    loss_marg_x2y = marginal_nll(args, inputs) if args.train_gmm else 0.
    loss_marg_y2x = marginal_nll(args, targets) if args.train_gmm else 0.

    loss_x2y, loss_y2x = [], []
    is_nan = False

    for _ in range(args.finetune_n_iters):
        
        loss_cond_x2y = mdn_nll(model_x2y(inputs), targets)
        loss_cond_y2x = mdn_nll(model_y2x(targets), inputs)
        
        if torch.isnan(loss_cond_x2y).item() or torch.isnan(loss_cond_y2x).item():
            is_nan = True
            break
        
        optim_x2y.zero_grad()
        optim_y2x.zero_grad()
        
        loss_cond_x2y.backward(retain_graph=True)
        loss_cond_y2x.backward(retain_graph=True)

        nan_in_x2y = gradnan_filter(model_x2y)
        nan_in_y2x = gradnan_filter(model_y2x)
        
        if nan_in_x2y or nan_in_y2x: 
            is_nan = True
            break
        
        optim_x2y.step()
        optim_y2x.step()

        loss_x2y.append(loss_cond_x2y + loss_marg_x2y)
        loss_y2x.append(loss_cond_y2x + loss_marg_y2x)

    return loss_x2y, loss_y2x, is_nan
예제 #3
0
def train_mle_nll(args, model, scm, encoder=None, decoder=None, polarity='X2Y'):

    model = model.cuda()
    if encoder is not None: 
        encoder = encoder.cuda()
    if decoder is not None: 
        decoder = decoder.cuda()
    
    optimizer_mle = optim.Adam(model.parameters(), lr=args.mle_lr)

    losses = []

    for iter_num in range(1, args.mle_n_iters + 1):

        x = sample_from_normal(0, 2, args.mle_nsamples, args.n_features)
        
        with torch.no_grad():
            y = scm(x)
            x, y = x.cuda(), y.cuda()

            if decoder is not None:
                x, y = decoder(x, y)
            if encoder is not None:
                x, y = encoder(x, y)

        if polarity == 'X2Y':
            inputs, targets = x, y
        elif polarity == 'Y2X':
            inputs, targets = y, x
        else:
            raise ValueError('%s does not match any known polarity.' % polarity)

        inputs, targets = inputs.cuda(), targets.cuda()

        loss_conditional = mdn_nll(model(inputs), targets)

        optimizer_mle.zero_grad()
        loss_conditional.backward()
        optimizer_mle.step()

        losses.append(loss_conditional.item())

    return losses
예제 #4
0
def encoder_train_shared_regret(opt, model_x2y, model_y2x, scm, encoder,
                                decoder, alpha):
    if opt.CUDA:
        model_x2y = model_x2y.cuda()
        model_y2x = model_y2x.cuda()
        encoder = encoder.cuda()
        decoder = decoder.cuda()
    encoder_optim = torch.optim.Adam(encoder.parameters(), opt.ENCODER_LR)
    alpha_optim = torch.optim.Adam([alpha], opt.ALPHA_LR)
    frames = []
    start = time.time()
    for meta_iter in tqdm.trange(opt.NUM_META_ITER):
        # Preheat the models
        _ = tu.train_nll(opt, model_x2y, scm, opt.TRAIN_DISTRY, 'X2Y', mdn_nll,
                         decoder, encoder)
        _ = tu.train_nll(opt, model_y2x, scm, opt.TRAIN_DISTRY, 'Y2X', mdn_nll,
                         decoder, encoder)
        # Sample from SCM
        X = opt.TRANS_DISTRY()
        Y = scm(X)
        if opt.CUDA:
            X, Y = X.cuda(), Y.cuda()
        # Decode
        with torch.no_grad():
            X, Y = decoder(X, Y)
        # Encode
        X, Y = encoder(X, Y)
        with torch.no_grad():
            if opt.USE_BASELINE:
                baseline_y = marginal_nll(opt, Y, mdn_nll)
                baseline_x = marginal_nll(opt, X, mdn_nll)
            else:
                baseline_y = 0.
                baseline_x = 0.
        # Save state dicts
        state_x2y = deepcopy(model_x2y.state_dict())
        state_y2x = deepcopy(model_y2x.state_dict())
        # Inner loop
        optim_x2y = torch.optim.Adam(model_x2y.parameters(),
                                     lr=opt.FINETUNE_LR)
        optim_y2x = torch.optim.Adam(model_y2x.parameters(),
                                     lr=opt.FINETUNE_LR)
        regrets_x2y = []
        regrets_y2x = []
        is_nan = False
        # Evaluate regret discrepancy
        for t in range(opt.FINETUNE_NUM_ITER):
            loss_x2y = mdn_nll(model_x2y(X), Y)
            loss_y2x = mdn_nll(model_y2x(Y), X)
            if torch.isnan(loss_x2y).item() or torch.isnan(loss_y2x).item():
                is_nan = True
                break
            optim_x2y.zero_grad()
            optim_y2x.zero_grad()
            loss_x2y.backward(retain_graph=True)
            loss_y2x.backward(retain_graph=True)
            # Filter out NaNs that might have sneaked in
            nan_in_x2y = gradnan_filter(model_x2y)
            nan_in_y2x = gradnan_filter(model_y2x)
            if nan_in_x2y or nan_in_y2x:
                is_nan = True
                break
            optim_x2y.step()
            optim_y2x.step()
            # Store for encoder
            regrets_x2y.append(loss_x2y + baseline_x)
            regrets_y2x.append(loss_y2x + baseline_y)
        if not is_nan:
            # Evaluate total regret
            regret_x2y = torch.stack(regrets_x2y).mean()
            regret_y2x = torch.stack(regrets_y2x).mean()
            # Evaluate losses
            loss = torch.logsumexp(
                torch.stack([
                    F.logsigmoid(alpha) + regret_x2y,
                    F.logsigmoid(-alpha) + regret_y2x
                ]), 0)
            # Optimize
            encoder_optim.zero_grad()
            alpha_optim.zero_grad()
            loss.backward()
            # Make sure no nans
            if torch.isnan(encoder.theta.grad.data).any():
                encoder.theta.grad.data.zero_()
            if torch.isnan(alpha.grad.data).any():
                alpha.grad.data.zero_()
            encoder_optim.step()
            alpha_optim.step()
            # Load original state dicts
            model_x2y.load_state_dict(state_x2y)
            model_y2x.load_state_dict(state_y2x)
            # Add info
            end = time.time()
            frames.append(
                Namespace(iter_num=meta_iter,
                          regret_x2y=regret_x2y.item(),
                          regret_y2x=regret_y2x.item(),
                          loss=loss.item(),
                          alpha=alpha.item(),
                          theta=encoder.theta.item(),
                          como_time=end - start))
        else:
            # Load original state dicts
            model_x2y.load_state_dict(state_x2y)
            model_y2x.load_state_dict(state_y2x)
            # Add dummy info
            end = time.time()
            frames.append(
                Namespace(iter_num=meta_iter,
                          regret_x2y=float('nan'),
                          regret_y2x=float('nan'),
                          loss=float('nan'),
                          alpha=float('nan'),
                          theta=float('nan'),
                          como_time=end - start))

    return frames
예제 #5
0
def marginal_nll(args, inputs):
    gmm = GaussianMixture(args.gmm_n_gaussians, args.n_features).cuda()
    gmm.fit(inputs, n_iters=args.em_n_iters)
    with torch.no_grad():
        loss_marginal = mdn_nll(gmm(inputs), inputs)
    return loss_marginal
예제 #6
0
def train_encoder(args, model_x2y, model_y2x, scm, encoder, decoder):
    alpha = nn.Parameter(torch.tensor(0.).cuda(), requires_grad=True)

    optim_encoder = optim.Adam(encoder.parameters(), args.encoder_lr)
    optim_alpha = optim.Adam([alpha], lr=args.alpha_lr)

    # print('Start pre-training model_x2y.')
    _ = train_mle_nll(args, model_x2y, scm, encoder, decoder, polarity='X2Y')

    # print('Start pre-training model_y2x.')
    _ = train_mle_nll(args, model_y2x, scm, encoder, decoder, polarity='Y2X')

    results = []

    for meta_iter_num in range(1, args.meta_n_iters + 1):

        _ = train_mle_nll(args,
                          model_x2y,
                          scm,
                          encoder,
                          decoder,
                          polarity='X2Y')
        _ = train_mle_nll(args,
                          model_y2x,
                          scm,
                          encoder,
                          decoder,
                          polarity='Y2X')

        # same mechanism (conditional distribution)
        param = np.random.uniform(-4, 4)
        a_ts = sample_from_normal(param, 2, args.meta_nsamples,
                                  args.n_features)
        b_ts = scm(a_ts)
        a_ts, b_ts = a_ts.cuda(), b_ts.cuda()
        with torch.no_grad():
            x_ts, y_ts = decoder(a_ts, b_ts)

        x_ts, y_ts = encoder(x_ts, y_ts)

        loss_marg_x2y = marginal_nll(args, x_ts) if args.train_gmm else 0.
        loss_marg_y2x = marginal_nll(args, y_ts) if args.train_gmm else 0.

        state_x2y = deepcopy(model_x2y.state_dict())
        state_y2x = deepcopy(model_y2x.state_dict())

        # Inner loop
        optim_x2y = optim.Adam(model_x2y.parameters(), lr=args.finetune_lr)
        optim_y2x = optim.Adam(model_y2x.parameters(), lr=args.finetune_lr)

        loss_x2y, loss_y2x = [], []
        is_nan = False

        for _ in range(args.finetune_n_iters):

            loss_cond_x2y = mdn_nll(model_x2y(x_ts), y_ts)
            loss_cond_y2x = mdn_nll(model_y2x(y_ts), x_ts)

            if torch.isnan(loss_cond_x2y).item() or torch.isnan(
                    loss_cond_y2x).item():
                is_nan = True
                break

            optim_x2y.zero_grad()
            optim_y2x.zero_grad()

            loss_cond_x2y.backward(retain_graph=True)
            loss_cond_y2x.backward(retain_graph=True)

            nan_in_x2y = gradnan_filter(model_x2y)
            nan_in_y2x = gradnan_filter(model_y2x)

            if nan_in_x2y or nan_in_y2x:
                is_nan = True
                break

            optim_x2y.step()
            optim_y2x.step()

            loss_x2y.append(loss_cond_x2y + loss_marg_x2y)
            loss_y2x.append(loss_cond_y2x + loss_marg_y2x)

        if not is_nan:
            loss_x2y = torch.stack(loss_x2y).mean()
            loss_y2x = torch.stack(loss_y2x).mean()

            log_alpha, log_1_m_alpha = F.logsigmoid(alpha), F.logsigmoid(
                -alpha)
            loss = logsumexp(log_alpha + loss_x2y, log_1_m_alpha + loss_y2x)

            optim_encoder.zero_grad()
            optim_alpha.zero_grad()

            loss.backward()

            if torch.isnan(encoder.theta.grad.data).any():
                encoder.theta.grad.data.zero_()

            if torch.isnan(alpha.grad.data).any():
                alpha.grad.data.zero_()

            optim_encoder.step()
            optim_alpha.step()

            model_x2y.load_state_dict(state_x2y)
            model_y2x.load_state_dict(state_y2x)

            print(
                '| Iteration: %d | Prob: %.3f | X2Y_Loss: %.3f | Y2X_Loss: %.3f | Theta: %.3f'
                % (meta_iter_num, torch.sigmoid(alpha).item(), loss_x2y,
                   loss_y2x, encoder.theta.item()))

            with torch.no_grad():
                results.append(
                    Namespace(iter_num=meta_iter_num,
                              sig_alpha=torch.sigmoid(alpha).item(),
                              loss_x2y=loss_x2y,
                              loss_y2x=loss_y2x,
                              theta=encoder.theta.item()))

        else:

            model_x2y.load_state_dict(state_x2y)
            model_y2x.load_state_dict(state_y2x)

            with torch.no_grad():
                results.append(
                    Namespace(iter_num=meta_iter_num,
                              sig_alpha=float('nan'),
                              loss_x2y=float('nan'),
                              loss_y2x=float('nan'),
                              theta=float('nan')))

    return results