コード例 #1
0
    def __init__(self, args, vocab):
        '''
        Seq2Seq agent
        '''
        super().__init__(args, vocab)

        self.args = args

        # TODO remove hardcoding and base on list of module names or something.
        n_modules = 8

        # Add high level vocab.
        self.vocab['high_level'] = Vocab()
        self.vocab['high_level'].word2index(self.submodule_names, train=True)

        # encoder and self-attention for starting state modules and high-level controller.
        self.enc = nn.LSTM(args.demb,
                           args.dhid,
                           bidirectional=True,
                           batch_first=True)
        self.enc_att = nn.ModuleList([
            vnn.SelfAttn(args.dhid * 2) for i in range(2)
        ])  # One for submodules and one for controller.

        # subgoal monitoring
        self.subgoal_monitoring = (self.args.pm_aux_loss_wt > 0
                                   or self.args.subgoal_aux_loss_wt > 0)

        # frame mask decoder
        decoder = vnn.ConvFrameMaskDecoderProgressMonitor if self.subgoal_monitoring else vnn.ConvFrameMaskDecoderModular
        self.dec = decoder(self.emb_action_low,
                           args.dframe,
                           2 * args.dhid,
                           pframe=args.pframe,
                           attn_dropout=args.attn_dropout,
                           hstate_dropout=args.hstate_dropout,
                           actor_dropout=args.actor_dropout,
                           input_dropout=args.input_dropout,
                           teacher_forcing=args.dec_teacher_forcing)

        # dropouts
        self.vis_dropout = nn.Dropout(args.vis_dropout)
        self.lang_dropout = nn.Dropout(args.lang_dropout, inplace=True)
        self.input_dropout = nn.Dropout(args.input_dropout)

        # internal states
        self.state_t = None
        self.e_t = None
        self.test_mode = False

        # bce reconstruction loss
        self.bce_with_logits = torch.nn.BCEWithLogitsLoss(reduction='none')
        self.mse_loss = torch.nn.MSELoss(reduction='none')

        # paths
        self.root_path = os.getcwd()

        # reset model
        self.reset()
コード例 #2
0
ファイル: seq2seq_im_mask.py プロジェクト: kolbytn/alfred
    def __init__(self, args, vocab, manager=None):
        '''
        Seq2Seq agent
        '''
        super().__init__(args, vocab, manager)

        # encoder and self-attention
        self.enc = nn.LSTM(args.demb,
                           args.dhid,
                           bidirectional=True,
                           batch_first=True)
        self.enc_att = vnn.SelfAttn(args.dhid * 2)

        # subgoal monitoring
        self.subgoal_monitoring = (self.args.pm_aux_loss_wt > 0
                                   or self.args.subgoal_aux_loss_wt > 0)

        # frame mask decoder
        decoder = vnn.ConvFrameMaskDecoderProgressMonitor if self.subgoal_monitoring else vnn.ConvFrameMaskDecoder
        self.dec = decoder(self.emb_action_low,
                           args.dframe,
                           2 * args.dhid,
                           pframe=args.pframe,
                           attn_dropout=args.attn_dropout,
                           hstate_dropout=args.hstate_dropout,
                           actor_dropout=args.actor_dropout,
                           input_dropout=args.input_dropout,
                           teacher_forcing=args.dec_teacher_forcing)

        # dropouts
        self.vis_dropout = nn.Dropout(args.vis_dropout)
        self.lang_dropout = nn.Dropout(args.lang_dropout, inplace=True)
        self.input_dropout = nn.Dropout(args.input_dropout)

        # internal states
        self.state_t = None
        self.e_t = None
        self.test_mode = False

        # bce reconstruction loss
        self.bce_with_logits = torch.nn.BCEWithLogitsLoss(reduction='none')
        self.mse_loss = torch.nn.MSELoss(reduction='none')

        # paths
        self.root_path = os.getcwd()
        self.feat_pt = 'feat_conv.pt'

        # params
        self.max_subgoals = 25

        # reset model
        self.reset()
