示例#1
0
def train(args):
    # gpu init
    multi_gpu = False
    if len(args.gpus.split(',')) > 1:
        multi_gpu = True
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    G = AAD_Gen()
    F = ArcFace_Net(args.backbone, args.test_model_path)  # no need to train
    E = Att_Encoder()
    H = HEARNet()

    G = load_model(G, 'path_to_G_model')
    E = load_model(E, 'path_to_E_model')

    G.eval()
    E.eval()

    optimizer = torch.optim.Adam({'params': H.parameters()},
                                 lr=0.0004,
                                 betas=(0.0, 0.999))

    if multi_gpu:
        H = DataParallel(D).to(device)
    else:
        H = D.to(device)

    for epoch in range(1, args.total_epoch + 1):
        H.train()
        #        F.test()      Only extract features!  # input dim=3,256,256   out dim=256 !
        for step, data in enumerate(aug_data_loader):
            try:
                img, label = data
            except Exception as e:
                continue
            source = img[:4, :, :, :].to(device)
            target = img[[0, 1, 2, 4], :, :, :].to(device)

            Y_tt = G(F(target), E(target))
            error = target - Y_tt
            Yst0 = G(F(source), E(target))
            Yst = H(torch.cat((Yst0, error), dim=1))

            optimizer.zero_grad()

            L_id = IdLoss()(F(Yst), F(source))
            L_chg = ChgLoss()(Yst0, Yst)
            L_rec = RecLoss()(Yst0[:-1, :, :, :], target[:-1, :, :, :], label)

            Loss = (L_id + L_chg + L_rec).to(device)
            Loss.backward()
            optimizer.step()
示例#2
0
    def predict(self, input_ids, attention_mask, chosen_sub, threshold):
        pred_sub_start,pred_sub_end,pred_obj_start,pred_obj_end=self.forward(input_ids=input_ids,\
            attention_mask=attention_mask,chosen_sub=chosen_sub)

        one = self.one
        zero = self.zero
        #[batch_size,max_seq_len]

        F = lambda x: torch.where(x > threshold, one, zero).long()
        pred_sub_start = F(pred_sub_start)
        pred_sub_end = F(pred_sub_end)
        pred_obj_start = F(pred_obj_start)
        pred_obj_end = F(pred_obj_end)

        return pred_sub_start, pred_sub_end, pred_obj_start, pred_obj_end
示例#3
0
def test_forward(input, kernel_size=3, padding=1, stride=2, dilation=1):
    # input = (Variable(torch.FloatTensor(torch.randn(1, 1, 5, 5)), requires_grad=True),
    #          Variable(torch.FloatTensor(torch.randn(1, 9)), requires_grad=True),)

    F = RRSVM.RRSVM_F(kernel_size,
                      padding,
                      stride,
                      dilation=1,
                      return_indices=True)
    analytical, analytical_indices = F(*input)
    analytical = analytical.data.numpy()
    analytical_indices = analytical_indices.data.numpy()
    numerical, numerical_indices = get_numerical_output(
        *input,
        kernel_size=kernel_size,
        padding=padding,
        stride=stride,
        dilation=1)

    atol = 1e-5
    rtol = 1e-3
    if not (np.absolute(numerical - analytical) <=
            (atol + rtol * np.absolute(numerical))).all():
        print "Output Failed Foward Test"
    else:
        print "Ouput Pass Foward Test"

    if not (np.absolute(analytical_indices - numerical_indices) <=
            (atol + rtol * np.absolute(numerical))).all():
        print "Indices Failed Foward Test"
    else:
        print "Indices Pass Foward Test"
