Example #1
0
def get_transformer(device: torch.device) -> GetterReturnType:
    # For most SOTA research, you would like to have embed to 720, nhead to 12, bsz to 64, tgt_len/src_len to 128.
    N = 64
    seq_length = 128
    ntoken = 50
    model = models.TransformerModel(ntoken=ntoken,
                                    ninp=720,
                                    nhead=12,
                                    nhid=2048,
                                    nlayers=2)
    model.to(device)

    if has_functorch:
        # disable dropout for consistency checking
        model.eval()

    criterion = nn.NLLLoss()
    params, names = extract_weights(model)

    data = torch.rand(N, seq_length + 1, device=device).mul(ntoken).long()
    inputs = data.narrow(1, 0, seq_length)
    targets = data.narrow(1, 1, seq_length)

    def forward(*new_params: Tensor) -> Tensor:
        load_weights(model, names, new_params)
        out = model(inputs)

        loss = criterion(out.reshape(N * seq_length, ntoken),
                         targets.reshape(N * seq_length))
        return loss

    return forward, params
Example #2
0
def get_fcn_resnet(device: torch.device) -> GetterReturnType:
    N = 8
    criterion = torch.nn.MSELoss()
    model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False)

    if has_functorch:
        from functorch.experimental import replace_all_batch_norm_modules_

        replace_all_batch_norm_modules_(model)
        # disable dropout for consistency checking
        model.eval()

    model.to(device)
    params, names = extract_weights(model)

    inputs = torch.rand([N, 3, 480, 480], device=device)
    # Given model has 21 classes
    labels = torch.rand([N, 21, 480, 480], device=device)

    def forward(*new_params: Tensor) -> Tensor:
        load_weights(model, names, new_params)
        out = model(inputs)['out']

        loss = criterion(out, labels)
        return loss

    return forward, params
Example #3
0
def get_deepspeech(device: torch.device) -> GetterReturnType:
    sample_rate = 16000
    window_size = 0.02
    window = "hamming"
    audio_conf = dict(sample_rate=sample_rate,
                      window_size=window_size,
                      window=window,
                      noise_dir=None)

    N = 10
    num_classes = 10
    spectrogram_size = 161
    # Commented are the original sizes in the code
    seq_length = 500  # 1343
    target_length = 10  # 50
    labels = torch.rand(num_classes, device=device)
    inputs = torch.rand(N, 1, spectrogram_size, seq_length, device=device)
    # Sequence length for each input
    inputs_sizes = torch.rand(N, device=device).mul(seq_length * 0.1).add(
        seq_length * 0.8)
    targets = torch.rand(N, target_length, device=device)
    targets_sizes = torch.full((N, ),
                               target_length,
                               dtype=torch.int,
                               device=device)

    model = models.DeepSpeech(rnn_type=nn.LSTM,
                              labels=labels,
                              rnn_hidden_size=1024,
                              nb_layers=5,
                              audio_conf=audio_conf,
                              bidirectional=True)

    if has_functorch:
        from functorch.experimental import replace_all_batch_norm_modules_

        replace_all_batch_norm_modules_(model)

    model = model.to(device)
    criterion = nn.CTCLoss()
    params, names = extract_weights(model)

    def forward(*new_params: Tensor) -> Tensor:
        load_weights(model, names, new_params)
        out, out_sizes = model(inputs, inputs_sizes)
        out = out.transpose(0, 1)  # For ctc loss

        loss = criterion(out, targets, out_sizes, targets_sizes)
        return loss

    return forward, params
Example #4
0
def get_resnet18(device: torch.device) -> GetterReturnType:
    N = 32
    model = models.resnet18(pretrained=False)
    criterion = torch.nn.CrossEntropyLoss()
    model.to(device)
    params, names = extract_weights(model)

    inputs = torch.rand([N, 3, 224, 224], device=device)
    labels = torch.rand(N, device=device).mul(10).long()

    def forward(*new_params: Tensor) -> Tensor:
        load_weights(model, names, new_params)
        out = model(inputs)

        loss = criterion(out, labels)
        return loss

    return forward, params
Example #5
0
def get_fcn_resnet(device: torch.device) -> GetterReturnType:
    N = 8
    criterion = torch.nn.MSELoss()
    model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False)
    model.to(device)
    params, names = extract_weights(model)

    inputs = torch.rand([N, 3, 480, 480], device=device)
    # Given model has 21 classes
    labels = torch.rand([N, 21, 480, 480], device=device)

    def forward(*new_params: Tensor) -> Tensor:
        load_weights(model, names, new_params)
        out = model(inputs)['out']

        loss = criterion(out, labels)
        return loss

    return forward, params
