Пример #1
0
    def backward(ctx, z_grad, log_s_grad):
        F = ctx.F
        z, spect, speaker_ids, audio_out = ctx.saved_tensors

        audio_0_out, audio_1_out = audio_out.chunk(2, 1)
        audio_0_out, audio_1_out = audio_0_out.contiguous(
        ), audio_1_out.contiguous()
        dza, dzb = z_grad.chunk(2, 1)
        dza, dzb = dza.contiguous(), dzb.contiguous()

        with set_grad_enabled(True):
            audio_0 = audio_0_out
            audio_0.requires_grad = True
            log_s, t = F(audio_0, spect, speaker_ids)

        with torch.no_grad():
            s = torch.exp(log_s).half(
            )  # exp not implemented for fp16 therefore this is cast to fp32 by Nvidia/Apex
            audio_1 = (audio_1_out -
                       t) / s  # s is fp32 therefore audio_1 is cast to fp32.
            z.storage().resize_(reduce(mul, audio_1.shape) * 2)  # z is fp16
            if z.dtype == torch.float16:  # if z is fp16, cast audio_0 and audio_1 back to fp16.
                torch.cat((audio_0.half(), audio_1.half()), 1,
                          out=z)  #fp16  # .contiguous()
            else:
                torch.cat((audio_0, audio_1), 1, out=z)  #fp32  # .contiguous()
            #z.copy_(xout)  # .detach()

        with set_grad_enabled(True):
            param_list = [audio_0] + list(F.parameters())
            if ctx.needs_input_grad[1]:
                param_list += [spect]
            if ctx.needs_input_grad[2]:
                param_list += [speaker_ids]
            dtsdxa, *dw = grad(torch.cat((log_s, t), 1),
                               param_list,
                               grad_outputs=torch.cat(
                                   (dzb * audio_1 * s + log_s_grad, dzb), 1))

            dxa = dza + dtsdxa
            dxb = dzb * s
            dx = torch.cat((dxa, dxb), 1)
            if ctx.needs_input_grad[1]:
                *dw, dy = dw
            else:
                dy = None
            if ctx.needs_input_grad[2]:
                *dw, ds = dw
            else:
                ds = None

        return (dx, dy, ds, None) + tuple(dw)
Пример #2
0
    def backward(ctx, x_grad, log_s_grad):
        F = ctx.F
        audio_out, spect, speaker_ids, z = ctx.saved_tensors

        audio_0, audio_1 = z.chunk(2, 1)
        audio_0, audio_1 = audio_0.contiguous(), audio_1.contiguous()
        dxa, dxb = x_grad.chunk(2, 1)
        dxa, dxb = dxa.contiguous(), dxb.contiguous()

        with set_grad_enabled(True):
            audio_0_out = audio_0
            audio_0_out.requires_grad = True
            log_s, t = F(audio_0_out, spect, speaker_ids)
            s = log_s.exp()

        with torch.no_grad():
            audio_1_out = audio_1 * s + t

            audio_out.storage().resize_(reduce(mul, audio_1_out.shape) * 2)
            torch.cat((audio_0_out, audio_1_out), 1, out=audio_out)
            #audio_out.copy_(zout)

        with set_grad_enabled(True):
            param_list = [audio_0_out] + list(F.parameters())
            if ctx.needs_input_grad[1]:
                param_list += [spect]
            if ctx.needs_input_grad[2]:
                param_list += [speaker_ids]
            dtsdza, *dw = grad(
                torch.cat((-log_s, -t / s), 1),
                param_list,
                grad_outputs=torch.cat(
                    (dxb * audio_1_out / s.detach() + log_s_grad, dxb), 1))

            dza = dxa + dtsdza
            dzb = dxb / s.detach()
            dz = torch.cat((dza, dzb), 1)
            if ctx.needs_input_grad[1]:
                *dw, dy = dw
            else:
                dy = None
            if ctx.needs_input_grad[2]:
                *dw, ds = dw
            else:
                ds = None

        return (dz, dy, ds, None) + tuple(dw)
