def quasi_newton(params, func, init_values, stop_condition=1e-5):
    # * BFGS
    values = Matrix(init_values)
    lam = Symbol('lam')
    next_g = 0
    next_values = 0
    h = eye(len(params))
    step = 0
    while True:
        g = get_grad(params, func)
        g = g.subs(dict(zip(params, list(values))))
        d = -h**(-1) * g
        lam_func = func.subs(dict(zip(params, list(values + lam * d))))
        lam_value = get_stagnation(lam_func)
        next_values = values + lam_value * d
        if get_norm(g) <= stop_condition:
            return list(values), func.subs(dict(zip(params,
                                                    list(next_values))))
        else:
            next_g = get_grad(params, func)
            next_g = next_g.subs(dict(zip(params, list(next_values))))
            s = next_values - values
            y = next_g - g
            h = (eye(len(params)) - (s * y.T) /
                 (s.T * y)[0]) * h * (eye(len(params)) - (s * y.T) /
                                      (s.T * y)[0]).T + (s * s.T) / (s.T *
                                                                     y)[0]
        values = next_values
        f_value = func.subs(dict(zip(params, list(values))))
        print('step: {}  params: {}  f: {}'.format(step, list(values),
                                                   f_value))
        step += 1
예제 #2
0
 def get_grad(self):
     """
     Collect the user-specific, item-specific, and rating-specific gradients
     """
     # Collect user-specific gradients from the user embedding parameters
     u_grad = get_grad(self.user_embedding.parameters())
     # Collect item-specific gradients from the item embedding parameters
     i_grad = get_grad(self.item_embedding.parameters())
     # Collect rating-specific gradients from the recommendation model parameters
     r_grad = get_grad(self.rec_model.parameters())
     return u_grad, i_grad, r_grad
def conjugate_gradient(params, func, init_values, stop_condition=1e-2):
    # * PRP
    values = Matrix(init_values)
    lam = Symbol('lam')
    beta = 0
    previous_d = 0
    previous_g = 0
    step = 0
    while True:
        g = get_grad(params, func)
        g = g.subs(dict(zip(params, list(values))))
        if get_norm(g) <= stop_condition:
            return list(values), func.subs(dict(zip(params, list(values))))
        if previous_g != 0:
            beta = (g.T * (g - previous_g)) / (get_norm(previous_g)**2)
            d = -g + beta[0] * previous_d
        else:
            d = -g
        lam_func = func.subs(dict(zip(params, list(values + lam * d))))
        lam_value = get_stagnation(lam_func)
        values = values + lam_value * d
        previous_d = d
        previous_g = g
        f_value = func.subs(dict(zip(params, list(values))))
        print('step: {}  params: {}  f: {}'.format(step, list(values),
                                                   f_value))
        step += 1
예제 #4
0
    def test_sgd_logreg_2(self):
        X = np.random.randn(46, 7).astype(np.float32)
        w = np.random.randn(7).astype(np.float32)
        y_true = np.random.randint(0, 2, (46)).astype(np.float32)

        dX = val(X)
        dw = val(w)
        dy_true = val(y_true)
        dy_out = rd.build_dot_mv(dX, dw)
        dy_pred = rd.build_vsigmoid(dy_out)
        dloss = rd.build_bce_loss(dy_out, dy_true)

        tX = torch.tensor(X, requires_grad=True)
        tw = torch.tensor(w, requires_grad=True)
        ty_true = torch.tensor(y_true, requires_grad=False)
        ty_out = torch.matmul(tX, tw)
        utils.save_grad(ty_out)
        ty_pred = torch.sigmoid(ty_out)
        criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')
        tloss = criterion(ty_out, ty_true)
        tloss.backward()

        self.ck_fequals(dloss.eval(), tloss.data.numpy(), feps=1e-3)
        self.ck_fequals(dy_pred.eval(), ty_pred.data.numpy())
        self.ck_fequals(
            get_grad(dloss, dy_out).eval(),
            utils.get_grad(ty_out).data.numpy())
        self.ck_fequals(get_grad(dloss, dw).eval(), tw.grad.data.numpy())
        self.ck_fequals(get_grad(dloss, dX).eval(), tX.grad.data.numpy())