示例#4
0
    def __init__(self, *args, **kwargs):
        tk.Tk.__init__(self, *args, **kwargs)

        self.title_font = tkfont.Font(family='Helvetica',
                                      size=18,
                                      weight="bold",
                                      slant="italic")
        #self.filepath= tk.StringVar()
        # the container is where we'll stack a bunch of frames
        # on top of each other, then the one we want visible
        # will be raised above the others
        container = tk.Frame(self)
        container.pack(side="top", fill="both", expand=True)
        container.grid_rowconfigure(0, weight=1)
        container.grid_columnconfigure(0, weight=1)

        self.frames = {}
        for F in (StartPage, PageSubmit, PageSeeVal):
            page_name = F.__name__
            frame = F(parent=container, controller=self)
            self.frames[page_name] = frame

            # put all of the pages in the same location;
            # the one on the top of the stacking order
            # will be the one that is visible.
            frame.grid(row=0, column=0, sticky="nsew")

        self.show_frame("StartPage")
示例#5
0
def test(step, pred_save_pth):
    F.eval()
    C.eval()
    image_id = []
    pred_label = []
    with torch.no_grad():
        # corrects = torch.zeros(1).to("cuda")
        # for idx, (src, labels) in enumerate(eval_loader):
        #     src, labels = src.to("cuda"), labels.to("cuda")
        #     c = C(F(src))
        #     _, preds = torch.max(c, 1)
        #     corrects += (preds == labels).sum()
        # acc = corrects.item() / len(eval_loader.dataset)
        # print('***** Eval Result: {:.4f}, Step: {}'.format(acc, step))

        corrects = torch.zeros(1).to("cuda")
        for tgt, labels in test_loader:
            tgt = tgt.to("cuda")
            # print(type(c.cpu()))
            h, c = C(F(tgt))

            _, preds = torch.max(c, 1)
            # pred = c[0].max(1,keepdim = True)[1]
            # corrects += (preds == labels).sum()
            for img_id, im_pd in zip(labels, preds):
                # print(im_pd)
                image_id.append(img_id)
                pred_label.append(int(im_pd.cpu().numpy()))
        # acc = corrects.item() / len(test_loader.dataset)
        # print('***** Test Result: {:.4f}, Step: {}'.format(acc, step))
        dicts = {"image_name": image_id, "label": pred_label}
        DF = pd.DataFrame(dicts)
        DF.to_csv(pred_save_pth, index=0)
    F.train()
    C.train()
示例#6
0
def ampSolver(hBatch, yBatch, Symb, noise_sigma):
	def F(x_in, tau_l, Symb):
		arg = -(x_in - Symb.reshape((1,1,-1))) ** 2 / 2. / tau_l
		exp_arg = np.exp(arg - np.max(arg, axis=2, keepdims=True))
		prob = exp_arg / np.sum(exp_arg, axis=2, keepdims=True)
		f = np.matmul(prob, Symb.reshape((1,-1,1)))
		return f

	def G(x_in, tau_l, Symb):
		arg = -(x_in - Symb.reshape((1,1,-1))) ** 2 / 2. / tau_l
		exp_arg = np.exp(arg - np.max(arg, axis=2, keepdims=True))
		prob = exp_arg / np.sum(exp_arg, axis=2, keepdims=True)
		g = np.matmul(prob, Symb.reshape((1,-1,1)) ** 2) - F(x_in, tau_l, Symb) ** 2
		return g

	numIterations = 50
	NT = hBatch.shape[2]
	NR = hBatch.shape[1]
	N0 = noise_sigma ** 2 / 2.
	xhat = np.zeros((numIterations, hBatch.shape[0], hBatch.shape[2], 1))
	z = np.zeros((numIterations, hBatch.shape[0], hBatch.shape[2], 1))
	r = np.zeros((numIterations, hBatch.shape[0], hBatch.shape[1], 1))
	tau = np.zeros((numIterations, hBatch.shape[0], 1, 1))
	r[0] = yBatch
	for l in range(numIterations-1):
		z[l] = xhat[l] + np.matmul(hBatch.transpose((0,2,1)), r[l])
		xhat[l+1] = F(z[l], N0 * (1.+tau[l]), Symb)
		tau[l+1] = float(NT) / NR / N0 * np.mean(G(z[l], N0 * (1. + tau[l]), Symb), axis=1, keepdims=True)
		r[l+1] = yBatch - np.matmul(hBatch, xhat[l+1]) + tau[l+1]/(1.+tau[l]) * r[l]

	return xhat[l+1]