Пример #3
0
    def backward(ctx, z_grad, log_s_grad):
        F = ctx.F
        z, spect, speaker_ids, audio_out = ctx.saved_tensors

        audio_0_out, audio_1_out = audio_out.chunk(2, 1)
        audio_0_out, audio_1_out = audio_0_out.contiguous(
        ), audio_1_out.contiguous()
        dza, dzb = z_grad.chunk(2, 1)
        dza, dzb = dza.contiguous(), dzb.contiguous()

        with set_grad_enabled(True):
            audio_0 = audio_0_out
            audio_0.requires_grad = True
            log_s, t = F(audio_0, spect, speaker_ids)

        with torch.no_grad():
            s = log_s.exp()
            audio_1 = (audio_1_out - t) / s
            z.storage().resize_(reduce(mul, audio_1.shape) * 2)
            torch.cat((audio_0, audio_1), 1, out=z)  #fp32  # .contiguous()
            #torch.cat((audio_0.half(), audio_1.half()), 1, out=z)#fp16  # .contiguous()
            #z.copy_(xout)  # .detach()

        with set_grad_enabled(True):
            param_list = [audio_0] + list(F.parameters())
            if ctx.needs_input_grad[1]:
                param_list += [spect, speaker_ids]
            dtsdxa, *dw = grad(torch.cat((log_s, t), 1),
                               param_list,
                               grad_outputs=torch.cat(
                                   (dzb * audio_1 * s + log_s_grad, dzb), 1))

            dxa = dza + dtsdxa
            dxb = dzb * s
            dx = torch.cat((dxa, dxb), 1)
            if ctx.needs_input_grad[1]:
                *dw, dy = dw
            else:
                dy = None
            if ctx.needs_input_grad[2]:
                *dw, ds = dw
            else:
                ds = None

        return (dx, dy, ds, None) + tuple(dw)
    def backward(ctx, x_grad, log_s_grad):
        F = ctx.F
        z, y, x = ctx.saved_tensors

        xa, xb = x.chunk(2, 1)
        xa, xb = xa.contiguous(), xb.contiguous()
        dxa, dxb = x_grad.chunk(2, 1)
        dxa, dxb = dxa.contiguous(), dxb.contiguous()

        with set_grad_enabled(True):
            za = xa
            za.requires_grad = True
            log_s, t = F(za, y)
            s = log_s.exp()

        with torch.no_grad():
            zb = xb * s + t

            z.storage().resize_(reduce(mul, zb.shape) * 2)
            torch.cat((za, zb), 1, out=z)
            #z.copy_(zout)

        with set_grad_enabled(True):
            param_list = [za] + list(F.parameters())
            if ctx.needs_input_grad[1]:
                param_list += [y]
            dtsdza, *dw = grad(torch.cat((-log_s, -t / s), 1),
                               param_list,
                               grad_outputs=torch.cat(
                                   (dxb * zb / s.detach() + log_s_grad, dxb),
                                   1))

            dza = dxa + dtsdza
            dzb = dxb / s.detach()
            dz = torch.cat((dza, dzb), 1)
            if ctx.needs_input_grad[1]:
                *dw, dy = dw
            else:
                dy = None
        return (dz, dy, None) + tuple(dw)
    def backward(ctx, z_grad, log_s_grad):
        F = ctx.F
        x, y, z = ctx.saved_tensors

        za, zb = z.chunk(2, 1)
        za, zb = za.contiguous(), zb.contiguous()
        dza, dzb = z_grad.chunk(2, 1)
        dza, dzb = dza.contiguous(), dzb.contiguous()

        with set_grad_enabled(True):
            xa = za
            xa.requires_grad = True
            log_s, t = F(xa, y)

        with torch.no_grad():
            s = log_s.exp()
            xb = (zb - t) / s
            x.storage().resize_(reduce(mul, xb.shape) * 2)
            torch.cat((xa, xb), 1, out=x)  # .contiguous()
            #x.copy_(xout)  # .detach()

        with set_grad_enabled(True):
            param_list = [xa] + list(F.parameters())
            if ctx.needs_input_grad[1]:
                param_list += [y]
            dtsdxa, *dw = grad(torch.cat((log_s, t), 1),
                               param_list,
                               grad_outputs=torch.cat(
                                   (dzb * xb * s + log_s_grad, dzb), 1))

            dxa = dza + dtsdxa
            dxb = dzb * s
            dx = torch.cat((dxa, dxb), 1)
            if ctx.needs_input_grad[1]:
                *dw, dy = dw
            else:
                dy = None

        return (dx, dy, None) + tuple(dw)