예제 #5
0
    def test_sgd_logreg_k(self):
        X = np.random.randn(46, 7).astype(np.float32)
        w = np.random.randn(7, 4).astype(np.float32)
        y_true = np.zeros((46, 4)).astype(np.float32)
        for i in range(y_true.shape[0]):
            y_true[i][np.random.randint(0, y_true.shape[1])] = 1

        dX = val(X)
        dw = val(w)
        dy_true = val(y_true)
        dy_out = rd.build_dot_mm(dX, dw)
        dy_pred = rd.build_softmax(dy_out)
        dloss = rd.build_cross_entropy_loss(dy_out, dy_true)

        tX = torch.tensor(X, requires_grad=True)
        tw = torch.tensor(w, requires_grad=True)
        ty_true = torch.tensor(y_true, requires_grad=False)
        ty_true = torch.argmax(ty_true, dim=1)
        ty_out = torch.matmul(tX, tw)
        ty_pred = torch.nn.functional.softmax(ty_out, dim=1)
        utils.save_grad(ty_out)
        criterion = torch.nn.CrossEntropyLoss(reduction='sum')
        tloss = criterion(ty_out, ty_true)
        tloss.backward()

        self.ck_fequals(dloss.eval(), tloss.data.numpy(), feps=1e-3)
        self.ck_fequals(dy_pred.eval(), ty_pred.data.numpy())
        self.ck_fequals(
            get_grad(dloss, dy_out).eval(),
            utils.get_grad(ty_out).data.numpy())
        self.ck_fequals(get_grad(dloss, dw).eval(), tw.grad.data.numpy())
        self.ck_fequals(get_grad(dloss, dX).eval(), tX.grad.data.numpy())
예제 #6
0
def online_newton_step(batch_data, T, init, G, D, alpha=1):
    '''
    Most parameters are similar to OGD, and the rest are
     :param alpha: Strong convexity parameter
     '''

    n = init.shape[0]   # dimensionality
    xs = [init]
    gamma = 0.5 * min(1 / (4 * G * D), alpha)
    epsilon = 1 / ((gamma * D) ** 2)
    eta = 1 / gamma  # fixed step size
    A = epsilon * np.identity(n)   # initial matrix

    rt = batch_data[:, 0]  # initial ratios
    for t in range(T - 1):
        # compute online gradient
        grad = utl.get_grad(xs[-1], rt)
        # Rank-1 update
        A += np.outer(grad, grad)
        # weighted gradient update
        y = xs[-1] - eta * np.matmul(np.linalg.inv(A), grad)
        # project w.r.t A
        xs.append(utl.simplex_projection_wrt_matrix(y, A, D))
        # observe next data point
        rt = batch_data[:, t + 1]
    return xs
예제 #7
0
    def test_sgd_mse(self):
        X = np.random.randn(46, 7)
        w = np.random.randn(7)
        y_true = np.random.randn(46)

        dX = val(X)
        dw = val(w)
        dy_true = val(y_true)
        dy_pred = rd.build_dot_mv(dX, dw)
        dloss = mse(dy_pred, dy_true)

        tX = torch.tensor(X, requires_grad=True)
        tw = torch.tensor(w, requires_grad=True)
        ty_true = torch.tensor(y_true, requires_grad=True)
        ty_pred = torch.matmul(tX, tw)
        utils.save_grad(ty_pred)
        criterion = torch.nn.MSELoss()
        tloss = criterion(ty_pred, ty_true)
        tloss.backward()

        self.ck_fequals(dloss.eval(), tloss.data.numpy(), feps=1e-3)
        self.ck_fequals(
            get_grad(dloss, dy_pred).eval(),
            utils.get_grad(ty_pred).data.numpy())
        self.ck_fequals(
            get_grad(dloss, dy_true).eval(), ty_true.grad.data.numpy())
        self.ck_fequals(get_grad(dloss, dw).eval(),
                        tw.grad.data.numpy(),
                        feps=1e-4)
        self.ck_fequals(get_grad(dloss, dX).eval(), tX.grad.data.numpy())
예제 #8
0
def newton(params, func, init_values, stop_condition=1e-2):
    values = Matrix(init_values)
    step = 0
    while True:
        g = get_grad(params, func)
        g = g.subs(dict(zip(params, list(values))))
        if get_norm(g) <= stop_condition:
            return list(values), func.subs(dict(zip(params, list(values))))
        h = get_hessian(params, func)
        h = h.subs(dict(zip(params, list(values))))
        values = values - h**(-1) * g
        f_value = func.subs(dict(zip(params, list(values))))
        print('step: {}  params: {}  f: {}'.format(step, list(values),
                                                   f_value))
        step += 1