示例#7
0
 def forward(self, x):
     #sigmoid = nn.Sigmoid()
     # F = MySigmoid.apply
     # F = MyReLU.apply
     # F = nn.Sigmoid()
     F = self.af
     x = F(self.fc1(x))
     x = F(self.fc2(x))
     # x = F(self.fc3(x))
     # x = F(self.fc4(x))
     # x = F(self.fc5(x))
     # x = F(self.fc6(x))
     # x = F(self.fc7(x))
     # x = F(self.fc8(x))
     x = F(self.fc9(x))
     x = self.fc10(x)
     return x
示例#8
0
def cost(z0, z1, z2):
    # impose structure of PDE
    s0 = torch.sum(torch.square(torch.abs(F(z0))))
    # impose initial condition
    s1 = torch.sum(torch.square(torch.abs(model(z1) - u0(z1))))

    #impose boundary condition
    s2 = torch.sum(torch.square(torch.abs(model(z2) - g1(z2))))
    return s0 + s1 + s2
示例#9
0
 def validate(G, F, loader):
     acc = AverageMeter()
     G.eval()
     F.eval()
     for x, y in loader:
         x, y = x.to(device), y.to(device)
         features, _ = G.forward_features(x)
         y_pred = F(features)
         acc.update(accuracy(y_pred.data, y, topk=(1,))[0].item(), x.size(0))
     return acc.avg
def CycleConsistencyLoss(G, F, fake_X, fake_Y, X, Y):
    """
    compute the cycle consistency loss L_cyc (G, F ) 
    """
    x_cycled = F(fake_Y)
    y_cycled = G(fake_X)
    loss1 = nn.functional.l1_loss(x_cycled, X)
    loss2 = nn.functional.l1_loss(y_cycled, Y)
    loss = loss1 + loss2
    return loss / 2
示例#11
0
    def cal_loss(self, pred_sub_vec, sub_vec, pred_obj_vec, obj_vec):

        pred_sub_start, pred_sub_end = pred_sub_vec
        target_sub_start, target_sub_end = sub_vec

        pred_obj_start, pred_obj_end = pred_obj_vec
        target_obj_start, target_obj_end = obj_vec

        #import pdb;pdb.set_trace()
        F = lambda x, y, weights: weighted_binary_cross_entropy(
            x, y, weights=weights)
        sub_start_loss = F(pred_sub_start, target_sub_start, None)
        sub_end_loss = F(pred_sub_end, target_sub_end, None)

        obj_start_loss = F(pred_obj_start, target_obj_start, self.weight)
        obj_end_loss = F(pred_obj_end, target_obj_end, self.weight)

        print("sub_start:{}\t sub_end:{}\t obj_start:{}\t obj_end:{}"\
            .format(sub_start_loss.item(),sub_end_loss.item(),obj_start_loss.item(),obj_end_loss.item()))

        return sub_start_loss + sub_end_loss + obj_start_loss + obj_end_loss
示例#12
0
 def forward(self, output1, output2, label):
     # FIXME: need to check if this distance calculation is right
     F = nn.PairwiseDistance(p=2)
     # we want to add an empty last dimension
     output1 = output1.unsqueeze(output1.dim())
     output2 = output2.unsqueeze(output2.dim())
     euclidean_distance = F(output1, output2)
     loss_contrastive = torch.mean(
         (1 - label) * torch.pow(euclidean_distance, 2) + (label) *
         torch.pow(torch.clamp(self.margin -
                               euclidean_distance, min=0.0), 2))
     return loss_contrastive
