Exemplo n.º 1
0
    def __init__(self,
                 n_in_chan,
                 n_out_chan,
                 filter_sz,
                 stride=1,
                 do_res=True,
                 parent_rf=None,
                 name=None):
        super(ConvReLURes, self).__init__()
        self.do_res = do_res
        if self.do_res:
            if stride != 1:
                print('Stride must be 1 for residually connected convolution',
                      file=sys.stderr)
                raise ValueError

        self.n_in = n_in_chan
        self.n_out = n_out_chan
        self.conv = nn.Conv1d(n_in_chan,
                              n_out_chan,
                              filter_sz,
                              stride,
                              padding=0,
                              bias=False)
        self.relu = nn.ReLU()
        # self.bn = nn.BatchNorm1d(n_out_chan)

        self.rf = rfield.Rfield(filter_info=filter_sz,
                                stride=stride,
                                parent=parent_rf,
                                name=name)
        netmisc.xavier_init(self.conv)
Exemplo n.º 2
0
    def __init__(self, n_in, n_out, n_sam_per_datapoint=1, bias=True):
        '''n_sam_per_datapoint is L from equation 7,
        https://arxiv.org/pdf/1312.6114.pdf'''
        super(VAE, self).__init__()
        self.linear = nn.Conv1d(n_in, n_out * 2, 1, bias=False)
        self.tanh = nn.Tanh()
        self.linear2 = nn.Conv1d(n_out * 2, n_out * 2, 1, bias=bias)
        self.n_sam_per_datapoint = n_sam_per_datapoint
        netmisc.xavier_init(self.linear)
        netmisc.xavier_init(self.linear2)

        # Cache these values for later access by the objective function
        self.mu = None
        self.sigma = None
Exemplo n.º 3
0
 def __init__(self, n_in, n_out, vq_gamma, vq_n_embed):
     super(VQ, self).__init__()
     self.d = n_out
     self.gamma = vq_gamma
     self.k = vq_n_embed
     self.linear = nn.Conv1d(n_in, self.d, 1, bias=False)
     self.sg = StopGrad()
     self.rg = ReplaceGrad()
     self.ze = None
     self.l2norm_min = None
     self.register_buffer('ind_hist', torch.zeros(self.k))
     self.circ_inds = None
     self.emb = nn.Parameter(data=torch.empty(self.k, self.d))
     netmisc.xavier_init(self.linear)
     nn.init.xavier_uniform_(self.emb, gain=1)
Exemplo n.º 4
0
    def __init__(self, n_in, n_out, vq_gamma, vq_ema_gamma, vq_n_embed,
                 training):
        super(VQEMA, self).__init__()
        self.training = training
        self.d = n_out
        self.gamma = vq_gamma
        self.ema_gamma = vq_ema_gamma
        self.ema_gamma_comp = 1.0 - self.ema_gamma
        self.k = vq_n_embed
        self.linear = nn.Conv1d(n_in, self.d, 1, bias=False)
        self.sg = StopGrad()
        self.rg = ReplaceGrad()
        self.ze = None
        self.register_buffer('emb', torch.empty(self.k, self.d))
        nn.init.xavier_uniform_(self.emb, gain=10)

        if self.ema_gamma >= 1.0 or self.ema_gamma <= 0:
            raise RuntimeError('VQEMA must use an EMA-gamma value in (0, 1)')

        if self.training:
            self.min_dist = None
            self.circ_inds = None
            self.register_buffer('ind_hist', torch.zeros(self.k))
            self.register_buffer('ema_numer', torch.empty(self.k, self.d))
            self.register_buffer('ema_denom', torch.empty(self.k))
            self.register_buffer('z_sum', torch.empty(self.k, self.d))
            self.register_buffer('n_sum', torch.empty(self.k))
            self.register_buffer('n_sum_ones', torch.ones(self.k))
            #self.ema_numer.detach_()
            #self.ema_denom.detach_()
            #self.z_sum.detach_()
            #self.n_sum.detach_()
            #self.emb.detach_()
            #nn.init.ones_(self.ema_denom)
            self.ema_numer = self.emb * self.ema_gamma_comp
            self.ema_denom = self.n_sum_ones * self.ema_gamma_comp

        netmisc.xavier_init(self.linear)
Exemplo n.º 5
0
    def __init__(self,
                 n_in_chan,
                 n_out_chan,
                 filter_sz,
                 stride=1,
                 do_res=True,
                 parent_vc=None,
                 name=None):
        super(ConvReLURes, self).__init__()
        self.n_in = n_in_chan
        self.n_out = n_out_chan
        self.conv = nn.Conv1d(n_in_chan,
                              n_out_chan,
                              filter_sz,
                              stride,
                              padding=0,
                              bias=True)
        self.relu = nn.ReLU()
        self.name = name
        # self.bn = nn.BatchNorm1d(n_out_chan)

        self.vc = vconv.VirtualConv(filter_info=filter_sz,
                                    stride=stride,
                                    parent=parent_vc,
                                    name=name)

        self.do_res = do_res
        if self.do_res:
            if stride != 1:
                print('Stride must be 1 for residually connected convolution',
                      file=stderr)
                raise ValueError
            l_off, r_off = vconv.output_offsets(self.vc, self.vc)
            self.register_buffer('residual_offsets',
                                 torch.tensor([l_off, r_off]))

        netmisc.xavier_init(self.conv)
Exemplo n.º 6
0
 def __init__(self, n_in, n_out, bias=True):
     super(AE, self).__init__()
     self.linear = nn.Conv1d(n_in, n_out, 1, bias=bias)
     netmisc.xavier_init(self.linear)