def forward(self, s0, s1, drop_prob): s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) states = [s0, s1] for i in range(self._steps): h1 = states[self._indices[2 * i]] h2 = states[self._indices[2 * i + 1]] op1 = self._ops[2 * i] op2 = self._ops[2 * i + 1] h1 = op1(h1) h2 = op2(h2) if self.training and drop_prob > 0.: if not isinstance(op1, Identity): h1 = drop_path(h1, drop_prob) if not isinstance(op2, Identity): h2 = drop_path(h2, drop_prob) s = h1 + h2 states += [s] return torch.cat([states[i] for i in self._concat], dim=1)
def forward(self, s0, s1, weights, drop_prob=0.): s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) states = [s0, s1] offset = 0 for i in range(self._steps): if drop_prob > 0. and self.training: s = sum( drop_path(self._ops[offset + j](h, weights[offset + j]), drop_prob) for j, h in enumerate(states)) else: s = sum(self._ops[offset + j](h, weights[offset + j]) for j, h in enumerate(states)) offset += len(states) states.append(s) return torch.cat(states[-self._multiplier:], dim=1)