Пример #6
0
        # self.sigmoid = nn.Sigmoid()
    def forward(self, h):
        y = self.layer(h)
        # tsne = self.lkr1(self.linear1(y))
        # y = self.sigmoid(self.linear2(tsne))
        return y


F = FeatureExtractor().to("cuda")
C = Classifier().to("cuda")
D = Discriminator().to("cuda")
# print(D)
d_critirion = nn.BCELoss()
c_critirion = nn.CrossEntropyLoss()

F_opt = torch.optim.Adam(F.parameters(), lr=0.0001)
C_opt = torch.optim.Adam(C.parameters(), lr=0.0001)
D_opt = torch.optim.Adam(D.parameters(), lr=0.0001)

#%%
max_epoch = 50


def get_lambda(epoch, max_epoch):
    p = epoch / max_epoch
    return 2. / (1 + np.exp(-10. * p)) - 1.


def train(step=step):
    final_acc = 0
    ll_c, ll_d = [], []
def backprop_deep(G,
                  F,
                  D_X,
                  D_Y,
                  content_loader,
                  style_loader,
                  start=0,
                  T=200,
                  gamma=0.0002):
    params = list(G.parameters()) + list(F.parameters())
    optimizer_Ge = torch.optim.Adam(params, lr=gamma, betas=(0.5, 0.999))
    optimizer_Dx = torch.optim.Adam(D_X.parameters(),
                                    lr=gamma,
                                    betas=(0.5, 0.999))
    optimizer_Dy = torch.optim.Adam(D_Y.parameters(),
                                    lr=gamma,
                                    betas=(0.5, 0.999))

    buffer_X_fromY = ImageHistoryBuffer()
    buffer_Y_fromX = ImageHistoryBuffer()

    for epoch in range(start, T):
        #  linearly decay the rate to zero over the next 100 epochs.
        if epoch >= 100:
            for g in optimizer_Ge.param_groups:
                g['lr'] = gamma - gamma / 100 * (epoch - 100 + 1)
            for g in optimizer_Dx.param_groups:
                g['lr'] = gamma - gamma / 100 * (epoch - 100 + 1)
            for g in optimizer_Dy.param_groups:
                g['lr'] = gamma - gamma / 100 * (epoch - 100 + 1)

        for idx, img in enumerate(zip(content_loader, style_loader)):
            X = img[0].to(device)
            Y = img[1].to(device)
            # Generators
            # Initialize the gradients to zero
            optimizer_Ge.zero_grad()
            # Forward propagation and Error evaluation
            loss_Ge = fullLoss(G, F, D_X, D_Y, X, Y, idx)
            # Back propagation
            loss_Ge.backward()
            # Parameter update
            optimizer_Ge.step()

            # Discriminator X
            # Initialize the gradients to zero
            optimizer_Dx.zero_grad()
            # Forward propagation and Error evaluation
            dx = D_X(X)
            # To reduce model oscillation, using a history of generated images rather than
            # the ones produced by the latest generators.
            dfy = buffer_X_fromY.get_from_image_history_buffer(D_X(F(Y)))
            loss_Dx = D_X.criterion(dx, dfy) / 2
            # Back propagation
            loss_Dx.backward(retain_graph=True)
            # Parameter update
            optimizer_Dx.step()

            # Discriminator Y
            # Initialize the gradients to zero
            optimizer_Dy.zero_grad()
            # Forward propagation and Error evaluation
            dy = D_Y(Y)
            dgx = buffer_X_fromY.get_from_image_history_buffer(D_Y(G(X)))
            loss_Dy = D_Y.criterion(dy, dgx) / 2
            # Back propagation
            loss_Dy.backward(retain_graph=True)
            # Parameter update
            optimizer_Dy.step()

            if idx % 50 == 0:
                print('[%d, %d] G_loss: %.3f Dx_loss: %.3f Dy_loss: %.3f\n' %
                      (epoch, idx, loss_Ge.item(), loss_Dx.item(),
                       loss_Dy.item()))

            if idx % 100 == 0:
                fake = G(content_img)
                imshow(fake, title='G(content_img) [%d, %d]' % (epoch, idx))

                cycled = F(fake)
                imshow(cycled,
                       title='F(G(content_img)) [%d, %d]' % (epoch, idx))

        show_loss()

        # save model checkpoint
        torch.save(
            {
                'G_state_dict': G.state_dict(),
                'F_state_dict': F.state_dict(),
                'Dx_state_dict': D_X.state_dict(),
                'Dy_state_dict': D_Y.state_dict(),
            }, saving_dir + "%d.pth" % (epoch))

        # save losses
        saved_losses = {
            "a": plt_Advers_loss_Dy_Gx,
            "b": plt_Advers_loss_Dx_Fy,
            "c": plt_Cyc_loss,
            "d": plt_loss
        }
        with open(saving_dir + 'loss.backup.epoch_%d' % (epoch),
                  'wb') as backup_file:
            pickle.dump(saved_losses, backup_file)

    return G, F, D_X, D_Y