示例#13
0
    def backward(ctx, z_grad, log_s_grad):
        F = ctx.F
        z, spect, speaker_ids, audio_out = ctx.saved_tensors

        audio_0_out, audio_1_out = audio_out.chunk(2, 1)
        audio_0_out, audio_1_out = audio_0_out.contiguous(
        ), audio_1_out.contiguous()
        dza, dzb = z_grad.chunk(2, 1)
        dza, dzb = dza.contiguous(), dzb.contiguous()

        with set_grad_enabled(True):
            audio_0 = audio_0_out
            audio_0.requires_grad = True
            log_s, t = F(audio_0, spect, speaker_ids)

        with torch.no_grad():
            s = torch.exp(log_s).half(
            )  # exp not implemented for fp16 therefore this is cast to fp32 by Nvidia/Apex
            audio_1 = (audio_1_out -
                       t) / s  # s is fp32 therefore audio_1 is cast to fp32.
            z.storage().resize_(reduce(mul, audio_1.shape) * 2)  # z is fp16
            if z.dtype == torch.float16:  # if z is fp16, cast audio_0 and audio_1 back to fp16.
                torch.cat((audio_0.half(), audio_1.half()), 1,
                          out=z)  #fp16  # .contiguous()
            else:
                torch.cat((audio_0, audio_1), 1, out=z)  #fp32  # .contiguous()
            #z.copy_(xout)  # .detach()

        with set_grad_enabled(True):
            param_list = [audio_0] + list(F.parameters())
            if ctx.needs_input_grad[1]:
                param_list += [spect]
            if ctx.needs_input_grad[2]:
                param_list += [speaker_ids]
            dtsdxa, *dw = grad(torch.cat((log_s, t), 1),
                               param_list,
                               grad_outputs=torch.cat(
                                   (dzb * audio_1 * s + log_s_grad, dzb), 1))

            dxa = dza + dtsdxa
            dxb = dzb * s
            dx = torch.cat((dxa, dxb), 1)
            if ctx.needs_input_grad[1]:
                *dw, dy = dw
            else:
                dy = None
            if ctx.needs_input_grad[2]:
                *dw, ds = dw
            else:
                ds = None

        return (dx, dy, ds, None) + tuple(dw)
示例#14
0
 def forward(ctx, audio_out, spect, speaker_ids, F, *F_weights):
     ctx.F = F
     with torch.no_grad():
         audio_0_out, audio_1_out = audio_out.chunk(2, 1)
         audio_0_out, audio_1_out = audio_0_out.contiguous(), audio_1_out.contiguous()
         
         log_s, t = F(audio_0_out, spect, speaker_ids)
         audio_1 = (audio_1_out - t) / log_s.exp()
         audio_0 = audio_0_out
         z = torch.cat((audio_0, audio_1), 1)
     
     ctx.save_for_backward(audio_out.data, spect, speaker_ids, z)
     return z, -log_s
    def forward(ctx, z, y, F, *F_weights):
        ctx.F = F
        with torch.no_grad():
            za, zb = z.chunk(2, 1)
            za, zb = za.contiguous(), zb.contiguous()

            log_s, t = F(za, y)
            xb = (zb - t) / log_s.exp()
            xa = za
            x = torch.cat((xa, xb), 1)

        ctx.save_for_backward(z.data, y, x)
        return x, -log_s
示例#16
0
    def forward(ctx, z, spect, speaker_ids, F, *F_weights):
        ctx.F = F
        with torch.no_grad():
            audio_0, audio_1 = z.chunk(2, 1)
            audio_0, audio_1 = audio_0.contiguous(), audio_1.contiguous()

            log_s, t = F(audio_0, spect, speaker_ids)
            audio_1_out = audio_1 * log_s.exp() + t
            audio_0_out = audio_0
            audio_out = torch.cat((audio_0_out, audio_1_out), 1)

        ctx.save_for_backward(z.data, spect, speaker_ids, audio_out)
        return audio_out, log_s
    def forward(ctx, x, y, F, *F_weights):
        ctx.F = F
        with torch.no_grad():
            xa, xb = x.chunk(2, 1)
            xa, xb = xa.contiguous(), xb.contiguous()

            log_s, t = F(xa, y)
            zb = xb * log_s.exp() + t
            za = xa
            z = torch.cat((za, zb), 1)

        ctx.save_for_backward(x.data, y, z)
        return z, log_s
