Esempio n. 1
0
    def forward(self, s0, s1, drop_prob):
        """

        :param s0:
        :param s1:
        :param drop_prob:
        :return:
        """
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        for i in range(self._steps):
            h1 = states[self._indices[2 * i]]
            h2 = states[self._indices[2 * i + 1]]
            op1 = self._ops[2 * i]
            op2 = self._ops[2 * i + 1]
            h1 = op1(h1)
            h2 = op2(h2)

            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)

            s = h1 + h2
            states += [s]

        return torch.cat([states[i] for i in self._concat], dim=1)
Esempio n. 2
0
    def forward(self, s0, s1, drop_prob):
        #s0, s1 = s1, cell(s0, s1, self.drop_path_prob=0.2)
        #self.preprocess首先经过一个简单的卷积网络操作,得到s0,s1
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]  #存储之前的状态
        #self._steps=4
        for i in range(self._steps):
            #初始状态,i=0时:h1=s0,h2=s1
            h1 = states[self._indices[2 * i]]  #[0,0,1,0]
            h2 = states[self._indices[2 * i + 1]]  #[1,1,0,2]
            # 最后一个0,2,0状态的那个来构成跳跃连接
            op1 = self._ops[2 * i]
            op2 = self._ops[2 * i + 1]
            h1 = op1(h1)
            h2 = op2(h2)
            #进行dropout操作
            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)
            s = h1 + h2  #每个中间状态由之前的两个操作加和而来
            states += [s]
        #将最后4个状态cat起来
        return torch.cat([states[i] for i in self._concat], dim=1)
Esempio n. 3
0
    def forward(self, s0, s1, drop_prob):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        for i in range(self._steps):
            h1 = states[self._indices[2 * i]]
            h2 = states[self._indices[2 * i + 1]]
            op1 = self._ops[2 * i]
            op2 = self._ops[2 * i + 1]
            h1 = op1(h1)
            h2 = op2(h2)
            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)
            s = h1 + h2

            states += [s]

        out = torch.cat([states[i] for i in self._concat], dim=1)
        if self.mask is None:
            self._create_mask(s1, out)

        out = self.activation(out)
        self.act = out
        return out
Esempio n. 4
0
    def forward(self, s0, s1, drop_prob):
        res0 = s0
        res1 = s1
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        for i in range(self._steps):
            h1 = states[self._indices[2 * i]]
            h2 = states[self._indices[2 * i + 1]]
            op1 = self._ops[2 * i]
            op2 = self._ops[2 * i + 1]
            h1 = op1(h1)
            h2 = op2(h2)
            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)
            s = h1 + h2
            states += [s]

        if self.layer == 0:
            return torch.cat([states[i] for i in self._concat], dim=1)
        else:
            if self.reduction:
                states_out = torch.cat([states[i] for i in self._concat], dim=1)
                states_out += self.preprocess_res(res1)
                return states_out
            else:
                states_out = torch.cat([states[i] for i in self._concat], dim=1)
                states_out += res1
                return states_out
Esempio n. 5
0
    def forward(self, s0, s1, drop_prob):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        for i in range(self._steps):
            h1 = states[self._indices[2 * i]]
            h2 = states[self._indices[2 * i + 1]]
            op1 = self._ops[2 * i]
            op2 = self._ops[2 * i + 1]
            h1 = op1(h1)
            h2 = op2(h2)
            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)
            if self._weight == None:
                s = h1 + h2
            else:
                s = self._weight[2 * i] * h1 + self._weight[2 * i + 1] * h2
            states += [s]
            s = h1 + h2
            states += [s]
        return torch.cat([states[i] for i in self._concat], dim=1)  # N,C,H, W