def steepest_descent(params, func, init_values, stop_condition=1e-10):
    values = Matrix(init_values)
    lam = Symbol('lam')
    step = 0
    while True:
        g = get_grad(params, func)
        g = g.subs(dict(zip(params, list(values))))
        if get_norm(g) <= stop_condition:
            return list(values), func.subs(dict(zip(params, list(values))))
        lam_func = func.subs(dict(zip(params, list(values - lam * g))))
        lam_value = get_stagnation(lam_func)
        values = values - lam_value * g
        f_value = func.subs(dict(zip(params, list(values))))
        print('step: {}  params: {}  f: {}'.format(step, list(values),
                                                   f_value))
        step += 1
예제 #10
0
def online_exponential_gradient(batch_data, T, init, G, D):
    '''
    Similar to OGD with only difference in update rule
    '''

    xs = [init]
    eta = D / (G * np.sqrt(2 * T))  # fixed step size
    rt = batch_data[:, 0]  # initial ratios

    for t in range(T - 1):
        # compute online gradient
        grad = utl.get_grad(xs[-1], rt)
        # perform OEG update (softmax)
        xt = xs[-1] * np.exp(-eta * grad) / np.sum(xs[-1] * np.exp(-eta * grad))
        # add to iterates (no need for projection as xt in the simplex)
        xs.append(xt)
        # observe next data point
        rt = batch_data[:, t + 1]
    return xs
예제 #11
0
def bisection(params, func, a, b, stop_condition=1e-2):
    a = Matrix(a)
    b = Matrix(b)
    step = 0
    while True:
        g = get_grad(params, func)
        g_a = g.subs(dict(zip(params, list(a))))
        g_b = g.subs(dict(zip(params, list(b))))
        assert g_a.values()[0] < 0
        assert g_b.values()[0] > 0
        bi = Matrix([(a[0] + b[0]) / 2])
        g_bi = g.subs(dict(zip(params, list(bi))))
        if g_bi.values()[0] > 0:
            b = bi
        else:
            a = bi
        print('step: {}  a: {}  b: {}'.format(step, list(a), list(b)))
        if np.abs(a[0] - b[0]) <= stop_condition:
            break
        step += 1
예제 #12
0
def calcOrientation(img, kp):
    auxList = []
    sigma = sigma_c * kp.scale
    radius = int(2 * np.ceil(sigma) + 1)
    hist = np.zeros(bins, dtype=np.float32)

    kernel = gaussian_filter(sigma)

    for i in range(-radius, radius + 1):
        y = kp.y + i
        if isOut(img, 1, y):
            continue
        for j in range(-radius, radius + 1):
            x = kp.x + j
            if isOut(img, x, 1):
                continue

            mag, theta = get_grad(img, x, y)
            weight = kernel[i + radius, j + radius] * mag

            binn = quantize_orientation(theta, bins) - 1
            hist[binn] += weight

    maxBin = np.argmax(hist)
    maxBinVal = np.max(hist)

    kp.setDir(maxBin * 10)

    # checking if exist other valeus above 80% of the max
    #print ('->', hist)

    for binno, k in enumerate(hist):
        if binno == maxBin:
            continue
        if k > .85 * maxBinVal:
            nkp = handleKeypoints.KeyPoint(kp.x, kp.y, kp.scale, binno * 10)
            auxList.append(nkp)

    return auxList
예제 #13
0
def online_gradient_descent(batch_data, T, init, G, D):
    '''
    :param batch_data: ratios data of size (n_ratios, T)
    :param T: Horizon of repeated game
    :param init: Initial point x_1
    :param G: Lipschitz coeeficient
    :param D: Diameter
    :return:
    '''

    xs = [init]
    eta = D / (G * np.sqrt(T))  # fixed step size
    rt = batch_data[:, 0]   # initial ratios

    for t in range(T-1):
        # adapting step size
        # eta = D / (G * np.sqrt(t+1))
        # compute online gradient
        grad = utl.get_grad(xs[-1], rt)
        # perform OGD update (with projection)
        xs.append(utl.simplex_projection(xs[-1] - eta * grad))
        # observe next data point
        rt = batch_data[:, t + 1]
    return xs
