def __init__(self, args, num_features): super().__init__() self.args = args self.remove_writer_dropout = args.control_remove_writer_dropout self.final_ln = LayerNorm(num_features, elementwise_affine=True) \ if args.control_add_final_ln else None if args.control_aggregation == 'max': self.aggregator = GridMAX(num_features) elif args.control_aggregation == 'cell': self.aggregator = lambda x: (x, None) else: raise ValueError('Unknown aggregation for the controller', args.control_aggregation) self.net = nn.Sequential() for _ in range(args.control_num_layers): self.net.add_module('lin%d' % _, Linear(num_features, num_features // 2)) num_features = num_features // 2 # Oracle: self.oracle = SimulTransOracle(args.control_oracle_penalty) # Agent : Observation >> Binary R/W decision self.gate = nn.Linear(num_features, 1, bias=True) nn.init.normal_(self.gate.weight, 0, 1 / num_features) nn.init.constant_(self.gate.bias, 0) self.write_right = args.control_write_right
def __init__(self, args, dictionary, embed_tokens, left_pad=False): super().__init__(dictionary) self.share_input_output_embed = args.share_decoder_input_output_embed self.decoder_dim = args.decoder_embed_dim embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_target_positions, embed_dim, self.padding_idx, left_pad=args.left_pad_target, learned=args.learned_pos, ) if args.add_positional_embeddings else None self.embedding_dropout = nn.Dropout(args.embeddings_dropout) self.input_channels = args.encoder_embed_dim + args.decoder_embed_dim self.output_dim = args.output_dim print('Input channels:', self.input_channels) if args.network == 'resnet_addup_nonorm2': self.net = ResNetAddUpNoNorm2(self.input_channels, args) elif args.network == 'resnet_addup_nonorm': self.net = ResNetAddUpNoNorm(self.input_channels, args) self.output_channels = self.net.output_channels self.aggregator = GridMAX(self.output_channels) print('Decoder dim:', self.decoder_dim) print('The ConvNet output channels:', self.output_channels) print('Required output dim:', self.output_dim) if not self.output_dim == self.output_channels or not args.skip_output_mapping: self.projection = Linear( self.output_channels, self.output_dim, dropout=args.prediction_dropout ) else: self.projection = None self.prediction_dropout = nn.Dropout(args.prediction_dropout) if self.share_input_output_embed: self.prediction = Linear( self.decoder_dim, len(dictionary) ) self.prediction.weight = self.embed_tokens.weight else: self.prediction = Linear( self.output_dim, len(dictionary) )
def __init__(self, args, dictionary, embed_tokens): super().__init__(dictionary) self.share_input_output_embed = args.share_decoder_input_output_embed self.decoder_dim = args.decoder_embed_dim embed_dim = embed_tokens.embedding_dim self.padding_idx = embed_tokens.padding_idx self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) self.embed_positions = PositionalEmbedding( args.max_target_positions, embed_dim, self.padding_idx, left_pad=args.left_pad_target, # False learned=args.learned_pos, ) if args.add_positional_embeddings else None self.ln = lambda x: x if args.embeddings_ln: self.ln = nn.LayerNorm(embed_dim, elementwise_affine=True) self.embedding_dropout = nn.Dropout(args.embeddings_dropout) self.input_dropout = nn.Dropout(args.input_dropout) self.input_channels = args.encoder_embed_dim + args.decoder_embed_dim self.output_dim = args.output_dim if args.network == 'resnet': self.net = ResNet(self.input_channels, args) elif args.network == 'resnet2': self.net = ResNet2(self.input_channels, args) elif args.network == 'dilated_resnet': self.net = DilatedResnet(self.input_channels, args) elif args.network == 'dilated_resnet2': self.net = DilatedResnet2(self.input_channels, args) elif args.network == 'expanding_resnet': self.net = ExpandingResNet(self.input_channels, args) elif args.network == 'fav_resnet': self.net = FavResNet(self.input_channels, args) elif args.network == 'resnet3': self.net = ResNet3(self.input_channels, args) elif args.network == 'resnet4': self.net = ResNet4(self.input_channels, args) elif args.network == 'resnet5': self.net = ResNet5(self.input_channels, args) elif args.network == 'resnet6': self.net = ResNet6(self.input_channels, args) elif args.network == 'resnet_renorm': self.net = ResNetReNorm(self.input_channels, args) elif args.network == 'resnet_addup': self.net = ResNetAddUp(self.input_channels, args) elif args.network == 'resnet_addup2': self.net = ResNetAddUp2(self.input_channels, args) elif args.network == 'resnet_addup3': self.net = ResNetAddUp3(self.input_channels, args) elif args.network == 'resnet_addup_nonorm': self.net = ResNetAddUpNoNorm(self.input_channels, args) elif args.network == 'resnet_addup_nonorm2': self.net = ResNetAddUpNoNorm2(self.input_channels, args) elif args.network == 'resnet_addup_nonorm2_rev': self.net = ResNetAddUpNoNorm2Rev(self.input_channels, args) elif args.network == 'resnet_addup_nonorm2_wbias': self.net = BiasResNetAddUpNoNorm2(self.input_channels, args) elif args.network == 'resnet_addup_nonorm2_gated': self.net = ResNetAddUpNoNorm2Gated(self.input_channels, args) elif args.network == 'resnet_addup_nonorm2_all': self.net = ResNetAddUpNoNorm2All(self.input_channels, args) elif args.network == 'resnet_addup_nonorm2_gated_noffn': self.net = ResNetAddUpNoNorm2GatedNoFFN(self.input_channels, args) elif args.network == 'resnet_addup_nonorm3': self.net = ResNetAddUpNoNorm3(self.input_channels, args) elif args.network == 'resnet_addup_nonorm4': self.net = ResNetAddUpNoNorm4(self.input_channels, args) elif args.network == 'densenet_ln': self.net = DenseNetLN(self.input_channels, args) elif args.network == 'densenet_ffn': self.net = DenseNetFFN(self.input_channels, args) elif args.network == 'densenet_ffn_pono': self.net = DenseNetFFNPONO(self.input_channels, args) elif args.network == 'densenet_pono': self.net = DenseNetPONO(self.input_channels, args) elif args.network == 'densenet_pono_cascade': self.net = DenseNetPONOCascade(self.input_channels, args) elif args.network == 'densenet_cascade': self.net = DenseNetCascade(self.input_channels, args) elif args.network == 'densenet_pono_kmax': self.net = DenseNetPONOKmax(self.input_channels, args) elif args.network == 'densenet_bn': self.net = DenseNetBN(self.input_channels, args) elif args.network == 'densenet_nonorm': self.net = DenseNetNoNorm(self.input_channels, args) else: raise ValueError('Unknown architecture %s' % args.network) self.policy = args.waitk_policy self.waitk = args.waitk self.output_channels = self.net.output_channels if args.waitk_policy == 'area': if args.aggregation == 'max': self.aggregator = GridMAX(self.output_channels) elif args.aggregation == 'max2': self.aggregator = GridMAX2(self.output_channels) elif args.aggregation == 'gated_max': self.aggregator = GridGatedMAX(self.output_channels) elif args.aggregation == 'attn': self.aggregator = GridATTN(self.output_channels) else: raise ValueError('Unknown aggregation %s' % args.aggregation) elif args.waitk_policy == 'path': if args.aggregation == 'max': self.aggregator = PathMAX(self.output_channels, self.waitk) elif args.aggregation == 'gated_max': self.aggregator = PathGatedMAX(self.output_channels, self.waitk) elif args.aggregation == 'attn': self.aggregator = PathATTN(self.output_channels, self.waitk) elif args.aggregation == 'cell': self.aggregator = PathCell(self.output_channels, self.waitk) elif args.aggregation == 'full': self.aggregator = PathMAXFull(self.output_channels, self.waitk) else: raise ValueError('Unknown aggregation %s' % args.aggregation) else: raise ValueError('Unknown policy %s' % args.waitk_policy) print('Decoder dim:', self.decoder_dim) print('The ConvNet output channels:', self.output_channels) print('Required output dim:', self.output_dim) if not self.output_dim == self.output_channels or not args.skip_output_mapping: self.projection = Linear(self.output_channels, self.output_dim, dropout=args.prediction_dropout) else: self.projection = None self.prediction_dropout = nn.Dropout(args.prediction_dropout) if self.share_input_output_embed: self.prediction = Linear(self.decoder_dim, len(dictionary)) self.prediction.weight = self.embed_tokens.weight else: self.prediction = Linear(self.output_dim, len(dictionary)) self.need_attention_weights = args.need_attention_weights