Esempio n. 6
0
    def forward(self, s0, s1, drop_prob):
        """

        :param s0:
        :param s1:
        :param drop_prob:
        :return:
        """
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        for i in range(self._steps):
            # for each noce i, find which previous two node we
            # connect to and corresponding ops for them
            h1 = states[self._indices[2 * i]]
            h2 = states[self._indices[2 * i + 1]]
            op1 = self._ops[2 * i]
            op2 = self._ops[2 * i + 1]
            h1 = op1(h1)
            h2 = op2(h2)

            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)

            # aggregation of ops result is arithmatic sum
            s = h1 + h2
            states += [s]

        # concatenate outputs of all node which becomes the result of the cell
        # this makes it necessory that wxh is same for all outputs
        return torch.cat([states[i] for i in self._concat], dim=1)
Esempio n. 7
0
    def forward(self, s0, s1, drop_prob):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        """
       Iterates over the number of feedfoward steps
       and returns the states h1 and h2, and applies the corresponding operations to the
       computed states two 
    """
        for i in range(self._steps):
            h1 = states[self._indices[2 * i]]
            h2 = states[self._indices[2 * i + 1]]
            op1 = self._ops[2 * i]
            op2 = self._ops[2 * i + 1]
            h1 = op1(h1)
            h2 = op2(h2)
            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)
            s = h1 + h2
            states += [s]
        return torch.cat([states[i] for i in self._concat], dim=1)