示例#18
0
文件: Net_data.py 项目: GGOSinon/GCN
 def forward(self, x0):
     connections = self.connections
     nodes = self.nodes
     c = np.zeros(len(connections), dtype=np.int)
     deg = np.zeros(len(nodes), dtype=np.int)
     num_node = {}
     for i,key in enumerate(nodes):
         num_node[key]=i
     #print(num_node)
     for key in connections:
        conn = connections[key]
        if conn.enabled==False: continue
        e = num_node[conn.e]
        deg[e]+=1
     x = {}
     x[0] = x0.view(-1, 3*32*32)
     #print(num_node)
     #for num in connections:
     #    conn = connections[num]
     #    print(conn.s, conn.e)
     
     while True:
         #print(deg) 
         #print(connections)
         for i,key in enumerate(connections):
             conn = connections[key]
             if conn.enabled == False: continue
             s = conn.s
             e = conn.e
             ns = num_node[s]
             ne = num_node[e]
             if c[i] == 0 and deg[ns]==0:
                c[i] = 1
                variable_name = 'self.fc'+str(key)
                F = getattr(self, variable_name)
                if s not in x:
                    deg[ne] = 0
                    continue
                X = F(x[s])
                if e in x: x[e] += X
                else: x[e] = X
                deg[ne]-=1
                if deg[ne]==0: x[s]=nodes[s].actF(x[s])
         #print(deg)
         done = True
         for i in range(len(nodes)):
             if deg[i]>0: done = False
         if done: break
     return x[1] #always 0 is input, 1 is output
示例#19
0
    def forward(self,
                F,
                G,
                E,
                scale,
                alpha,
                z,
                labels=None,
                hessian_layers=[3],
                current_layer=[0]):
        F_z = F(z, scale, z2=None, p_mix=0)

        # Autoencoding loss in latent space
        G_z = G(F_z, scale, alpha)
        E_z = E(G_z, alpha)

        if labels is not None:
            E_z = E_z.reshape(E_z.shape[0], 1,
                              E_z.shape[1]).repeat(1, F_z.shape[1], 1)
            if self.use_dist:
                x = self.p_dist(F_z, E_z)
                y = torch.eq(labels, labels.T).float().to(x.device)
                loss = self.loss_fn(x, y)
            else:
                perm = torch.randperm(E_z.shape[0], device=E_z.device)
                E_z_hat = torch.index_select(E_z, 0, perm)
                F_z_hat = torch.index_select(F_z, 0, perm)
                F_hat = torch.cat([F_z, F_x_hat], 0)
                E_hat = torch.cat([E_z, E_z_hat], 0)
                loss = self.loss_fn(F_hat, E_hat, labels)
        else:
            F_x = F_z[:, 0, :]
            loss = self.loss_fn(F_x, E_z)

        if self.use_tv:
            loss += self.total_variation(G_z)

        # Hessian applied to G here
        if self.enable_hessian:
            h_loss = hessian_penalty(G,
                                     z=F_z,
                                     scale=scale,
                                     alpha=alpha,
                                     return_norm=hessian_layers)
            h_loss *= self.hessian_weight
            if current_layer in hessian_layers:
                h_loss = h_loss * alpha
            loss += h_loss
        return loss
示例#20
0
def feature(sourceset, targetset, F, device):
    F.eval()
    features = np.empty((0, 256))
    targets = np.empty((0, ), dtype=np.int8)
    domains = np.empty((0, ), dtype=np.int8)
    with torch.no_grad():
        for data in sourceset:
            image = data['image'].to(device)
            target = data['label'].to(device)
            domain = data['domain'].to(device)
            latent = F(image)
            features = np.concatenate((features, latent.cpu().numpy()), axis=0)
            targets = np.concatenate((targets, target), axis=0)
            domains = np.concatenate((domains, domain), axis=0)
        for data in targetset:
            image = data['image'].to(device)
            target = data['label'].to(device)
            domain = data['domain'].to(device)
            latent = F(image)
            features = np.concatenate((features, latent.cpu().numpy()), axis=0)
            targets = np.concatenate((targets, target), axis=0)
            domains = np.concatenate((domains, domain), axis=0)
    print(features.shape, targets.shape, domains.shape)
    draw_features(features, targets, 10, domains)
