예제 #1
0
def get_model(device):
    """
    Gets the device that the model is running on. Currently running standard linformer
    """
    model = Linformer(input_size=config["dummy_seq_len"], channels=config["dummy_ch"], dim_d=config["dummy_ch"], dim_k=64,dim_ff=64, nhead=4, depth=2, activation="gelu", checkpoint_level="C0", full_attention=True, include_ff=False, parameter_sharing="none")
    model.to(device)
    return model
예제 #2
0
def get_model(device):
    """
    Gets the device that the model is running on. Currently running standard linformer
    """
    model = Linformer(input_size=config["dummy_seq_len"], channels=config["dummy_ch"], dim_k=64,dim_ff=64, nhead=4, depth=4, activation="gelu", checkpoint_level="C0")
    model.to(device)
    return model
예제 #3
0
 def __init__(self, embed_size, num_layers=2):
     super().__init__()
     self.path_encoder = nn.GRU(embed_size, embed_size, 2)
     # FIXME: computing a new representation of r and every paths for each layer is very expensive (attention bottleneck)
     #  need to use a cheaper alternative
     self.layers = Padder(
         Linformer(256,
                   embed_size,
                   dim_d=None,
                   dim_k=embed_size * 2,
                   dim_ff=embed_size * 2,
                   nhead=4,
                   depth=num_layers))
예제 #4
0
import sys
import torch

sys.path.insert(0, "../")
from linformer_pytorch import Linformer, Visualizer

model = Linformer(
    input_size=512,
    channels=16,
    dim_k=128,
    dim_ff=32,
    nhead=4,
    depth=3,
    activation="relu",
    checkpoint_level="C0",
    parameter_sharing="layerwise",
    k_reduce_by_layer=1,
)
x = torch.randn(1, 512, 16)
y = model(x, visualize=True)
vis = Visualizer(model)
vis.plot_all_heads(title="All P_bar matrices",
                   show=True,
                   save_file=None,
                   figsize=(30, 20),
                   n_limit=256)
print(y)  # (1, 512, 16)
예제 #5
0
import sys
import torch

sys.path.insert(0, "../")
from linformer_pytorch import Linformer, Padder

model = Linformer(
    input_size=512,
    channels=16,
    dim_d=32,
    dim_k=16,
    dim_ff=32,
    nhead=6,
    depth=3,
    checkpoint_level="C1",
)
model = Padder(model)
x = torch.randn(1, 500, 16)  # This does not match the input size!
y = model(x)
print(y)  # (1, 500, 16)
import sys
import torch

sys.path.insert(0, "../")
from linformer_pytorch import Linformer

model = Linformer(
    input_size=510,
    channels=21,
    dim_d=26,
    dim_k=61,
    dim_ff=32,
    nhead=4,
    depth=3,
    activation="relu",
    checkpoint_level="C0",
    parameter_sharing="none",
    k_reduce_by_layer=1,
    include_ff=True,
    method="convolution",
)
x = torch.randn(1, 510, 21)
y = model(x)
print(y)  # (1, 510, 21)
import sys
import torch

sys.path.insert(0, "../")
from linformer_pytorch import Linformer

model = Linformer(
    input_size=512,
    channels=16,
    dim_k=16,
    dim_ff=32,
    nhead=4,
    depth=3,
    activation="relu",
    checkpoint_level="C1",
    parameter_sharing="none",
    k_reduce_by_layer=1,
    include_ff=True,
    w_o_intermediate_dim=4,
)
x = torch.randn(1, 512, 16)
y = model(x)
print(y)  # (1, 512, 16)
import sys
import torch

sys.path.insert(0, "../")
from linformer_pytorch import Linformer

model = Linformer(
        input_size=512,
        channels=16,
        dim_k=16,
        dim_ff=32,
        nhead=4,
        depth=3,
        activation="relu",
        checkpoint_level="C1",
        parameter_sharing="none",
        k_reduce_by_layer=1,
        full_attention=True,
        )
x = torch.randn(1, 512, 16)
y = model(x)
print(y) # (1, 512, 16)
예제 #9
0
from linformer_pytorch import Linformer
import torch

device = torch.device("cuda")
model = Linformer(
        input_size=262144,
        channels=64,
        dim_d=256,
        dim_k=64,
        dim_ff=128,
        ).cuda()
x = torch.randn(1, 262144, 64).cuda()
y = model(x)
print(y) # (1, 262144, 64)

# To see memory usage, uncomment the line below.
#print('Allocated Memory:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')