コード例 #3
0
    def __init__(self, args, vocab):
        '''
        Seq2Seq agent
        '''
        super().__init__(args, vocab)

        # encoder and self-attention
        if args.use_bert:
            self.bert = AutoModel.from_pretrained(args.bert_model)
            self.max_length = args.max_length
            bert_config = AutoConfig.from_pretrained(args.bert_model)

            # update hidden dimension to bert dim
            args.dhid = bert_config.hidden_size // 2
        else:
            self.enc_goal = nn.LSTM(args.demb,
                                    args.dhid,
                                    bidirectional=True,
                                    batch_first=True)
            self.enc_instr = nn.LSTM(args.demb,
                                     args.dhid,
                                     bidirectional=True,
                                     batch_first=True)
        self.enc_att_goal = vnn.SelfAttn(args.dhid * 2)
        self.enc_att_instr = vnn.SelfAttn(args.dhid * 2)

        # subgoal monitoring
        self.subgoal_monitoring = (self.args.pm_aux_loss_wt > 0
                                   or self.args.subgoal_aux_loss_wt > 0)

        # frame mask decoder
        decoder = vnn.ConvFrameMaskDecoderProgressMonitor if self.subgoal_monitoring else vnn.ConvFrameMaskDecoder
        self.dec = decoder(self.emb_action_low,
                           args.dframe,
                           2 * args.dhid,
                           pframe=args.pframe,
                           attn_dropout=args.attn_dropout,
                           hstate_dropout=args.hstate_dropout,
                           actor_dropout=args.actor_dropout,
                           input_dropout=args.input_dropout,
                           teacher_forcing=args.dec_teacher_forcing)
        self.use_bert = args.use_bert
        # dropouts
        self.vis_dropout = nn.Dropout(args.vis_dropout)
        self.lang_dropout = nn.Dropout(args.lang_dropout, inplace=True)
        self.input_dropout = nn.Dropout(args.input_dropout)

        # internal states
        self.state_t = None
        self.e_t = None
        self.test_mode = False

        # bce reconstruction loss
        self.bce_with_logits = torch.nn.BCEWithLogitsLoss(reduction='none')
        self.mse_loss = torch.nn.MSELoss(reduction='none')
        self.ce_loss = torch.nn.CrossEntropyLoss()

        # paths
        self.root_path = os.getcwd()
        self.feat_pt = 'feat_conv.pt'

        # params
        self.max_subgoals = 25

        # reset model
        self.reset()

        args.visual_model = 'resnet18'
        self.resnet = Resnet(args)
コード例 #4
0
    def __init__(self, args, vocab):
        '''
        Modular Seq2Seq agent
        '''
        super().__init__(args, vocab)

        # subgoal monitoring
        self.subgoal_monitoring = (self.args.pm_aux_loss_wt > 0
                                   or self.args.subgoal_aux_loss_wt > 0)

        # Individual network for each of the 8 submodules.
        self.submodules = nn.ModuleList(
            [Seq2SeqIM(args, vocab) for i in range(8)])

        # Dictionary from submodule names to idx.
        self.submodule_names = [
            'PAD', 'GotoLocation', 'PickupObject', 'PutObject', 'CoolObject',
            'HeatObject', 'CleanObject', 'SliceObject', 'ToggleObject', 'NoOp'
        ]

        self.high_vocab = Vocab(self.submodule_names)

        # Embeddings for high-level actions.
        self.emb_action_high = nn.Embedding(len(self.high_vocab), args.demb)

        # end tokens
        self.stop_token = self.high_vocab.word2index("NoOp", train=False)

        # internal states
        self.state_t = None
        self.e_t = None
        self.test_mode = False

        # encoder and self-attention
        self.enc = nn.LSTM(args.demb,
                           args.dhid,
                           bidirectional=True,
                           batch_first=True)
        self.enc_att = vnn.SelfAttn(args.dhid * 2)

        # frame decoder (no masks)
        decoder = vnn.ConvFrameDecoderProgressMonitor if self.subgoal_monitoring else vnn.ConvFrameDecoder

        self.dec = decoder(self.emb_action_high,
                           args.dframe,
                           2 * args.dhid,
                           pframe=args.pframe,
                           attn_dropout=args.attn_dropout,
                           hstate_dropout=args.hstate_dropout,
                           actor_dropout=args.actor_dropout,
                           input_dropout=args.input_dropout,
                           teacher_forcing=args.dec_teacher_forcing)

        # dropouts
        self.vis_dropout = nn.Dropout(args.vis_dropout)
        self.lang_dropout = nn.Dropout(args.lang_dropout, inplace=True)
        self.input_dropout = nn.Dropout(args.input_dropout)

        # bce reconstruction loss
        self.bce_with_logits = torch.nn.BCEWithLogitsLoss(reduction='none')
        self.mse_loss = torch.nn.MSELoss(reduction='none')

        # paths
        self.root_path = os.getcwd()
        self.feat_pt = 'feat_conv.pt'

        # params
        self.max_subgoals = 25

        # reset model
        self.reset()