Esempio n. 8
0
    def forward(self, s0, s1, drop_prob):
        s0p = self.preprocess0(s0)
        s1p = self.preprocess1(s1)

        states = [s0p, s1p]
        for i in range(self._steps):
            h1 = states[self._indices[2 * i]]
            h2 = states[self._indices[2 * i + 1]]
            op1 = self._ops[2 * i]
            op2 = self._ops[2 * i + 1]
            h1 = op1(h1)
            h2 = op2(h2)
            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)
            s = h1 + h2
            states += [s]

        if self.shrink_channel:
            out = None
            for s in states[-len(self._concat):]:
                if out is None:
                    out = s
                else:
                    out = out + s
        else:
            out = torch.cat([states[i] for i in self._concat], dim=1)
        '''
    print('shape info')
    print('output', out.shape)
    print('origin', s0.shape, s1.shape)
    print('preproc', s0p.shape, s1p.shape)
    print('normal', self.residual_norm(s0).shape, self.residual_norm(s1).shape)
    if self.reduction:
        print('reduce all', self.residual_reduce(s0).shape, self.residual_reduce(s1).shape)
    elif self.reduction_prev:
        print('reduce s0', self.residual_reduce(s0).shape)
    print('end')
    '''

        if self.reduction:
            out = out + self.residual_wei * self.residual_reduce(s0)
            out = out + self.residual_wei * self.residual_reduce(s1)
        elif self.reduction_prev:
            if s1.shape[1] < out.shape[1]:
                s1 = s1.repeat(1, out.shape[1] // s1.shape[1], 1, 1)
            out = out + self.residual_wei * self.residual_reduce(s0)
            out = out + self.residual_wei * self.residual_norm(s1)
        else:
            if s0.shape[1] < out.shape[1]:
                s0 = s0.repeat(1, out.shape[1] // s0.shape[1], 1, 1)
            out = out + self.residual_wei * self.residual_norm(s0)
            out = out + self.residual_wei * self.residual_norm(s1)

        #return torch.cat([states[i] for i in self._concat], dim=1)
        return out
Esempio n. 9
0
    def forward(self, s0, s1, drop_prob):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        j = 0
        for i in self._genotype_nodes:
            h1 = states[i[0]]
            h2 = states[i[1]]
            op1 = self._ops[2 * j]
            op2 = self._ops[2 * j + 1]
            h1 = op1(h1)
            h2 = op2(h2)

            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)
            s = h1 + h2
            # if i==0:
            #   h1=states[0]
            #   h2=states[1]
            #   op1=self._ops[0]
            #   op2=self._ops[1]
            #   h1=op1(h1)
            #   h2=op2(h2)
            # s = h1 + h2
            # if i==1:
            #   h1 = states[0]
            #   h2 = states[1]
            #   h3 = states[2]
            #   op1 = self._ops[2]
            #   op2 = self._ops[3]
            #   op3 = self._ops[4]
            #   h1=op1(h1)
            #   h2=op2(h2)
            #   h3=op3(h3)
            #   s=h1+h2+h3
            # if i==2:
            #   h1 = states[0]
            #   h2 = states[1]
            #   h3 = states[2]
            #   h4 = states[3]
            #   op1 = self._ops[5]
            #   op2 = self._ops[6]
            #   op3 = self._ops[7]
            #   op4 = self._ops[8]
            #   h1 = op1(h1)
            #   h2 = op2(h2)
            #   h3 = op3(h3)
            #   h4 = op4(h4)
            #   s=h1+h2+h3+h4
            states += [s]
            j = j + 1
        # return torch.cat([states[i] for i in self._concat], dim=1)
        return torch.cat([states[i + 2] for i in range(self._steps)], dim=1)
    def forward(self, s0, s1, drop_prob=-1):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        X0 = self.ops1[0](s0)
        X1 = self.ops1[1](s1)
        if self.training and drop_prob > 0.:
            X0, X1 = drop_path(X0, drop_prob), drop_path(X1, drop_prob)

        # X2 = self.ops2[0] (X0+X1)
        X2 = self.ops2[0](s0)
        X3 = self.ops2[1](s1)
        if self.training and drop_prob > 0.:
            X2, X3 = drop_path(X2, drop_prob), drop_path(X3, drop_prob)
        return torch.cat([X0, X1, X2, X3], dim=1)
Esempio n. 11
0
    def forward(self, s0, s1, drop_path_prob, weights):
        if self._use_ckpt:
            s0 = cp.checkpoint(self.preprocess0, s0)
            s1 = cp.checkpoint(self.preprocess1, s1)
        else:
            s0 = self.preprocess0(s0)
            s1 = self.preprocess1(s1)

        states = [s0, s1]
        offset = 0
        for i in range(self._steps):
            s = 0
            for j, h in enumerate(states):
                op = self._ops[offset + j]
                if self._use_ckpt:
                    h = cp.checkpoint(op, *[h, weights[offset + j]])
                else:
                    h = op(h, weights[offset + j])
                if self.training and drop_path_prob > 0.:
                    if not isinstance(op, Identity):
                        h = drop_path(h, drop_path_prob)
                s += h
            offset += len(states)
            states.append(s)
        return torch.cat(states[-4:], dim=1)
Esempio n. 12
0
    def forward(self, s0, s1, drop_prob):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = {}
        states['0'] = s0
        states['1'] = s1

        # get all the operations in current intermediate node
        for to_i, ops in self._ops.items():
            h = []
            for from_i, op_i in ops.items():
                # each edge may no more than one operation
                if from_i not in states:
                    #print('Exist the isolate node, which id is {}, we need ignore it!'.format(from_i))
                    continue
                h += [
                    sum([
                        op(states[from_i]) for op in op_i if from_i in states
                    ])
                ]
            out = sum(h)
            if self.training and drop_prob > 0:
                out = drop_path(out, drop_prob)
            states[to_i] = out

        return torch.cat([v for v in states.values()][2:], dim=1)
Esempio n. 13
0
    def forward(self, s0, s1, drop_prob):
        """

        :param s0:
        :param s1:
        :param drop_prob:
        :return:
        """
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        for i in range(self._steps):
            h1 = states[self._indices[2 * i]]
            h2 = states[self._indices[2 * i + 1]]
            op1 = self._first_layers[2 * i]
            op2 = self._first_layers[2 * i + 1]
            h1 = op1(h1)
            h2 = op2(h2)

            if self._second_layers:
                at1 = self._second_layers[2 * i]
                at2 = self._second_layers[2 * i + 1]
                h1 = at1(h1)
                h2 = at2(h2)

                if self.training and drop_prob > 0.:
                    if not isinstance(op1, Identity) and not isinstance(
                            at1, Identity):
                        h1 = drop_path(h1, drop_prob)
                    if not isinstance(op2, Identity) and not isinstance(
                            at2, Identity):
                        h2 = drop_path(h2, drop_prob)
            else:
                if self.training and drop_prob > 0.:
                    if not isinstance(op1, Identity):
                        h1 = drop_path(h1, drop_prob)
                    if not isinstance(op2, Identity):
                        h2 = drop_path(h2, drop_prob)

            s = h1 + h2
            states += [s]

        out = torch.cat([states[i] for i in self._concat], dim=1)
        if self._bottleneck:
            out = self._bottleneck(out)
        return out
Esempio n. 14
0
    def forward(self, s0, s1, drop_prob):
        res0 = s0
        res1 = s1
        if self.reduction:
            s0 = self.preprocess0_red(res0)
            s1 = self.preprocess1_red(res1)
        else:
            s0 = self.preprocess0(s0)
            s1 = self.preprocess1(s1)

        states = [s0, s1]
        for i in range(self._steps):
            h1 = states[self._indices[2 * i]]
            h2 = states[self._indices[2 * i + 1]]
            op1 = self._ops[2 * i]
            op2 = self._ops[2 * i + 1]
            h1 = op1(h1)
            h2 = op2(h2)
            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)
            s = h1 + h2
            states += [s]

        if self.layer == 0:
            return torch.cat([states[i] for i in self._concat], dim=1)
        else:
            if self.reduction:
                states_out = torch.cat([states[i] for i in self._concat],
                                       dim=1)
                if res1.shape[2] % 2 == 1 and res1.shape[3] % 2 == 0:
                    preprocess_out = self.preprocess_res_xpad(res1)
                elif res1.shape[2] % 2 == 0 and res1.shape[3] % 2 == 1:
                    preprocess_out = self.preprocess_res_ypad(res1)
                else:
                    preprocess_out = self.preprocess_res(res1)
                states_out += preprocess_out
                return states_out
            else:
                # if self.layer>0:
                states_out = torch.cat([states[i] for i in self._concat],
                                       dim=1)
                # states_out.append(states[1])
                states_out += res1
                return states_out
