def train_alpha(opt, model_x2y, model_y2x, model_g2y, model_g2x, alpha, gt_scm,
                distr, sweep_distr, nll, transfer_metric, mixmode='logmix'):
    # Everyone to CUDA
    if opt.CUDA:
        model_x2y.cuda()
        model_y2x.cuda()
    alpha_optim = torch.optim.Adam([alpha], lr=opt.ALPHA_LR)
    frames = []
    iterations = tnrange(opt.ALPHA_NUM_ITER, leave=False)
    start = time.time()
    for iter_num in iterations:
        # Sample parameter for the transfer distribution
        sweep_param = sweep_distr()
        # Sample X from transfer
        X_gt = distr(sweep_param)
        Y_gt = gt_scm(X_gt)
        with torch.no_grad():
            if opt.CUDA:
                X_gt, Y_gt = X_gt.cuda(), Y_gt.cuda()
        # Evaluate performance
        metric_x2y = transfer_metric(opt, model_x2y, model_g2x, X_gt, Y_gt, nll)
        metric_y2x = transfer_metric(opt, model_y2x, model_g2y, Y_gt, X_gt, nll)
        # Estimate gradient
        if mixmode == 'logmix':
            loss_alpha = torch.sigmoid(alpha) * metric_x2y + (1 - torch.sigmoid(alpha)) * metric_y2x
        else:
            log_alpha, log_1_m_alpha = F.logsigmoid(alpha), F.logsigmoid(-alpha)
            as_lse = logsumexp(log_alpha + metric_x2y, log_1_m_alpha + metric_y2x)
            if mixmode == 'logsigp':
                loss_alpha = as_lse
            elif mixmode == 'sigp':
                loss_alpha = as_lse.exp()
        # Optimize
        alpha_optim.zero_grad()
        loss_alpha.backward()
        alpha_optim.step()
        end = time.time()

        with torch.no_grad():
            gamma = torch.sigmoid(alpha).item()

        if alpha.item() > 0:
            prediction = 1.
        else:
            prediction = 0.

        # Append info
        with torch.no_grad():
            frames.append(Namespace(iter_num=iter_num,
                                    alpha=gamma,
                                    sig_alpha=prediction,
                                    time=end-start,
                                    metric_x2y=metric_x2y,
                                    metric_y2x=metric_y2x,
                                    loss_alpha=loss_alpha.item()))
        iterations.set_postfix(alpha='{0:.4f}'.format(torch.sigmoid(alpha).item()))
    return frames
    def online_loglikelihood(self, logl_A_B, logl_B_A):
        n = logl_A_B.size(0)
        log_alpha, log_1_m_alpha = F.logsigmoid(self.w), F.logsigmoid(-self.w)

        return logsumexp(log_alpha + torch.sum(logl_A_B),
            log_1_m_alpha + torch.sum(logl_B_A))# / float(n)
    def online_loglikelihood(self, logl_A_B, logl_B_A):
        log_alpha, log_1_m_alpha = F.logsigmoid(self.z), F.logsigmoid(-self.z)

        return logsumexp(log_alpha + torch.sum(logl_A_B),
                         log_1_m_alpha + torch.sum(logl_B_A))