예제 #14
0
def forward_grad(model, batch, compute_loss, args, compute_grad=True):
    device = args.device

    # divide up batch (for gradient accumulation when memory constrained)
    #num_shards = args.num_train_batch_shards
    # need the max(1, ...) since the last batch in an epoch might be small
    #microbatch_size = max(1, batch[0].size()[0] // num_shards)
    if args.microbatch_size > 0:
        microbatch_size = min(batch[0].size()[0], args.microbatch_size)
    else:
        microbatch_size = batch[0].size()[0]

    # accumulators for the loss & metric values
    accum_loss = 0
    accum_metrics = None

    num_iters = math.ceil(batch[0].size()[0] / microbatch_size)
    for i in range(num_iters):
        # extract current microbatch
        start = i * microbatch_size
        end = (i+1) * microbatch_size
        microbatch = [t[start:end] for t in batch]

        # forward pass
        loss, *metrics = compute_loss(model, microbatch, args)

        # if first time through, we find out how many metrics there are
        if accum_metrics is None:
            accum_metrics = [0 for _ in metrics]

        # accumulate loss & metrics, weighted by how many data points
        # were actually used
        accum_loss += loss.item() * microbatch[0].size()[0]
        for i, m in enumerate(metrics):
            accum_metrics[i] += m.item() * microbatch[0].size()[0]

        # backward pass
        if compute_grad:
            loss.backward()

    # gradient clipping
    if compute_grad and args.max_grad_norm is not None and args.mode not in ["sketch"]:
        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                       args.max_grad_norm * num_iters)

    # "average" here is over the data in the batch
    average_loss = accum_loss / batch[0].size()[0]
    average_metrics = [m / batch[0].size()[0] for m in accum_metrics]

    results = [average_loss] + average_metrics

    if not compute_grad:
        return results

    grad = get_grad(model, args)
    if args.do_dp:
        grad = clip_grad(args.l2_norm_clip, grad)
        if args.dp_mode == "worker":
            noise = torch.normal(mean=0, std=args.noise_multiplier, size=grad.size()).to(args.device)
            noise *= np.sqrt(args.num_workers)
            grad += noise

    # compress the gradient if needed
    if args.mode == "sketch":
        sketch = CSVec(d=args.grad_size, c=args.num_cols,
            r=args.num_rows, device=args.device,
            numBlocks=args.num_blocks)
        sketch.accumulateVec(grad)
        # gradient clipping
        if compute_grad and args.max_grad_norm is not None:
            sketch = clip_grad(args.max_grad_norm, sketch)
        g = sketch.table
    elif args.mode == "true_topk":
        g = grad
    elif args.mode == "local_topk":
        # ideally we'd return the compressed version of the gradient,
        # i.e. _topk(grad, k=args.k). However, for sketching we do momentum
        # in the sketch, whereas for topk we do momentum before taking topk
        # so we have to return an inconsistent quantity here
        g = grad
    elif args.mode == "fedavg":
        # logic for doing fedavg happens in process_batch
        g = grad
    elif args.mode == "uncompressed":
        g = grad

    return g, results
예제 #15
0
Ok_size = int(overlap_ratio * batch_size)
Nk_size = int((1 - 2 * overlap_ratio) * batch_size)

# sample previous overlap gradient
end = 0

random_index = np.random.permutation(range(len(train_dataset)))

begin = time.time()
Ok_prev = random_index[0:Ok_size]
X_trains, y_trains = train_dataset.getItems(Ok_prev)
end = time.time() - begin
print(end)

begin = time.time()
g_Ok_prev, obj_Ok_prev = get_grad(optimizer, X_trains, y_trains, opfun)
end = time.time() - begin
print(end)

# main loop

for n_iter in range(max_iter):

    # training mode
    model.train()

    # sample current non-overlap and next overlap gradient
    begin = time.time()
    random_index = np.random.permutation(range(len(train_dataset)))
    Ok = random_index[0:Ok_size]
    Nk = random_index[Ok_size:(Ok_size + Nk_size)]
optimizer = LBFGS(model.parameters(),
                  lr=lr,
                  history_size=10,
                  line_search='None',
                  debug=True)

#%% Main training loop

Ok_size = int(overlap_ratio * batch_size)
Nk_size = int((1 - 2 * overlap_ratio) * batch_size)