示例#21
0
def test_gradient(input, kernel_size=3, padding=0, stride=1):

    F = RRSVM.RRSVM_F(kernel_size=kernel_size,
                      padding=padding,
                      stride=stride,
                      dilation=1)

    test = gradcheck(lambda i, s: F(i, s),
                     inputs=input,
                     eps=1e-3,
                     atol=1e-3,
                     rtol=1e-3)
    if test == True:
        print("Gradient Check Passed!")
    else:
        print("Gradient Check Failed!")
示例#22
0
    def backward(ctx, x_grad, log_s_grad):
        F = ctx.F
        audio_out, spect, speaker_ids, z = ctx.saved_tensors

        audio_0, audio_1 = z.chunk(2, 1)
        audio_0, audio_1 = audio_0.contiguous(), audio_1.contiguous()
        dxa, dxb = x_grad.chunk(2, 1)
        dxa, dxb = dxa.contiguous(), dxb.contiguous()

        with set_grad_enabled(True):
            audio_0_out = audio_0
            audio_0_out.requires_grad = True
            log_s, t = F(audio_0_out, spect, speaker_ids)
            s = log_s.exp()

        with torch.no_grad():
            audio_1_out = audio_1 * s + t

            audio_out.storage().resize_(reduce(mul, audio_1_out.shape) * 2)
            torch.cat((audio_0_out, audio_1_out), 1, out=audio_out)
            #audio_out.copy_(zout)

        with set_grad_enabled(True):
            param_list = [audio_0_out] + list(F.parameters())
            if ctx.needs_input_grad[1]:
                param_list += [spect]
            if ctx.needs_input_grad[2]:
                param_list += [speaker_ids]
            dtsdza, *dw = grad(
                torch.cat((-log_s, -t / s), 1),
                param_list,
                grad_outputs=torch.cat(
                    (dxb * audio_1_out / s.detach() + log_s_grad, dxb), 1))

            dza = dxa + dtsdza
            dzb = dxb / s.detach()
            dz = torch.cat((dza, dzb), 1)
            if ctx.needs_input_grad[1]:
                *dw, dy = dw
            else:
                dy = None
            if ctx.needs_input_grad[2]:
                *dw, ds = dw
            else:
                ds = None

        return (dz, dy, ds, None) + tuple(dw)
示例#23
0
def loss_generator_hessian(G,
                           F,
                           z,
                           scale,
                           alpha,
                           scale_alpha=False,
                           hessian_layers=[3],
                           current_layer=[0],
                           hessian_weight=0.01):
    loss = hessian_penalty(G,
                           z=F(z, scale, z2=None, p_mix=0),
                           scale=scale,
                           alpha=alpha,
                           return_norm=hessian_layers)
    if current_layer in hessian_layers or scale_alpha:
        loss = loss * alpha
    return loss * hessian_weight
示例#24
0
def testing(dataset, F, C, device, output_path):
    F.eval()
    C.eval()
    with open(output_path, 'w') as f:
        f.write('image_name, label\n')
    with torch.no_grad():
        corrects = torch.zeros(1).to(device)
        for data in dataset:
            image = data['image'].to(device)
            name = data['name']
            feature = F(image)
            label = C(feature)
            _, predicts = torch.max(label, 1)

            for image_name, label in zip(name, predicts):
                with open(output_path, 'a') as f:
                    f.write(f'{image_name},{label}\n')