Example #6
0
def get_wav2letter(device: torch.device) -> GetterReturnType:
    N = 10
    input_frames = 700
    vocab_size = 28
    model = models.Wav2Letter(num_classes=vocab_size)
    criterion = torch.nn.NLLLoss()
    model.to(device)
    params, names = extract_weights(model)

    inputs = torch.rand([N, 1, input_frames], device=device)
    labels = torch.rand(N, 3, device=device).mul(vocab_size).long()

    def forward(*new_params: Tensor) -> Tensor:
        load_weights(model, names, new_params)
        out = model(inputs)

        loss = criterion(out, labels)
        return loss

    return forward, params
Example #7
0
def get_multiheadattn(device: torch.device) -> GetterReturnType:
    # From https://github.com/pytorch/text/blob/master/test/data/test_modules.py#L10
    embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64
    # Build torchtext MultiheadAttention module
    in_proj = models.InProjContainer(
        torch.nn.Linear(embed_dim, embed_dim, bias=False),
        torch.nn.Linear(embed_dim, embed_dim, bias=False),
        torch.nn.Linear(embed_dim, embed_dim, bias=False))

    model = models.MultiheadAttentionContainer(
        nhead, in_proj, models.ScaledDotProduct(),
        torch.nn.Linear(embed_dim, embed_dim, bias=False))
    model.to(device)
    params, names = extract_weights(model)

    query = torch.rand((tgt_len, bsz, embed_dim), device=device)
    key = value = torch.rand((src_len, bsz, embed_dim), device=device)
    attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len),
                                 device=device).to(torch.bool)
    bias_k = bias_v = torch.rand((1, 1, embed_dim), device=device)

    attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead))
    bias_k = bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)
    bias_v = bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)

    def forward(*new_params: Tensor) -> Tensor:
        load_weights(model, names, new_params)
        mha_output, attn_weights = model(query,
                                         key,
                                         value,
                                         attn_mask=attn_mask,
                                         bias_k=bias_k,
                                         bias_v=bias_v)

        # Don't test any specific loss, just backprop ones for both outputs
        loss = mha_output.sum() + attn_weights.sum()

        return loss

    return forward, params
Example #8
0
def get_resnet18(device: torch.device) -> GetterReturnType:
    N = 32
    model = models.resnet18(pretrained=False)

    if has_functorch:
        from functorch.experimental import replace_all_batch_norm_modules_

        replace_all_batch_norm_modules_(model)

    criterion = torch.nn.CrossEntropyLoss()
    model.to(device)
    params, names = extract_weights(model)

    inputs = torch.rand([N, 3, 224, 224], device=device)
    labels = torch.rand(N, device=device).mul(10).long()

    def forward(*new_params: Tensor) -> Tensor:
        load_weights(model, names, new_params)
        out = model(inputs)

        loss = criterion(out, labels)
        return loss

    return forward, params