Esempio n. 15
0
    def forward(self, s0, s1, drop_prob):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        for i in range(self._steps):
            h1 = states[self._indices[2*i]]
            h2 = states[self._indices[2*i+1]]
            op1 = self._ops[2*i]
            op2 = self._ops[2*i+1]
            h1 = op1(h1)
            h2 = op2(h2)
            if self.training and drop_prob > 0.:
                if not op1.is_identity_op():
                    h1 = drop_path(h1, drop_prob)
                if not op2.is_identity_op():
                    h2 = drop_path(h2, drop_prob)
            s = h1 + h2
            states += [s]
        return torch.cat([states[i] for i in self._concat], dim=1)
Esempio n. 16
0
    def forward(self, x, drop_prob):
        s0 = s1 = x

        states = [s0, s1]
        for i in range(self._steps):
            h1 = states[self._indices[2 * i]]
            h2 = states[self._indices[2 * i + 1]]
            op1 = self._ops[2 * i]
            op2 = self._ops[2 * i + 1]
            h1 = op1(h1)
            h2 = op2(h2)
            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)
            s = h1 + h2
            states += [s]
            mask = self.activ(sum(states[2:]))
        return x * mask
Esempio n. 17
0
  def forward(self, s0, s1, drop_prob):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)

    states = [s0, s1]
    for i in range(self._steps): # 遍历每个节点,建立连接
      h1 = states[self._indices[2*i]] # 2*i取i个节点的第1个连接,指示i节点与之前第几个节点输出的feature map连接,h1表示建立连接后的要被处理的输入节点的特征向量
      h2 = states[self._indices[2*i+1]]
      op1 = self._ops[2*i]
      op2 = self._ops[2*i+1]
      h1 = op1(h1)
      h2 = op2(h2)
      if self.training and drop_prob > 0.:
        if not isinstance(op1, Identity): # 非恒等映射,dropout
          h1 = drop_path(h1, drop_prob)
        if not isinstance(op2, Identity):
          h2 = drop_path(h2, drop_prob)
      s = h1 + h2 # 两个算子是按位加
      states += [s] #存入states
    return torch.cat([states[i] for i in self._concat], dim=1) # 取要保留的feature拼接