# sample previous overlap gradient
random_index = np.random.permutation(range(X_train.shape[0]))
Ok_prev = random_index[0:Ok_size]
g_Ok_prev, obj_Ok_prev = get_grad(optimizer, X_train[Ok_prev],
                                  y_train[Ok_prev], opfun)

# main loop
for n_iter in range(max_iter):

    # training mode
    model.train()

    # sample current non-overlap and next overlap gradient
    random_index = np.random.permutation(range(X_train.shape[0]))
    Ok = random_index[0:Ok_size]
    Nk = random_index[Ok_size:(Ok_size + Nk_size)]

    # compute overlap gradient and objective
    g_Ok, obj_Ok = get_grad(optimizer, X_train[Ok], y_train[Ok], opfun)
예제 #17
0
# Define optimizer
optimizer = LBFGS(model.parameters(), lr=1., history_size=10, line_search='Wolfe', debug=True)

# Main training loop
for n_iter in range(max_iter):
    
    # training mode
    model.train()
    
    # sample batch
    random_index = np.random.permutation(range(X_train.shape[0]))
    Sk = random_index[0:batch_size]
    
    # compute initial gradient and objective
    grad, obj = get_grad(optimizer, X_train[Sk], y_train[Sk], opfun)
    
    # two-loop recursion to compute search direction
    p = optimizer.two_loop_recursion(-grad)
            
    # define closure for line search
    def closure():              
        
        optimizer.zero_grad()
        
        if cuda:
            loss_fn = torch.tensor(0, dtype=torch.float).cuda()
        else:
            loss_fn = torch.tensor(0, dtype=torch.float)
        
        for subsmpl in np.array_split(Sk, max(int(batch_size / ghost_batch), 1)):
