Esempio n. 1
0
    def __init__(self, args, num_features):
        super().__init__()
        self.args = args
        self.remove_writer_dropout = args.control_remove_writer_dropout

        self.final_ln = LayerNorm(num_features, elementwise_affine=True) \
                if args.control_add_final_ln else None

        if args.control_aggregation == 'max':
            self.aggregator = GridMAX(num_features)
        elif args.control_aggregation == 'cell':
            self.aggregator = lambda x: (x, None)
        else:
            raise ValueError('Unknown aggregation for the controller',
                             args.control_aggregation)

        self.net = nn.Sequential()
        for _ in range(args.control_num_layers):
            self.net.add_module('lin%d' % _,
                                Linear(num_features, num_features // 2))
            num_features = num_features // 2

        # Oracle:
        self.oracle = SimulTransOracle(args.control_oracle_penalty)

        # Agent : Observation >> Binary R/W decision
        self.gate = nn.Linear(num_features, 1, bias=True)
        nn.init.normal_(self.gate.weight, 0, 1 / num_features)
        nn.init.constant_(self.gate.bias, 0)
        self.write_right = args.control_write_right
Esempio n. 2
0
    def __init__(self, args,  dictionary, embed_tokens, left_pad=False):
        super().__init__(dictionary)
        self.share_input_output_embed = args.share_decoder_input_output_embed

        self.decoder_dim = args.decoder_embed_dim
        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_target_positions, 
            embed_dim, self.padding_idx,
            left_pad=args.left_pad_target,
            learned=args.learned_pos,
        ) if args.add_positional_embeddings else None

        self.embedding_dropout = nn.Dropout(args.embeddings_dropout)
        self.input_channels = args.encoder_embed_dim + args.decoder_embed_dim
        self.output_dim = args.output_dim

        print('Input channels:', self.input_channels)
        if args.network == 'resnet_addup_nonorm2':
            self.net = ResNetAddUpNoNorm2(self.input_channels, args)
        elif args.network == 'resnet_addup_nonorm':
            self.net = ResNetAddUpNoNorm(self.input_channels, args)

        self.output_channels = self.net.output_channels
        
        self.aggregator = GridMAX(self.output_channels)
        
        print('Decoder dim:', self.decoder_dim)
        print('The ConvNet output channels:', self.output_channels)
        print('Required output dim:', self.output_dim)

        if not self.output_dim == self.output_channels or not args.skip_output_mapping:
            self.projection = Linear(
                self.output_channels,
                self.output_dim,
                dropout=args.prediction_dropout
            )

        else:
            self.projection = None
        self.prediction_dropout = nn.Dropout(args.prediction_dropout)
        if self.share_input_output_embed:
            self.prediction = Linear(
                self.decoder_dim,
                len(dictionary)
            )
            self.prediction.weight = self.embed_tokens.weight
        else:
            self.prediction = Linear(
                self.output_dim,
                len(dictionary)
            )
Esempio n. 3
0
    def __init__(self, args, dictionary, embed_tokens):
        super().__init__(dictionary)
        self.share_input_output_embed = args.share_decoder_input_output_embed

        self.decoder_dim = args.decoder_embed_dim
        embed_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_target_positions = args.max_target_positions

        self.embed_tokens = embed_tokens
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = PositionalEmbedding(
            args.max_target_positions,
            embed_dim,
            self.padding_idx,
            left_pad=args.left_pad_target,  # False
            learned=args.learned_pos,
        ) if args.add_positional_embeddings else None
        self.ln = lambda x: x
        if args.embeddings_ln:
            self.ln = nn.LayerNorm(embed_dim, elementwise_affine=True)

        self.embedding_dropout = nn.Dropout(args.embeddings_dropout)
        self.input_dropout = nn.Dropout(args.input_dropout)
        self.input_channels = args.encoder_embed_dim + args.decoder_embed_dim
        self.output_dim = args.output_dim

        if args.network == 'resnet':
            self.net = ResNet(self.input_channels, args)
        elif args.network == 'resnet2':
            self.net = ResNet2(self.input_channels, args)
        elif args.network == 'dilated_resnet':
            self.net = DilatedResnet(self.input_channels, args)
        elif args.network == 'dilated_resnet2':
            self.net = DilatedResnet2(self.input_channels, args)
        elif args.network == 'expanding_resnet':
            self.net = ExpandingResNet(self.input_channels, args)

        elif args.network == 'fav_resnet':
            self.net = FavResNet(self.input_channels, args)

        elif args.network == 'resnet3':
            self.net = ResNet3(self.input_channels, args)
        elif args.network == 'resnet4':
            self.net = ResNet4(self.input_channels, args)
        elif args.network == 'resnet5':
            self.net = ResNet5(self.input_channels, args)
        elif args.network == 'resnet6':
            self.net = ResNet6(self.input_channels, args)
        elif args.network == 'resnet_renorm':
            self.net = ResNetReNorm(self.input_channels, args)
        elif args.network == 'resnet_addup':
            self.net = ResNetAddUp(self.input_channels, args)
        elif args.network == 'resnet_addup2':
            self.net = ResNetAddUp2(self.input_channels, args)
        elif args.network == 'resnet_addup3':
            self.net = ResNetAddUp3(self.input_channels, args)
        elif args.network == 'resnet_addup_nonorm':
            self.net = ResNetAddUpNoNorm(self.input_channels, args)
        elif args.network == 'resnet_addup_nonorm2':
            self.net = ResNetAddUpNoNorm2(self.input_channels, args)
        elif args.network == 'resnet_addup_nonorm2_rev':
            self.net = ResNetAddUpNoNorm2Rev(self.input_channels, args)

        elif args.network == 'resnet_addup_nonorm2_wbias':
            self.net = BiasResNetAddUpNoNorm2(self.input_channels, args)
        elif args.network == 'resnet_addup_nonorm2_gated':
            self.net = ResNetAddUpNoNorm2Gated(self.input_channels, args)
        elif args.network == 'resnet_addup_nonorm2_all':
            self.net = ResNetAddUpNoNorm2All(self.input_channels, args)

        elif args.network == 'resnet_addup_nonorm2_gated_noffn':
            self.net = ResNetAddUpNoNorm2GatedNoFFN(self.input_channels, args)
        elif args.network == 'resnet_addup_nonorm3':
            self.net = ResNetAddUpNoNorm3(self.input_channels, args)
        elif args.network == 'resnet_addup_nonorm4':
            self.net = ResNetAddUpNoNorm4(self.input_channels, args)
        elif args.network == 'densenet_ln':
            self.net = DenseNetLN(self.input_channels, args)
        elif args.network == 'densenet_ffn':
            self.net = DenseNetFFN(self.input_channels, args)
        elif args.network == 'densenet_ffn_pono':
            self.net = DenseNetFFNPONO(self.input_channels, args)
        elif args.network == 'densenet_pono':
            self.net = DenseNetPONO(self.input_channels, args)
        elif args.network == 'densenet_pono_cascade':
            self.net = DenseNetPONOCascade(self.input_channels, args)
        elif args.network == 'densenet_cascade':
            self.net = DenseNetCascade(self.input_channels, args)

        elif args.network == 'densenet_pono_kmax':
            self.net = DenseNetPONOKmax(self.input_channels, args)
        elif args.network == 'densenet_bn':
            self.net = DenseNetBN(self.input_channels, args)
        elif args.network == 'densenet_nonorm':
            self.net = DenseNetNoNorm(self.input_channels, args)

        else:
            raise ValueError('Unknown architecture %s' % args.network)

        self.policy = args.waitk_policy
        self.waitk = args.waitk

        self.output_channels = self.net.output_channels
        if args.waitk_policy == 'area':
            if args.aggregation == 'max':
                self.aggregator = GridMAX(self.output_channels)
            elif args.aggregation == 'max2':
                self.aggregator = GridMAX2(self.output_channels)
            elif args.aggregation == 'gated_max':
                self.aggregator = GridGatedMAX(self.output_channels)
            elif args.aggregation == 'attn':
                self.aggregator = GridATTN(self.output_channels)
            else:
                raise ValueError('Unknown aggregation %s' % args.aggregation)
        elif args.waitk_policy == 'path':
            if args.aggregation == 'max':
                self.aggregator = PathMAX(self.output_channels, self.waitk)
            elif args.aggregation == 'gated_max':
                self.aggregator = PathGatedMAX(self.output_channels,
                                               self.waitk)
            elif args.aggregation == 'attn':
                self.aggregator = PathATTN(self.output_channels, self.waitk)
            elif args.aggregation == 'cell':
                self.aggregator = PathCell(self.output_channels, self.waitk)
            elif args.aggregation == 'full':
                self.aggregator = PathMAXFull(self.output_channels, self.waitk)

            else:
                raise ValueError('Unknown aggregation %s' % args.aggregation)
        else:
            raise ValueError('Unknown policy %s' % args.waitk_policy)

        print('Decoder dim:', self.decoder_dim)
        print('The ConvNet output channels:', self.output_channels)
        print('Required output dim:', self.output_dim)

        if not self.output_dim == self.output_channels or not args.skip_output_mapping:
            self.projection = Linear(self.output_channels,
                                     self.output_dim,
                                     dropout=args.prediction_dropout)

        else:
            self.projection = None
        self.prediction_dropout = nn.Dropout(args.prediction_dropout)
        if self.share_input_output_embed:
            self.prediction = Linear(self.decoder_dim, len(dictionary))
            self.prediction.weight = self.embed_tokens.weight
        else:
            self.prediction = Linear(self.output_dim, len(dictionary))
        self.need_attention_weights = args.need_attention_weights