Example #9
0
def get_detr(device: torch.device) -> GetterReturnType:
    # All values below are from CLI defaults in https://github.com/facebookresearch/detr
    N = 2
    num_classes = 91
    hidden_dim = 256
    nheads = 8
    num_encoder_layers = 6
    num_decoder_layers = 6

    model = models.DETR(num_classes=num_classes,
                        hidden_dim=hidden_dim,
                        nheads=nheads,
                        num_encoder_layers=num_encoder_layers,
                        num_decoder_layers=num_decoder_layers)
    losses = ['labels', 'boxes', 'cardinality']
    eos_coef = 0.1
    bbox_loss_coef = 5
    giou_loss_coef = 2
    weight_dict = {
        'loss_ce': 1,
        'loss_bbox': bbox_loss_coef,
        'loss_giou': giou_loss_coef
    }
    matcher = models.HungarianMatcher(1, 5, 2)
    criterion = models.SetCriterion(num_classes=num_classes,
                                    matcher=matcher,
                                    weight_dict=weight_dict,
                                    eos_coef=eos_coef,
                                    losses=losses)

    model = model.to(device)
    criterion = criterion.to(device)
    params, names = extract_weights(model)

    inputs = torch.rand(N, 3, 800, 1200, device=device)
    labels = []
    for idx in range(N):
        targets = {}
        n_targets: int = int(torch.randint(5, 10, size=tuple()).item())
        label = torch.randint(5, 10, size=(n_targets, ))
        targets["labels"] = label
        boxes = torch.randint(100, 800, size=(n_targets, 4))
        for t in range(n_targets):
            if boxes[t, 0] > boxes[t, 2]:
                boxes[t, 0], boxes[t, 2] = boxes[t, 2], boxes[t, 0]
            if boxes[t, 1] > boxes[t, 3]:
                boxes[t, 1], boxes[t, 3] = boxes[t, 3], boxes[t, 1]
        targets["boxes"] = boxes.float()
        labels.append(targets)

    def forward(*new_params: Tensor) -> Tensor:
        load_weights(model, names, new_params)
        out = model(inputs)

        loss = criterion(out, labels)
        weight_dict = criterion.weight_dict
        final_loss = cast(
            Tensor,
            sum(loss[k] * weight_dict[k] for k in loss.keys()
                if k in weight_dict))
        return final_loss

    return forward, params
    def update_one_step(data, learning_rate, epsilon, n):
        state, count, logphi0  = data['state'], data['count'], data['logphi0']
        op_states, op_coeffs = data['update_states'], data['update_coeffs']
        psi = logphi_model(state)
        logphi = psi[:, 0].reshape(len(state), -1)
        theta = psi[:, 1].reshape(len(state), -1)

        # calculate the weights of the energy from important sampling
        delta_logphi = logphi - logphi0[..., None]

        # delta_logphi = delta_logphi - delta_logphi.mean()*torch.ones(delta_logphi.shape)
        delta_logphi = delta_logphi - delta_logphi.mean()
        weights = count[..., None]*torch.exp(delta_logphi * 2)
        weights_norm = weights.sum()
        weights = (weights/weights_norm).detach()
        
        if weights_norm/count.sum() > target_wn:
            return weights_norm/count.sum(), 0
        else:
            n_sample = op_states.shape[0]
            n_updates = op_states.shape[1]
            op_states = op_states.reshape([-1, Dp] + single_state_shape)
            psi_ops = logphi_model(op_states)
            logphi_ops = psi_ops[:, 0].reshape(n_sample, n_updates)
            theta_ops = psi_ops[:, 1].reshape(n_sample, n_updates)

            delta_logphi_os = logphi_ops - logphi*torch.ones_like(logphi_ops)
            # delta_logphi_os = torch.clamp(delta_logphi_os, max=5)
            delta_theta_os = theta_ops - theta*torch.ones_like(theta_ops)
            ops_real = torch.sum(op_coeffs*torch.exp(delta_logphi_os)*torch.cos(delta_theta_os), 1)
            ops_imag = torch.sum(op_coeffs*torch.exp(delta_logphi_os)*torch.sin(delta_theta_os), 1)
            
            with torch.no_grad():
                ops = ops_real + 1j*ops_imag # batch_size
                mean_e = (ops_real[...,None]*weights).sum(0)
            # update parameters with gradient descent
            # copy the model parameters
            op_logphi_model = copy.deepcopy(logphi_model)
            params, names = extract_weights(op_logphi_model)

            def forward(*new_param):
                load_weights(op_logphi_model, names, new_param)
                out = op_logphi_model(state)
                return out

            dydws = jacobian(forward, params, vectorize=True) # a tuple contain all grads
            cnt = 0
            tic = time.time()
            for param in logphi_model.parameters():
                param_len = len(param.data.reshape(-1))
                dydws_layer = dydws[cnt].reshape(n_sample,2,-1)
                with torch.no_grad():
                    grads_real = dydws_layer[:,0,:] # jacobian of d logphi wrt d w
                    grads_imag = dydws_layer[:,1,:] # jacobian of d theta wrt d w
                    Oks = grads_real + 1j*grads_imag
                    Oks_conj = grads_real - 1j*grads_imag
                    OO_matrix = Oks_conj.reshape(n_sample, 1, param_len)*Oks.reshape(n_sample, param_len, 1)
                
                
                Oks = Oks*weights
                Oks_conj = Oks_conj*weights
                Skk_matrix = (OO_matrix*weights[...,None]).sum(0) - Oks_conj.sum(0)[..., None]*Oks.sum(0)
                Skk_matrix = 0.5*(Skk_matrix + Skk_matrix.t().conj()) + epsilon*torch.eye(Skk_matrix.shape[0], device=gpu)
                # calculate Fk
                Fk = (ops[...,None]*Oks_conj).sum(0) - mean_e*(Oks_conj).sum(0)
                # update_k = torch.linalg.solve(Skk_matrix, Fk)
                update_k, _ = torch.solve(Fk[...,None], Skk_matrix)
                param.data -= learning_rate*update_k.real.reshape(param.data.shape)
                cnt += 1
            t = time.time() - tic
            return weights_norm/count.sum(), t
    def update_one_step(data, learning_rate, epsilon):
        state, count, logphi0 = data['state'], data['count'], data['logphi0']
        op_states, op_coeffs = data['update_states'], data['update_coeffs']
        # psi = logphi_model(state)
        psi = psi_model(state)
        logphi = psi[:, 0].reshape(len(state), -1)
        theta = psi[:, 1].reshape(len(state), -1)

        # calculate the weights of the energy from important sampling
        delta_logphi = logphi - logphi0[..., None]
        #print (delta_logphi.cpu().detach().numpy())

        # delta_logphi = delta_logphi - delta_logphi.mean()*torch.ones(delta_logphi.shape)
        delta_logphi = delta_logphi - delta_logphi.mean()
        weights = count[..., None] * torch.exp(delta_logphi * 2)
        weights_norm = weights.sum()
        weights = (weights / weights_norm).detach()

        if weights_norm / count.sum() > target_wn:
            return weights_norm / count.sum(), 0
        else:
            n_sample = op_states.shape[0]
            n_updates = op_states.shape[1]
            op_states = op_states.reshape([-1, Dp] + single_state_shape)
            # psi_ops = logphi_model(op_states)
            psi_ops = psi_model(op_states)
            logphi_ops = psi_ops[:, 0].reshape(n_sample, n_updates)
            theta_ops = psi_ops[:, 1].reshape(n_sample, n_updates)

            delta_logphi_os = logphi_ops - logphi * torch.ones_like(logphi_ops)
            # delta_logphi_os = torch.clamp(delta_logphi_os, max=5)
            delta_theta_os = theta_ops - theta * torch.ones_like(theta_ops)
            ops_real = torch.sum(
                op_coeffs * torch.exp(delta_logphi_os) *
                torch.cos(delta_theta_os), 1)
            ops_imag = torch.sum(
                op_coeffs * torch.exp(delta_logphi_os) *
                torch.sin(delta_theta_os), 1)

            with torch.no_grad():
                ops = ops_real + 1j * ops_imag  # batch_size
                mean_e = (ops_real[..., None] * weights).sum(0)
            # update parameters with gradient descent
            # copy the model parameters
            op_psi_model = copy.deepcopy(psi_model)
            params, names = extract_weights(op_psi_model)

            name_num = []
            for name in names:
                name_num.append(int(name.split(".")[0]))
            _, cnts = np.unique(name_num, return_counts=True, axis=0)
            # net_layers = len(cnts)

            ws = psi_model.state_dict()

            def forward(*new_param):
                load_weights(op_psi_model, names, new_param)
                out = op_psi_model(state)
                return out

            dpsidws = jacobian(forward, params,
                               vectorize=True)  # a tuple contain all grads

            cnt = 0
            tic = time.time()
            for net_layer in cnts:
                step = net_layer // 2
                for i in range(step):
                    index = cnt + i
                    param_len = len(ws[names[index]].reshape(-1))
                    dres_layer = dpsidws[index].reshape(n_sample, 2, -1)
                    dims_layer = dpsidws[index + step].reshape(n_sample, 2, -1)
                    with torch.no_grad():
                        Oks = 0.5 * (dres_layer[:, 0, :] + dims_layer[:, 1, :]
                                     ) - 0.5 * 1j * (dims_layer[:, 0, :] -
                                                     dres_layer[:, 1, :])
                        Oks_conj = Oks.conj()
                        OO_matrix = Oks_conj.reshape(
                            n_sample, 1, param_len) * Oks.reshape(
                                n_sample, param_len, 1)

                    Oks = Oks * weights
                    Oks_conj = Oks_conj * weights
                    Skk_matrix = (OO_matrix * weights[..., None]).sum(
                        0) - Oks_conj.sum(0)[..., None] * Oks.sum(0)
                    Skk_matrix = 0.5 * (Skk_matrix + Skk_matrix.t().conj(
                    )) + epsilon * torch.eye(Skk_matrix.shape[0], device=gpu)
                    # calculate Fk
                    Fk = (ops[..., None] *
                          Oks_conj).sum(0) - mean_e * (Oks_conj).sum(0)
                    update_k, _ = torch.solve(Fk[..., None], Skk_matrix)
                    # real part
                    ws[names[index]] -= learning_rate * update_k.real.reshape(
                        ws[names[index]].shape)
                    # imag part
                    ws[names[index +
                             step]] -= learning_rate * update_k.imag.reshape(
                                 ws[names[index]].shape)

                cnt += net_layer

            psi_model.load_state_dict(ws)
            t = time.time() - tic
            return weights_norm / count.sum(), t