Esempio n. 18
0
    def forward(self, s0, s1, drop_prob):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        for i in range(self._steps):
            h1 = states[self._indices[2 * i]]
            h2 = states[self._indices[2 * i + 1]]
            op1 = self._ops[2 * i]
            op2 = self._ops[2 * i + 1]
            # I assume these two are the incoming operations of the complete cell
            h1 = op1(h1)
            h2 = op2(h2)
            if self.training and drop_prob > 0.:
                if not isinstance(op1, Identity):
                    h1 = drop_path(h1, drop_prob)
                if not isinstance(op2, Identity):
                    h2 = drop_path(h2, drop_prob)
        # compute sum and add it to the states
            s = h1 + h2
            states += [s]
        return torch.cat([states[i] for i in self._concat], dim=1)
Esempio n. 19
0
 def forward(self, x, weights, drop_prob, eta_min, node_sum):
     mix_op = 0
     k = 0
     for w, op in zip(weights, self._ops):
         if w > eta_min * node_sum:
             if not isinstance(op, Identity):
                 mix_op = mix_op + w * drop_path(op(x), drop_prob)
             else:
                 mix_op = mix_op + w * op(x)
         else:
             if not isinstance(self._ops[k], Zero):
                 self._ops[k] = Zero(self.stride)
         k += 1
     return mix_op
Esempio n. 20
0
    def forward(self, s0, s1, drop_prob):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = {0: s0, 1: s1}
        for op, f, t in zip(self._ops, self._f_nodes, self._t_nodes):
            s = op(states[f])
            if self.training and drop_prob > 0.:
                if not isinstance(op, Identity):
                    s = drop_path(s, drop_prob)
            if t in states:
                states[t] = states[t] + s
            else:
                states[t] = s
        return torch.cat([states[i] for i in self._concat], dim=1)
Esempio n. 21
0
 def forward(self, s0, s1, drop_prob):
     s0 = self.preprocess0(s0)
     s1 = self.preprocess1(s1)
     states = [s0, s1]
     for i in range(4):
         s = 0
         for j in range(len(self._indices_output)):
             if self._indices_output[j] == (i + 2):
                 h = states[self._indices_input[j]]
                 op = self._ops[j]
                 h = op(h)
                 if self.training and drop_prob > 0.:
                     if not isinstance(op, Identity):
                         h = drop_path(h, drop_prob)
                 s = s + h
         states += [s]
     return torch.cat([states[i] for i in self._concat], dim=1)
Esempio n. 22
0
    def forward(self, s0, s1, weights, drop_prob=0.0):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        offset = 0
        for i in range(self._steps):
            if drop_prob > 0. and self.training:
                s = sum(
                    drop_path(self._ops[offset + j](h, weights[offset +
                                                               j]), drop_prob)
                    for j, h in enumerate(states))
            else:
                s = sum(self._ops[offset + j](h, weights[offset + j])
                        for j, h in enumerate(states))
            offset += len(states)
            states.append(s)

        return torch.cat(states[-self._multiplier:], dim=1)
Esempio n. 23
0
    def forward(self, s0, s1, drop_prob):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        for edges in self.dag:
            buffer = []
            for op in edges:
                if not isinstance(op, ops.Identity):
                    a = drop_path(op(states[op.s_idx]), drop_prob)
                else:
                    a = op(states[op.s_idx])
                buffer.append(a)
            s_cur = sum(buffer)
            # s_cur = sum(op(states[op.s_idx]) for op in edges)
            states.append(s_cur)

        s_out = torch.cat([states[i] for i in self.concat], dim=1)

        return s_out
Esempio n. 24
0
    def forward(self, s0, s1, drop_prob):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)

        states = [s0, s1]
        start = 0
        end = start + 2
        for i in range(self._steps):
            s = 0
            for j in range(start, end):
                h = states[self._indices[j]]
                op = self._ops[j]
                h = op(h)
                if self.training and drop_prob > 0.:
                    if not isinstance(op, Identity):
                        h = drop_path(h, drop_prob)
                s += h
            start = end
            end += (i + 3)
            states += [s]
        return torch.cat([states[i] for i in self._concat], dim=1)