예제 #18
0
파일: main.py 프로젝트: Ykmoon/SVM-pegasos
def main():
    # read the train file from first arugment
    train_file = sys.argv[1]
    #train_file='../data/covtype.scale.trn.libsvm'
    # read the test file from second argument
    test_file = sys.argv[2]
    #test_file = '../data/covtype.scale.tst.libsvm'

    # You can use load_svmlight_file to load data from train_file and test_file
    X_train, y_train = load_svmlight_file(train_file)
    X_test, y_test = load_svmlight_file(test_file)

    # You can use cg.ConjugateGradient(X, I, grad, lambda_)
    # Main entry point to the program
    X_train = sparse.hstack([X_train, np.ones((X_train.shape[0], 1))])
    X_test = sparse.hstack([X_test, np.ones((X_test.shape[0], 1))])

    X = sparse.csr_matrix(X_train)
    X_test = sparse.csr_matrix(X_test)

    y = sparse.csr_matrix(y_train).transpose()
    y_test = sparse.csr_matrix(y_test).transpose()

    #set global hyper parameter
    if sys.argv[1] == "covtype.scale.trn.libsvm":
        lambda_ = 3631.3203125
        optimal_loss = 2541.664519
        five_fold_CV = 75.6661
        optimal_function_value = 2541.664519

    else:
        lambda_ = 7230.875
        optimal_loss = 669.664812
        five_fold_CV = 97.3655
        optimal_function_value = 669.664812

    #SGD
    #set local sgd hyper parameter
    print('starting SGD...')
    n_batch = 1000
    beta = 0
    lr = 0.001
    w = np.zeros((X_train.shape[1]))
    n = X_train.shape[0]
    sgd_grad = []
    sgd_time = []
    sgd_rel = []
    sgd_test_acc = []
    epoch = 180
    start = time.time()
    #redefine learaning rate
    for i in range(epoch):
        gamma_t = lr / (1 + beta * i)
        batch_ = np.random.permutation(n)  #shuffle
        for j in range(n // n_batch):
            #make batch
            idx = batch_[j * n_batch:(j + 1) * n_batch]
            X_bc = X[idx]
            y_bc = y[idx]

            grad = get_grad(w, lambda_, n, X_bc, y_bc,
                            n_batch)  #comput gradient

            w = w - gamma_t * grad  #update gradient

        t = time.time() - start
        sgd_time.append(t)  # append to time list

        grad_ = np.linalg.norm(grad)  # get gradient value
        sgd_grad.append(grad_)

        rel = (get_loss(w, lambda_, X_test, y_test, n_batch) -
               optimal_loss) / optimal_loss  # get relative func value
        sgd_rel.append(rel)

        test_acc = get_acc(w, lambda_, X_test, y_test,
                           n_batch)  # get test accuracy
        sgd_test_acc.append(test_acc)
    print("SGD : final_time: {}, fina_test_acc: {}".format(
        time.time() - start, sgd_test_acc[-1]))

    #plot SGD
    '''
    plt.plot(sgd_time, sgd_grad)
    plt.xlabel("time")
    plt.ylabel("grad")
    plt.title("SGD")
    plt.show()

    plt.plot(sgd_time, sgd_rel)
    plt.xlabel("time")
    plt.ylabel("relative function")
    plt.title("SGD")
    plt.show()


    plt.plot(sgd_time, sgd_test_acc)
    plt.xlabel("time")
    plt.ylabel("test_acc")
    plt.title("SGD")
    plt.show()

    '''
    print('starting Newton...')
    #Newton
    #set local newton hyper parameter
    epoch = 50
    n_batch = 1000
    beta = 0.0001
    lr = 0.001
    w = np.zeros((X_train.shape[1]))
    n = X_train.shape[0]
    nt_grad = []
    nt_time = []
    nt_rel = []
    newton_time = time.time()

    nt_test_acc = []
    w = np.zeros((X_train.shape[1]))
    n = X_train.shape[0]

    for i in range(epoch):
        gamma_t = lr / (1 + beta * i)
        hessian_total = np.zeros(w.shape)
        I_ = []  #init I list to compute conjgate gradient
        for j in range(n // n_batch):
            X_bc = X[j * n_batch:(j + 1) * n_batch]  #make X_batch
            y_bc = y[j * n_batch:(j + 1) * n_batch]  #make y_batch

            hessian, I = get_hessian(w, lambda_, n, X_bc, y_bc)  # get hessian
            hessian_total += hessian
            I_.append(I)
        I_ = np.concatenate(I_)
        hessian_total += w

        delta, _ = cg.conjugateGradient(
            X, I_, hessian_total,
            lambda_)  #get update value from conjugateGradient

        w = w + delta  #update w

        t = time.time() - newton_time
        nt_time.append(t)  # append to time list

        grad_ = np.linalg.norm(hessian_total)  # get gradient value
        nt_grad.append(grad_)

        rel = (get_loss(w, lambda_, X_test, y_test, n_batch) -
               optimal_loss) / optimal_loss  # get relative func value
        nt_rel.append(rel)

        test_acc = get_acc(w, lambda_, X_test, y_test,
                           n_batch)  # get test accuracy
        nt_test_acc.append(test_acc)
    final_time = time.time() - newton_time
    print("final_time: {}, fina_test_acc: {}".format(final_time,
                                                     nt_test_acc[-1]))

    #plot
    '''
accfun = lambda op, y: np.mean(np.equal(predsfun(op), y.squeeze())) * 100

#%% Define optimizer

optimizer = FullBatchLBFGS(model.parameters(),
                           lr=1,
                           history_size=10,
                           line_search='Wolfe',
                           debug=True)

#%% Main training loop

no_samples = X_train.shape[0]

# compute initial gradient and objective
grad, obj = get_grad(optimizer, X_train, y_train, opfun)

# main loop
for n_iter in range(max_iter):

    # training mode
    model.train()

    # define closure for line search
    def closure():

        optimizer.zero_grad()

        if (torch.cuda.is_available()):
            loss_fn = torch.tensor(0, dtype=torch.float).cuda()
        else:
예제 #20
0
        crit_losses = []
        for _ in range(CRITIC_ITERS):
            # calculate discriminator loss
            noise = gen_noise(cur_batch_size, NOISE_DIM, device=device)
            fake = gen(noise)
            disc_fake = critic(fake.detach())
            disc_real = critic(real)

            # calculate gradient penalty
            epsilon = torch.rand(cur_batch_size,
                                 1,
                                 1,
                                 1,
                                 device=device,
                                 requires_grad=True)
            grads = get_grad(critic, real, fake.detach(), epsilon)
            gp = gradient_penalty(grads)

            # calculate discriminator loss
            disc_loss = -(torch.mean(disc_real) -
                          torch.mean(disc_fake)) + C_LAMBDA * gp

            # update discriminator
            opt_disc.zero_grad()
            disc_loss.backward(retain_graph=True)
            opt_disc.step()
            crit_losses.append(disc_loss.item())

        # monitor running loss
        lossD.update(np.mean(crit_losses), cur_batch_size)