示例#25
0
    def backward(ctx, z_grad, log_s_grad):
        F = ctx.F
        z, spect, speaker_ids, audio_out = ctx.saved_tensors

        audio_0_out, audio_1_out = audio_out.chunk(2, 1)
        audio_0_out, audio_1_out = audio_0_out.contiguous(
        ), audio_1_out.contiguous()
        dza, dzb = z_grad.chunk(2, 1)
        dza, dzb = dza.contiguous(), dzb.contiguous()

        with set_grad_enabled(True):
            audio_0 = audio_0_out
            audio_0.requires_grad = True
            log_s, t = F(audio_0, spect, speaker_ids)

        with torch.no_grad():
            s = log_s.exp()
            audio_1 = (audio_1_out - t) / s
            z.storage().resize_(reduce(mul, audio_1.shape) * 2)
            torch.cat((audio_0, audio_1), 1, out=z)  #fp32  # .contiguous()
            #torch.cat((audio_0.half(), audio_1.half()), 1, out=z)#fp16  # .contiguous()
            #z.copy_(xout)  # .detach()

        with set_grad_enabled(True):
            param_list = [audio_0] + list(F.parameters())
            if ctx.needs_input_grad[1]:
                param_list += [spect, speaker_ids]
            dtsdxa, *dw = grad(torch.cat((log_s, t), 1),
                               param_list,
                               grad_outputs=torch.cat(
                                   (dzb * audio_1 * s + log_s_grad, dzb), 1))

            dxa = dza + dtsdxa
            dxb = dzb * s
            dx = torch.cat((dxa, dxb), 1)
            if ctx.needs_input_grad[1]:
                *dw, dy = dw
            else:
                dy = None
            if ctx.needs_input_grad[2]:
                *dw, ds = dw
            else:
                ds = None

        return (dx, dy, ds, None) + tuple(dw)
def fullLoss(G, F, D_X, D_Y, X, Y, idx, lam=10.0):
    """
    compute the loss for cycle Gan
    """
    fake_X = F(Y)
    fake_Y = G(X)
    l_gan1 = AdversarialLoss(D_Y, fake_Y)
    l_gan2 = AdversarialLoss(D_X, fake_X)
    l_cyc = CycleConsistencyLoss(G, F, fake_X, fake_Y, X, Y)
    loss = l_gan1 + l_gan2 + lam * l_cyc

    if idx % 50 == 0:
        print(
            'Advers_loss_Dy_Gx: %.3f Advers_loss_Dx_Fy: %.3f Cyc_loss: %.3f' %
            (l_gan1, l_gan2, l_cyc))
    plt_Advers_loss_Dy_Gx.append(l_gan1)
    plt_Advers_loss_Dx_Fy.append(l_gan2)
    plt_Cyc_loss.append(l_cyc)
    plt_loss.append(loss)
    return loss
    def backward(ctx, x_grad, log_s_grad):
        F = ctx.F
        z, y, x = ctx.saved_tensors

        xa, xb = x.chunk(2, 1)
        xa, xb = xa.contiguous(), xb.contiguous()
        dxa, dxb = x_grad.chunk(2, 1)
        dxa, dxb = dxa.contiguous(), dxb.contiguous()

        with set_grad_enabled(True):
            za = xa
            za.requires_grad = True
            log_s, t = F(za, y)
            s = log_s.exp()

        with torch.no_grad():
            zb = xb * s + t

            z.storage().resize_(reduce(mul, zb.shape) * 2)
            torch.cat((za, zb), 1, out=z)
            #z.copy_(zout)

        with set_grad_enabled(True):
            param_list = [za] + list(F.parameters())
            if ctx.needs_input_grad[1]:
                param_list += [y]
            dtsdza, *dw = grad(torch.cat((-log_s, -t / s), 1),
                               param_list,
                               grad_outputs=torch.cat(
                                   (dxb * zb / s.detach() + log_s_grad, dxb),
                                   1))

            dza = dxa + dtsdza
            dzb = dxb / s.detach()
            dz = torch.cat((dza, dzb), 1)
            if ctx.needs_input_grad[1]:
                *dw, dy = dw
            else:
                dy = None
        return (dz, dy, None) + tuple(dw)
    def backward(ctx, z_grad, log_s_grad):
        F = ctx.F
        x, y, z = ctx.saved_tensors

        za, zb = z.chunk(2, 1)
        za, zb = za.contiguous(), zb.contiguous()
        dza, dzb = z_grad.chunk(2, 1)
        dza, dzb = dza.contiguous(), dzb.contiguous()

        with set_grad_enabled(True):
            xa = za
            xa.requires_grad = True
            log_s, t = F(xa, y)

        with torch.no_grad():
            s = log_s.exp()
            xb = (zb - t) / s
            x.storage().resize_(reduce(mul, xb.shape) * 2)
            torch.cat((xa, xb), 1, out=x)  # .contiguous()
            #x.copy_(xout)  # .detach()

        with set_grad_enabled(True):
            param_list = [xa] + list(F.parameters())
            if ctx.needs_input_grad[1]:
                param_list += [y]
            dtsdxa, *dw = grad(torch.cat((log_s, t), 1),
                               param_list,
                               grad_outputs=torch.cat(
                                   (dzb * xb * s + log_s_grad, dzb), 1))

            dxa = dza + dtsdxa
            dxb = dzb * s
            dx = torch.cat((dxa, dxb), 1)
            if ctx.needs_input_grad[1]:
                *dw, dy = dw
            else:
                dy = None

        return (dx, dy, None) + tuple(dw)
