def prior_sharing(task_outputs, num_modules): nt = len(task_outputs) return nn.Sequential( IgnoreTaskRouting(conv_layer(1, 53)), IgnoreTaskRouting(conv_layer(53, 53)), StaticTaskRouting(nt, [conv_layer(53, 53) for _ in task_outputs]), StaticTaskRouting(nt, [conv_layer(53, 53) for _ in task_outputs]), StaticTaskRouting(nt, [decoder(53, s) for s in task_outputs]), )
def prior_sharing(task_outputs, num_modules): nt = len(task_outputs) return nn.Sequential( IgnoreTaskRouting(conv_layer(1, 32)), IgnoreTaskRouting(conv_layer(32, 32)), StaticTaskRouting(nt, [conv_layer(32, 32) for _ in task_outputs]), IgnoreTaskRouting(Flatten()), StaticTaskRouting(nt, [dense_layer(288, 128) for _ in task_outputs]), StaticTaskRouting(nt, [nn.Linear(128, s) for s in task_outputs]), )
def make_wrapped_block(self, channels, stride): if channels == 64 or channels == 128: return IgnoreTaskRouting(self.make_block(channels, stride)) else: return StaticTaskRouting(self.num_tasks, [ self.make_block(channels, stride) for _ in range(self.num_tasks) ])
def learned_sharing(task_outputs, num_modules): nt = len(task_outputs) return nn.Sequential( LearnedTaskRouting(nt, [conv_layer(1, 53) for _ in range(num_modules)]), LearnedTaskRouting(nt, [conv_layer(53, 53) for _ in range(num_modules)]), LearnedTaskRouting(nt, [conv_layer(53, 53) for _ in range(num_modules)]), LearnedTaskRouting(nt, [conv_layer(53, 53) for _ in range(num_modules)]), StaticTaskRouting(nt, [decoder(53, s) for s in task_outputs]), )
def __init__(self, task_outputs, num_modules): super().__init__() self.num_tasks = len(task_outputs) self.num_modules = num_modules self.ch_in = 64 self.conv1 = IgnoreTaskRouting( nn.Sequential( nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(), )) self.layer1 = self.make_layer(64, 2, stride=1) self.layer2 = self.make_layer(128, 2, stride=2) self.layer3 = self.make_layer(256, 2, stride=2) self.layer4 = self.make_layer(512, 2, stride=2) self.pool_flatten = IgnoreTaskRouting( nn.Sequential(nn.AvgPool2d(4), Flatten())) self.fc = StaticTaskRouting(self.num_tasks, [nn.Linear(512, s) for s in task_outputs])
def make_wrapped_block(self, channels, stride): return StaticTaskRouting( self.num_tasks, [self.make_block(channels, stride) for _ in range(self.num_tasks)])