def call(self, x): s0 = s1 = self.stem(x) hardwts_reduce = gumbel_softmax(self.alphas_reduce, self.tau, hard=True) hardwts_normal = gumbel_softmax(self.alphas_normal, self.tau, hard=True) for cell in self.cells: hardwts = hardwts_reduce if cell.reduction else hardwts_normal hardwts = tf.cast(hardwts, s0.dtype) s0, s1 = s1, cell(s0, s1, hardwts) x = self.avg_pool(s1) logits = self.classifier(x) return logits
def call(self, x): hardwts, index = gumbel_softmax(self.alpha, tau=1.0, hard=True, return_index=True) branch_fns = [self._create_branch_fn(x, i, hardwts) for i in range(3)] return tf.switch_case(index, branch_fns)
def call(self, x): hardwts, index = gumbel_softmax(self.alpha_normal, tau=self.tau, hard=True, return_index=True) x = self.stem(x) x = self.normal1(x, hardwts, index) x = self.reduce1(x) x = self.normal2(x, hardwts, index) x = self.reduce2(x) x = self.normal3(x, hardwts, index) x = self.avg_pool(x) x = self.fc(x) return x
def step_fn(weights): ret, index = gumbel_softmax(weights, 1.0, True, return_index=True) return ret, index