示例#29
0
    def forward(self,
                F: Module,
                G: Module,
                E: Module,
                scale: float,
                alpha: float,
                z: Tensor,
                loss_fn: Func,
                labels=None,
                use_tv=False,
                tv_weight=0.001):
        # Hessian applied to G here
        F_z = F(z, scale, z2=None, p_mix=0)

        # Autoencoding loss in latent space
        G_z = G(F_z, scale, alpha)
        E_z = E(G_z, alpha)

        F_x = F_z[:, 0, :]
        if labels is not None:
            # I don't remember if I made this up or not,
            # but it is a noise-based regularization strategy that might
            # encourage variance through order invariance
            # by duplicating a single element at the same index in both
            # the projection space (`F`) and the latent space (`E`)
            # when using a distance based loss (metric learning)
            perm = torch.randperm(E_z.shape[0], device=E_z.device)
            E_z_hat = torch.index_select(E_z, 0, perm)
            F_x_hat = torch.index_select(F_x, 0, perm)
            F_hat = torch.cat([F_x, F_x_hat], 0)
            E_hat = torch.cat([E_z, E_z_hat], 0)
            loss = loss_fn(F_hat, E_hat, labels)
        else:
            loss = loss_fn(F_x, E_z)

        if use_tv:
            loss += self.total_variation(G_z) * tv_weight
        return loss
示例#30
0
def loss_autoencoder(F,
                     G,
                     E,
                     scale,
                     alpha,
                     z,
                     loss_fn,
                     labels=None,
                     use_tv=False,
                     tv_weight=0.001,
                     permute_regularize=False,
                     bbox=None):
    # Hessian applied to G here
    F_z = F(z, scale, z2=None, p_mix=0)

    # Autoencoding loss in latent space
    G_z = G(F_z, scale, alpha, bbox=bbox)
    E_z = E(G_z, alpha)

    #E_z = E_z.reshape(E_z.shape[0], 1, E_z.shape[1]).repeat(1, F_z.shape[1], 1)
    F_x = F_z[:, 0, :]
    if labels is not None:
        if permute_regularize:
            perm = torch.randperm(E_z.shape[0], device=E_z.device)
            E_z_hat = torch.index_select(E_z, 0, perm)
            F_x_hat = torch.index_select(F_x, 0, perm)
            F_hat = torch.cat([F_x, F_x_hat], 0)
            E_hat = torch.cat([E_z, E_z_hat], 0)
            loss = loss_fn(F_hat, E_hat, labels)
        else:
            loss = loss_fn(F_x, E_z, labels)
    else:
        loss = loss_fn(F_x, E_z)

    if use_tv:
        loss += total_variation(G_z) * tv_weight
    return loss