Пример #8
0
def train():
    print('Load training data...')
    train_loader = utils.data_loader('train', 'test')
    print('Done!')

    # model
    F = SuperDuperFeatureExtractor()
    C = SuperDuperClassifier()

    # load pretrained model
    if Config.use_pretrain:
        F.load_state_dict(torch.load(Config.checkpoint + 'F.pth'))
        C.load_state_dict(torch.load(Config.checkpoint + 'C.pth'))

    F.to(Config.device)
    C.to(Config.device)

    op_F = optim.Adam(F.parameters(), lr=Config.learning_rate)
    op_C = optim.Adam(C.parameters(), lr=Config.learning_rate)

    criterion = nn.CrossEntropyLoss()

    accuracy_s = accuracy_v = accuracy_t = 0

    # plot training route
    plot_loss = []
    plot_s_acc = []
    plot_v_acc = []
    plot_t_acc = []

    print('Training...')
    for epoch in range(Config.num_epoch):
        for idx, batch in enumerate(train_loader):
            s_img = batch['source_image'].to(Config.device)
            t_img = batch['target_image'].to(Config.device)
            v_img = batch['val_image'].to(Config.device)
            # img: [batch_size, 1, 28, 28]
            s_label = batch['source_label'].to(Config.device)
            v_label = batch['val_label'].to(Config.device)
            t_label = batch['target_label'].to(Config.device)
            # label: [batch_size,]

            feat_s = F(s_img)
            feat_v = F(v_img)
            feat_t = F(t_img)
            pred_s = C(feat_s)
            pred_v = C(feat_v)
            pred_t = C(feat_t)

            loss_s = criterion(pred_s, s_label)
            loss_msda = k_moment([feat_t, feat_s], k=4)

            loss = loss_s + Config.lambda_msda * loss_msda

            op_F.zero_grad()
            op_C.zero_grad()

            loss.backward()

            op_F.step()
            op_C.step()

            # accuracy
            pred_label_s = pred_s.argmax(1)
            accuracy_s = (pred_label_s
                          == s_label).sum().item() / s_label.size(0)
            pred_label_v = pred_v.argmax(1)
            accuracy_v = (pred_label_v
                          == v_label).sum().item() / v_label.size(0)
            pred_label_t = pred_t.argmax(1)
            accuracy_t = (pred_label_t
                          == t_label).sum().item() / t_label.size(0)

            # plot history
            plot_loss.append(loss.item())
            plot_s_acc.append(accuracy_s)
            plot_v_acc.append(accuracy_v)
            plot_t_acc.append(accuracy_t)

        print(
            'Epoch: {}, Loss_s: {}, Loss_msda: {}, Accuracy_s: {}, Accuracy_v: {}, Accuracy_t: {}'
            .format(epoch, loss_s, loss_msda, accuracy_s, accuracy_v,
                    accuracy_t))

    if Config.enable_plot:
        plt.figure('Accuracy & Loss')
        plt.ylim(0.0, 3.2)
        plt.plot(plot_loss, label='Loss')
        plt.plot(plot_s_acc, label='Source Accuracy')
        plt.plot(plot_v_acc, label='Validate Accuracy')
        plt.plot(plot_t_acc, label='Target Accuracy')
        plt.xlabel('Batch')
        plt.title('Train Accuracy & Loss')
        plt.legend()
        plt.show()

    print('Done!')

    # save model
    torch.save(F.state_dict(), Config.checkpoint + 'F.pth')
    torch.save(C.state_dict(), Config.checkpoint + 'C.pth')