예제 #1
0
    def __init__(self):
        super().__init__()
        self.encoder = downscale16_encoder_block()
        self.logit = DiscriminatorLogitBlock()

        p_trainable, p_non_trainable = count_params(self)
        print(
            f'Discriminator64 params: trainable {p_trainable} - non_trainable {p_non_trainable}'
        )
예제 #2
0
    def __init__(self):
        super().__init__()
        self.img = nn.Sequential(
            conv3x3(D_GF, 3),
            nn.Tanh()
        )

        p_trainable, p_non_trainable = count_params(self)
        print(f'Image output params: trainable {p_trainable} - non_trainable {p_non_trainable}')
예제 #3
0
    def __init__(self):
        super().__init__()
        self.downscale_encoder_16 = downscale16_encoder_block()
        self.downscale_encoder_32 = downscale2_encoder_block(
            D_DF * 8, D_DF * 16)
        self.encoder32 = conv3x3_LReLU(D_DF * 16, D_DF * 8)
        self.logit = DiscriminatorLogitBlock()

        p_trainable, p_non_trainable = count_params(self)
        print(
            f'Discriminator128 params: trainable {p_trainable} - non_trainable {p_non_trainable}'
        )
예제 #4
0
    def __init__(self, use_self_attention=False):
        super().__init__()
        self.residuals = nn.Sequential(*[Residual(D_GF * 2) for _ in range(RESIDUALS)])
        self.attn = Attention(D_GF, D_HIDDEN)
        self.upsample = upsample_block(D_GF * 2, D_GF)
        self.use_self_attention = use_self_attention

        if self.use_self_attention:
            self.self_attn = self_attn_block()

        p_trainable, p_non_trainable = count_params(self)
        print(f'GeneratorN params: trainable {p_trainable} - non_trainable {p_non_trainable}')
예제 #5
0
    def __init__(self):
        super().__init__()
        self.d_gf = D_GF * 16
        self.fc = nn.Sequential(
            nn.Linear(D_Z + D_COND, self.d_gf * 4 * 4 * 2, bias=False),
            nn.BatchNorm1d(self.d_gf * 4 * 4 * 2),
            nn.modules.activation.GLU(dim=1)
        )

        self.upsample_steps = nn.Sequential(
            *[upsample_block(self.d_gf // (2 ** i), self.d_gf // (2 ** (i + 1))) for i in range(4)]
        )

        p_trainable, p_non_trainable = count_params(self)
        print(f'Generator0 params: trainable {p_trainable} - non_trainable {p_non_trainable}')
예제 #6
0
    def __init__(self, device=DEVICE):
        super().__init__()
        self.device = device
        self.inception_model = torchvision.models.inception_v3(pretrained=True).to(self.device).eval()
        # Freeze Inception V3 parameters
        freeze_params_(self.inception_model)
        # 768: the dimension of mixed_6e layer's sub-regions (768 x 289 [number of sub-regions, 17 x 17])
        self.local_proj = conv1x1(768, D_HIDDEN).to(self.device)
        # 2048: the dimension of last average pool's output
        self.global_proj = nn.Linear(2048, D_HIDDEN).to(self.device)

        self.local_proj.weight.data.uniform_(-IMG_WEIGHT_INIT_RANGE, IMG_WEIGHT_INIT_RANGE)
        self.global_proj.weight.data.uniform_(-IMG_WEIGHT_INIT_RANGE, IMG_WEIGHT_INIT_RANGE)

        p_trainable, p_non_trainable = count_params(self)
        print(f'Image encoder params: trainable {p_trainable} - non_trainable {p_non_trainable}')
예제 #7
0
    def __init__(self, vocab_size, device=DEVICE):
        super().__init__()
        self.vocab_size = vocab_size
        self.device = device
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=D_WORD).to(self.device)
        self.emb_dropout = nn.Dropout(P_DROP).to(self.device)
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            self.rnn = nn.LSTM(
                input_size=D_WORD,
                hidden_size=D_HIDDEN // 2,  # bidirectional
                batch_first=True,
                dropout=P_DROP,
                bidirectional=True).to(self.device)
        # Initial cell and hidden state for each sequence
        hidden0_weights = torch.randn(D_HIDDEN // 2)
        self.hidden0 = nn.Parameter(hidden0_weights.to(self.device), requires_grad=True)
        cell0_weights = torch.randn(D_HIDDEN // 2)
        self.cell0 = nn.Parameter(cell0_weights.to(self.device), requires_grad=True)

        p_trainable, p_non_trainable = count_params(self)
        print(f'Text encoder params: trainable {p_trainable} - non_trainable {p_non_trainable}')