Пример #1
0
def main():
    '''Main function'''

    parser = argparse.ArgumentParser()

    # Learning
    parser.add_argument(
        '-lr', type=float, default=2.5e-4
    )  # Learning rate: 2.5e-4 for Friends and EmotionPush, 1e-4 for IEMOCAP
    parser.add_argument('-decay', type=float,
                        default=math.pow(0.5,
                                         1 / 40))  # half lr every 20 epochs
    parser.add_argument('-epochs', type=int, default=200)  # Defualt epochs 200
    parser.add_argument(
        '-patience',
        type=int,
        default=10,  # Patience of early stopping 10 epochs
        help='patience for early stopping')
    parser.add_argument(
        '-save_dir',
        type=str,
        default=
        "../../data/higru_bert_data/models",  # Save the model and results in snapshot/
        help='where to save the models')
    # Data
    parser.add_argument('-dataset',
                        type=str,
                        default='Teaching0',
                        help='dataset')
    parser.add_argument('-data_path',
                        type=str,
                        required=True,
                        help='data path')
    parser.add_argument('-vocab_path',
                        type=str,
                        required=True,
                        help='vocabulary path')
    parser.add_argument('-emodict_path',
                        type=str,
                        required=True,
                        help='emotion label dict path')
    parser.add_argument('-tr_emodict_path',
                        type=str,
                        default=None,
                        help='training set emodict path')
    parser.add_argument(
        '-max_seq_len',
        type=int,
        default=80,  # Pad each utterance to 80 tokens
        help='the sequence length')
    # model
    parser.add_argument(
        '-label_type',
        type=str,
        default='coarse',
        help='particular type pf labels used i.e coarse/fine/resistance')

    parser.add_argument(
        '-type',
        type=str,
        default='higru',  # Model type: default HiGRU 
        help='choose the low encoder')
    parser.add_argument(
        '-d_word_vec',
        type=int,
        default=300,  # Embeddings size 300
        help='the word embeddings size')
    parser.add_argument(
        '-d_h1',
        type=int,
        default=300,  # Lower-level RNN hidden state size 300
        help='the hidden size of rnn1')
    parser.add_argument(
        '-d_h2',
        type=int,
        default=300,  # Upper-level RNN hidden state size 300
        help='the hidden size of rnn1')
    parser.add_argument(
        '-d_fc',
        type=int,
        default=100,  # FC size 100
        help='the size of fc')
    parser.add_argument(
        '-gpu',
        type=str,
        default=None,  # Spcify the GPU for training
        help='gpu: default 0')
    parser.add_argument(
        '-embedding',
        type=str,
        default=None,  # Stored embedding path
        help='filename of embedding pickle')
    parser.add_argument(
        '-report_loss',
        type=int,
        default=720,  # Report loss interval, default the number of dialogues
        help='how many steps to report loss')
    parser.add_argument(
        '-bert',
        type=int,
        default=0,  # Report loss interval, default the number of dialogues
        help='include bert or not')

    # parser.add_argument('-mask', type=str, default='all',	# Choice of mask for ER, EE, or all
    # 					help='include mask type')

    # parser.add_argument('-alpha', type=float, default=0.9,	# proportion of the loss , 0.9 means the loss for the Face acts
    # 					help='include mask type')

    # parser.add_argument('-interpret', type=str, default='single_loss', # combined trainable loss,
    # 					help ='name of the file to be saved')

    # parser.add_argument('-ldm', type=int, default=1, help = 'how many last utterances used for the donor loss contribution') # last donor mask
    # parser.add_argument('-don_model', type=int, default=1, help = 'how to compute the donation probability') # last donor mask

    # parser.add_argument('-thresh_reg', type=float, default=0.0, help = 'how to choose threshold for the models 2 and 3') # last donor mask

    parser.add_argument('-bert_train',
                        type=int,
                        default=0,
                        help='choose 0 or 1')  # last donor mask

    parser.add_argument('-addn_features',
                        type=str,
                        default='all',
                        help='include all possible features'
                        )  # include the features to be used for training

    parser.add_argument('-seed', type=int, default=100, help='set random seed')

    args = parser.parse_args()
    print(args, '\n')

    seed_everything(args.seed)
    if 'resisting' in args.dataset:
        # feature_dim_dict = {'vad_features': 3, 'affect_features': 4, 'emo_features': 10, 'liwc_features': 64, 'sentiment_features': 3, 'face_features': 8, 'norm_er_strategies': 10, 'norm_er_DAs': 17, 'ee_DAs': 23, 'all': 3+4+10+64+3+8+10+17+23}
        feature_dim_dict = {
            'vad_features': 3,
            'affect_features': 4,
            'emo_features': 10,
            'liwc_features': 64,
            'sentiment_features': 3,
            'all': 3 + 4 + 10 + 64 + 3
        }

    elif 'negotiation' in args.dataset:
        feature_dim_dict = {
            'vad_features': 3,
            'affect_features': 4,
            'emo_features': 10,
            'liwc_features': 64,
            'sentiment_features': 3,
            'all': 3 + 4 + 10 + 64 + 3
        }

    feature_dim = 0
    if args.addn_features in feature_dim_dict:
        feature_dim = feature_dim_dict[args.addn_features]

    # Load vocabs
    print("Loading vocabulary...")
    worddict = Utils.loadFrPickle(args.vocab_path)
    print("Loading emotion label dict...")
    emodict = Utils.loadFrPickle(args.emodict_path)
    print("Loading review tr_emodict...")
    tr_emodict = Utils.loadFrPickle(args.tr_emodict_path)

    # Load data field
    print("Loading field...")
    field = Utils.loadFrPickle(args.data_path)

    test_loader = field['test']

    # import pdb; pdb.set_trace()
    trainable = False
    if args.bert_train == 1:
        trainable = True

    # Initialize word embeddings
    print("Initializing word embeddings...")
    embedding = nn.Embedding(worddict.n_words,
                             args.d_word_vec,
                             padding_idx=Const.PAD)
    if args.d_word_vec == 300:
        if args.embedding != None and os.path.isfile(args.embedding):
            print("Loading previous saved embeddings")
            np_embedding = Utils.loadFrPickle(args.embedding)
        else:
            np_embedding = Utils.load_pretrain(args.d_word_vec,
                                               worddict,
                                               type='glove')
            Utils.saveToPickle(args.embedding, np_embedding)
        embedding.weight.data.copy_(torch.from_numpy(np_embedding))
    embedding.weight.requires_grad = trainable
    # pu.db
    # Choose the model
    if args.type.startswith('bert-higru-basic') or args.type.startswith(
            'only-higru-basic'):
        print("Training the bert basic model")
        model = BERT_HiGRU_basic(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)
    elif args.type.startswith('bert-cnn') or args.type.startswith('only-cnn'):
        print("Training the bert cnn model")
        model = BERT_CNN(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)
    elif args.type.startswith('bert-higru-base') or args.type.startswith(
            'only-higru-base'):
        print("Training the higru baseline model")
        model = BERT_HiGRU_base(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)
    elif args.type.startswith('combo-f'):
        print("Training the combo model")
        model_bin = combo_bin(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type + "_bin",
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)
        model_multi = combo_multi(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type + "_multi",
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)
    elif args.type.startswith('combo'):
        print("Training the combo model")
        model_bin = combo_bin(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type + "_bin",
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)
        model_multi = combo_multi(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type + "_multi",
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)
    elif args.type.startswith('bert-lstm'):
        print("Word-level BiGRU baseline")
        model = BERT_LSTM(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)

    elif args.type.startswith('bert-higru-sent-conn-mask-mid'):
        print("Training sentence-based masking with mid connect")
        model = BERT_HiGRU_sent_conn_mask_mid(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)
    elif args.type.startswith('bert-higru-sent-attn-mask'):
        print("Training sentence-based masked attention")
        model = BERT_HiGRU_sent_attn_mask(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)
    elif args.type.startswith('bert-higru-sent-conn-mask-turn-noenc'):
        print(
            "Training sentence-based masked connections intraturn without enc")
        model = BERT_HiGRU_sent_conn_mask_turn_noenc(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)
    elif args.type.startswith('bert-higru-sent-conn-mask-interturn'):
        print("Training sentence-based masked connections with enc inter turn")
        model = BERT_HiGRU_sent_conn_mask_interturn(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)
    elif args.type.startswith('bert-higru-sent-conn-mask-turn'):
        print("Training sentence-based masked connections intraturn")
        model = BERT_HiGRU_sent_conn_mask_turn(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)

    elif args.type.startswith('bert-higru-sent-conn-mask'):
        print("Training sentence-based masked connections")
        model = BERT_HiGRU_sent_conn_mask(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)

    # Choose the model
    elif args.type.startswith('bert-higru-sent-attn-2'):
        print("Training sentence-based attention second")
        model = BERT_HiGRU_sent_attn_2(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)

    elif args.type.startswith('bert-higru-sent-attn'):
        print("Training sentence-based attention")
        model = BERT_HiGRU_sent_attn(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)

    elif args.type.startswith('bert-higru-uttr-attn-2'):
        print("Training utterance-based attention double level")
        model = BERT_HiGRU_uttr_attn_2(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)

    elif args.type.startswith('bert-higru-uttr-attn-3'):
        print("Training utterance-based attention second level only")
        model = BERT_HiGRU_uttr_attn_3(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)

    elif args.type.startswith('bert-higru-uttr-attn'):
        print("Training utterance-based attention")
        model = BERT_HiGRU_uttr_attn(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)

    elif args.type.startswith('bert-higru') or args.type.startswith(
            'only-higru'):
        model = BERT_HiGRU(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:],
            # bert_flag= args.bert,
            # don_model= args.don_model,
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)
        #speaker_flag= args.sf)

    elif args.type.startswith('higru'):
        model = HiGRU(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type,
            # bert= args.bert,
            # don_model= args.don_model
            feature_dim=feature_dim,
            long_bert=args.bert)

    elif args.type.startswith('bert-bigru'):
        model = BERT_BiGRU(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type[5:].replace('bigru', 'higru'),
            # bert_flag= args.bert,
            # don_model= args.don_model
            trainable=trainable,
            feature_dim=feature_dim,
            long_bert=args.bert)

    elif args.type.startswith('bigru'):
        model = BiGRU(
            d_word_vec=args.d_word_vec,
            d_h1=args.d_h1,
            d_h2=args.d_h2,
            d_fc=args.d_fc,
            emodict=emodict,
            worddict=worddict,
            embedding=embedding,
            type=args.type.replace('bigru', 'higru'),
            # bert= args.bert,
            # don_model= args.don_model
            feature_dim=feature_dim,
            long_bert=args.bert)

    # elif args.type.startswith('bert-gru'):
    # 	model = BERT_BiGRU(d_word_vec=args.d_word_vec,
    # 				  d_h1=args.d_h1,
    # 				  d_h2=args.d_h2,
    # 				  d_fc=args.d_fc,
    # 				  emodict=emodict,
    # 				  worddict=worddict,
    # 				  embedding=embedding,
    # 				  type=args.type[5:].replace('gru','higru'),
    # 				  bert_flag= args.bert,
    # 				  don_model= args.don_model)

    # Choose focused emotions

    focus_emo = []

    # Train the model
    if args.type.startswith('combo'):
        emotrain_combo(model_bin=model_bin,
                       model_multi=model_multi,
                       data_loader=field,
                       tr_emodict=tr_emodict,
                       emodict=emodict,
                       args=args,
                       focus_emo=focus_emo)
    else:
        emotrain(model=model,
                 data_loader=field,
                 tr_emodict=tr_emodict,
                 emodict=emodict,
                 args=args,
                 focus_emo=focus_emo)

    # Load the best model to test
    print("Load best models for testing!")

    file_str = Utils.return_file_path(args)
    # model = model.load_state_dict(args.save_dir+'/'+file_str+'.pt', map_location='cpu')
    if args.type.startswith('combo'):
        model_bin = torch.load(args.save_dir + '/' + file_str +
                               '_model_bin.pt',
                               map_location='cpu')
        model_multi = torch.load(args.save_dir + '/' + file_str +
                                 '_model_multi.pt',
                                 map_location='cpu')
        pAccs, acc, mf1, = emoeval_combo(model_bin=model_bin,
                                         model_multi=model_multi,
                                         data_loader=test_loader,
                                         tr_emodict=tr_emodict,
                                         emodict=emodict,
                                         args=args,
                                         focus_emo=focus_emo)
    else:
        model = torch.load(args.save_dir + '/' + file_str + '_model.pt',
                           map_location='cpu')
        # model = torch.load_state_dict(args.save_dir+'/'+file_str+'.pt', map_location='cpu')
        pAccs, acc, mf1, = emoeval(model=model,
                                   data_loader=test_loader,
                                   tr_emodict=tr_emodict,
                                   emodict=emodict,
                                   args=args,
                                   focus_emo=focus_emo)

    print("Test: ACCs-WA-UWA {}".format(pAccs))
    print("Accuracy = {}, F1 = {}".format(acc, mf1))

    # Save the test results
    record_file = '{}/{}_{}.txt'.format(args.save_dir, args.type,
                                        args.label_type)
    if os.path.isfile(record_file):
        f_rec = open(record_file, "a")
    else:
        f_rec = open(record_file, "w")

    f_rec.write("{} - {} - {}\t:\t{}\n".format(datetime.now(), args.d_h1,
                                               args.lr, pAccs))
    f_rec.close()
Пример #2
0
def main():
    '''Main function'''

    parser = argparse.ArgumentParser()

    # learning
    parser.add_argument('-lr', type=float, default=2e-4)
    parser.add_argument('-decay', type=float, default=0.75)
    parser.add_argument('-batch_size', type=int, default=16)
    parser.add_argument('-epochs', type=int, default=60)
    parser.add_argument('-patience',
                        type=int,
                        default=5,
                        help='patience for early stopping')
    parser.add_argument('-save_dir',
                        type=str,
                        default="snapshot",
                        help='where to save the models')
    # data
    parser.add_argument('-dataset',
                        type=str,
                        default='Friends',
                        help='dataset')
    parser.add_argument('-data_path',
                        type=str,
                        required=True,
                        help='data path')
    parser.add_argument('-vocab_path',
                        type=str,
                        required=True,
                        help='global vocabulary path')
    parser.add_argument('-emodict_path',
                        type=str,
                        required=True,
                        help='emotion label dict path')
    parser.add_argument('-tr_emodict_path',
                        type=str,
                        default=None,
                        help='training set emodict path')
    parser.add_argument(
        '-max_seq_len',
        type=int,
        default=60,  # 60 for emotion
        help='the sequence length')
    # model
    parser.add_argument('-sentEnc',
                        type=str,
                        default='gru2',
                        help='choose the low encoder')
    parser.add_argument('-contEnc',
                        type=str,
                        default='gru',
                        help='choose the mid encoder')
    parser.add_argument('-dec',
                        type=str,
                        default='dec',
                        help='choose the classifier')
    parser.add_argument('-d_word_vec',
                        type=int,
                        default=300,
                        help='the word embeddings size')
    parser.add_argument('-d_hidden_low',
                        type=int,
                        default=300,
                        help='the hidden size of rnn1')
    parser.add_argument('-d_hidden_up',
                        type=int,
                        default=300,
                        help='the hidden size of rnn1')
    parser.add_argument('-layers',
                        type=int,
                        default=1,
                        help='the num of stacked GRU layers')
    parser.add_argument('-d_fc', type=int, default=100, help='the size of fc')
    parser.add_argument('-gpu', type=str, default=None, help='gpu: default 0')
    parser.add_argument('-embedding',
                        type=str,
                        default=None,
                        help='filename of embedding pickle')
    parser.add_argument('-report_loss',
                        type=int,
                        default=720,
                        help='how many steps to report loss')
    parser.add_argument('-load_model',
                        action='store_true',
                        help='load the pretrained model')

    args = parser.parse_args()
    print(args, '\n')

    # load vocabs
    print("Loading vocabulary...")
    glob_vocab = Utils.loadFrPickle(args.vocab_path)
    print("Loading emotion label dict...")
    emodict = Utils.loadFrPickle(args.emodict_path)
    print("Loading review tr_emodict...")
    tr_emodict = Utils.loadFrPickle(args.tr_emodict_path)

    # load field
    print("Loading field...")
    field = Utils.loadFrPickle(args.data_path)
    test_loader = field['test']

    # word embedding
    print("Initializing word embeddings...")
    embedding = nn.Embedding(glob_vocab.n_words,
                             args.d_word_vec,
                             padding_idx=Const.PAD)
    if args.d_word_vec == 300:
        if args.embedding != None and os.path.isfile(args.embedding):
            np_embedding = Utils.loadFrPickle(args.embedding)
        else:
            np_embedding = Utils.load_pretrain(args.d_word_vec,
                                               glob_vocab,
                                               type='glove')
            Utils.saveToPickle("embedding.pt", np_embedding)
        embedding.weight.data.copy_(torch.from_numpy(np_embedding))
    embedding.max_norm = 1.0
    embedding.norm_type = 2.0
    embedding.weight.requires_grad = False

    # word to vec
    wordenc = Modules.wordEncoder(embedding=embedding)
    # sent to vec
    sentenc = Modules.sentEncoder(d_input=args.d_word_vec,
                                  d_output=args.d_hidden_low)
    if args.sentEnc == 'gru2':
        print("Utterance encoder: GRU2")
        sentenc = Modules.sentGRUEncoder(d_input=args.d_word_vec,
                                         d_output=args.d_hidden_low)
    if args.layers == 2:
        print("Number of stacked GRU layers: {}".format(args.layers))
        sentenc = Modules.sentGRU2LEncoder(d_input=args.d_word_vec,
                                           d_output=args.d_hidden_low)
    # cont
    contenc = Modules.contEncoder(d_input=args.d_hidden_low,
                                  d_output=args.d_hidden_up)
    # decoder
    emodec = Modules.mlpDecoder(d_input=args.d_hidden_low +
                                args.d_hidden_up * 2,
                                d_output=args.d_fc,
                                n_class=emodict.n_words)

    if args.load_model:
        print('Load in pretrained model...')
        wordenc = torch.load("snapshot/wordenc_OpSub_" +
                             str(args.d_hidden_low) + "_" +
                             str(args.d_hidden_up) + ".pt",
                             map_location='cpu')  #
        sentenc = torch.load("snapshot/sentenc_OpSub_" +
                             str(args.d_hidden_low) + "_" +
                             str(args.d_hidden_up) + ".pt",
                             map_location='cpu')
        contenc = torch.load("snapshot/contenc_OpSub_" +
                             str(args.d_hidden_low) + "_" +
                             str(args.d_hidden_up) + ".pt",
                             map_location='cpu')
        # freeze the pretrained parameters
        for p1 in wordenc.parameters():
            p1.requires_grad = False

    # Choose focused emotions
    focus_emo = Const.four_emo
    args.decay = 0.75
    if args.dataset == 'IEMOCAP4v2':
        focus_emo = Const.four_iem
        args.decay = 0.95
    if args.dataset == 'MELD':
        focus_emo = Const.sev_meld
    if args.dataset == 'EmoryNLP':
        focus_emo = Const.sev_emory
    if args.dataset == 'MOSEI':
        focus_emo = Const.six_mosei
    if args.dataset == 'MOSI':
        focus_emo = Const.two_mosi
    print("Focused emotion labels {}".format(focus_emo))

    emotrain(wordenc=wordenc,
             sentenc=sentenc,
             contenc=contenc,
             dec=emodec,
             data_loader=field,
             tr_emodict=tr_emodict,
             emodict=emodict,
             args=args,
             focus_emo=focus_emo)

    # test
    print("Load best models for testing!")

    wordenc = Utils.revmodel_loader(args.save_dir, 'wordenc', args.dataset,
                                    args.load_model)
    sentenc = Utils.revmodel_loader(args.save_dir, 'sentenc', args.dataset,
                                    args.load_model)
    contenc = Utils.revmodel_loader(args.save_dir, 'contenc', args.dataset,
                                    args.load_model)
    emodec = Utils.revmodel_loader(args.save_dir, 'dec', args.dataset,
                                   args.load_model)
    pAccs = emoeval(wordenc=wordenc,
                    sentenc=sentenc,
                    contenc=contenc,
                    dec=emodec,
                    data_loader=test_loader,
                    tr_emodict=tr_emodict,
                    emodict=emodict,
                    args=args,
                    focus_emo=focus_emo)
    print("Test: ACCs-F1s-WA-UWA-F1-val {}".format(pAccs))

    # record the test results
    record_file = '{}/{}_{}_finetune?{}.txt'.format(args.save_dir, "record",
                                                    args.dataset,
                                                    str(args.load_model))
    if os.path.isfile(record_file):
        f_rec = open(record_file, "a")
    else:
        f_rec = open(record_file, "w")
    f_rec.write("{} - {} - {}\t:\t{}\n".format(datetime.now(),
                                               args.d_hidden_low, args.lr,
                                               pAccs))
    f_rec.close()
Пример #3
0
def main():
	'''Main function'''

	parser = argparse.ArgumentParser()

	# Learning
	parser.add_argument('-lr', type=float, default=2.5e-4)		# Learning rate: 2.5e-4 for Friends and EmotionPush, 1e-4 for IEMOCAP
	parser.add_argument('-decay', type=float, default=math.pow(0.5, 1/20))	# half lr every 20 epochs
	parser.add_argument('-epochs', type=int, default=200)		# Defualt epochs 200
	parser.add_argument('-patience', type=int, default=10,		# Patience of early stopping 10 epochs
	                    help='patience for early stopping')
	parser.add_argument('-save_dir', type=str, default="snapshot",	# Save the model and results in snapshot/
	                    help='where to save the models')
	# Data
	parser.add_argument('-dataset', type=str, default='Friends',	# Default dataset Friends
	                    help='dataset')
	parser.add_argument('-data_path', type=str, required = True
	                    help='data path')
	parser.add_argument('-vocab_path', type=str, required=True,
	                    help='vocabulary path')
	parser.add_argument('-emodict_path', type=str, required=True,
	                    help='emotion label dict path')
	parser.add_argument('-tr_emodict_path', type=str, default=None,
	                    help='training set emodict path')
	parser.add_argument('-max_seq_len', type=int, default=80,	# Pad each utterance to 80 tokens
	                    help='the sequence length')
	# model
	parser.add_argument('-type', type=str, default='higru', 	# Model type: default HiGRU 
	                    help='choose the low encoder')
	parser.add_argument('-d_word_vec', type=int, default=300,	# Embeddings size 300
	                    help='the word embeddings size')
	parser.add_argument('-d_h1', type=int, default=300,		# Lower-level RNN hidden state size 300
	                    help='the hidden size of rnn1')
	parser.add_argument('-d_h2', type=int, default=300,		# Upper-level RNN hidden state size 300
	                    help='the hidden size of rnn1')
	parser.add_argument('-d_fc', type=int, default=100,		# FC size 100
	                    help='the size of fc')
	parser.add_argument('-gpu', type=str, default=None,		# Spcify the GPU for training
	                    help='gpu: default 0')
	parser.add_argument('-embedding', type=str, default=None,	# Stored embedding path
	                    help='filename of embedding pickle')
	parser.add_argument('-report_loss', type=int, default=720,	# Report loss interval, default the number of dialogues
	                    help='how many steps to report loss')

	args = parser.parse_args()
	print(args, '\n')

	# Load vocabs
	print("Loading vocabulary...")
	worddict = Utils.loadFrPickle(args.vocab_path)
	print("Loading emotion label dict...")
	emodict = Utils.loadFrPickle(args.emodict_path)
	print("Loading review tr_emodict...")
	tr_emodict = Utils.loadFrPickle(args.tr_emodict_path)

	# Load data field
	print("Loading field...")
	field = Utils.loadFrPickle(args.data_path)
	test_loader = field['test']

	# Initialize word embeddings
	print("Initializing word embeddings...")
	embedding = nn.Embedding(worddict.n_words, args.d_word_vec, padding_idx=Const.PAD)
	if args.d_word_vec == 300:
		if args.embedding != None and os.path.isfile(args.embedding):
			np_embedding = Utils.loadFrPickle(args.embedding)
		else:
			np_embedding = Utils.load_pretrain(args.d_word_vec, worddict, type='word2vec')
			Utils.saveToPickle(args.dataset + '_embedding.pt', np_embedding)
		embedding.weight.data.copy_(torch.from_numpy(np_embedding))
	embedding.weight.requires_grad = False

	# Choose the model
	model = HiGRU(d_word_vec=args.d_word_vec,
	              d_h1=args.d_h1,
	              d_h2=args.d_h2,
	              d_fc=args.d_fc,
	              emodict=emodict,
	              worddict=worddict,
	              embedding=embedding,
	              type=args.type)

	# Choose focused emotions
	focus_emo = Const.four_emo
	if args.dataset == 'IEMOCAP':
		focus_emo = Const.four_iem
	print("Focused emotion labels {}".format(focus_emo))

	# Train the model
	emotrain(model=model,
	         data_loader=field,
	         tr_emodict=tr_emodict,
	         emodict=emodict,
	         args=args,
	         focus_emo=focus_emo)

	# Load the best model to test
	print("Load best models for testing!")
	model = Utils.model_loader(args.save_dir, args.type, args.dataset)
	pAccs = emoeval(model=model,
	                data_loader=test_loader,
	                tr_emodict=tr_emodict,
	                emodict=emodict,
	                args=args,
	                focus_emo=focus_emo)
	print("Test: ACCs-WA-UWA {}".format(pAccs))

	# Save the test results
	record_file = '{}/{}_{}.txt'.format(args.save_dir, args.type, args.dataset)
	if os.path.isfile(record_file):
		f_rec = open(record_file, "a")
	else:
		f_rec = open(record_file, "w")
	f_rec.write("{} - {} - {}\t:\t{}\n".format(datetime.now(), args.d_h1, args.lr, pAccs))
	f_rec.close()
Пример #4
0
def main():
    '''Main function'''

    parser = argparse.ArgumentParser()

    # learning
    parser.add_argument(
        '-lr', type=float,
        default=5e-4)  # 2.5e-4 for Friends and EmotionPush, 1e-4 for IEMOCAP
    parser.add_argument('-decay', type=float, default=0.95)
    parser.add_argument('-maxnorm', type=float, default=5)
    parser.add_argument('-epochs', type=int, default=60)
    parser.add_argument('-batch_size', type=int, default=1)
    parser.add_argument('-patience',
                        type=int,
                        default=10,
                        help='patience for early stopping')
    parser.add_argument('-save_dir',
                        type=str,
                        default="snapshot",
                        help='where to save the models')
    # data
    parser.add_argument('-dataset',
                        type=str,
                        default='Friends',
                        help='dataset')
    parser.add_argument('-data_path',
                        type=str,
                        required=True,
                        help='data path')
    parser.add_argument('-vocab_path',
                        type=str,
                        required=True,
                        help='global vocabulary path')
    parser.add_argument('-emodict_path',
                        type=str,
                        required=True,
                        help='emotion label dict path')
    parser.add_argument('-tr_emodict_path',
                        type=str,
                        default=None,
                        help='training set emodict path')
    parser.add_argument('-max_seq_len',
                        type=int,
                        default=60,
                        help='the sequence length')
    # model
    parser.add_argument('-type',
                        type=str,
                        default='gru',
                        help='the model type: gru')
    parser.add_argument('-d_word_vec',
                        type=int,
                        default=300,
                        help='the word embeddings size')
    parser.add_argument('-d_h1',
                        type=int,
                        default=100,
                        help='the hidden size of rnn1')
    parser.add_argument('-d_h2',
                        type=int,
                        default=100,
                        help='the hidden size of rnn2')
    parser.add_argument('-hops',
                        type=int,
                        default=1,
                        help='the number of hops')
    parser.add_argument('-d_fc', type=int, default=100, help='the size of fc')
    parser.add_argument('-wind1',
                        type=int,
                        default=40,
                        help='the word-level context window')
    parser.add_argument('-gpu', type=str, default=None, help='gpu: default 0')
    parser.add_argument('-embedding',
                        type=str,
                        default=None,
                        help='filename of embedding pickle')
    parser.add_argument('-report_loss',
                        type=int,
                        default=720,
                        help='how many steps to report loss')

    args = parser.parse_args()
    print(args, '\n')

    # load vocabs
    print("Loading vocabulary...")
    worddict = Utils.loadFrPickle(args.vocab_path)
    print("Loading emotion label dict...")
    emodict = Utils.loadFrPickle(args.emodict_path)
    print("Loading review tr_emodict...")
    tr_emodict = Utils.loadFrPickle(args.tr_emodict_path)

    # load field
    print("Loading field...")
    field = Utils.loadFrPickle(args.data_path)
    test_loader = field['test']

    # word embedding
    print("Initializing word embeddings...")
    embedding = nn.Embedding(worddict.n_words,
                             args.d_word_vec,
                             padding_idx=Const.PAD)
    if args.d_word_vec == 300:
        if args.embedding != None and os.path.isfile(args.embedding):
            np_embedding = Utils.loadFrPickle(args.embedding)
        else:
            np_embedding = Utils.load_pretrain(args.d_word_vec,
                                               worddict,
                                               type='word2vec')
            Utils.saveToPickle(args.dataset + '_embedding.pt', np_embedding)
        embedding.weight.data.copy_(torch.from_numpy(np_embedding))
    embedding.weight.requires_grad = False

    # model type
    model = BiGRU(emodict=emodict,
                  worddict=worddict,
                  embedding=embedding,
                  args=args)
    if args.type in ['BiGRU']:
        print("Model type confirmed: BiGRU")
        model = BiGRU(emodict=emodict,
                      worddict=worddict,
                      embedding=embedding,
                      args=args)
    if args.type in ['CNN']:
        print("Model type confirmed: CNN")
        model = CNN(emodict=emodict,
                    worddict=worddict,
                    embedding=embedding,
                    args=args)
    if args.type in ['cLSTM']:
        print("Model type confirmed: cLSTM")
        model = cLSTM(emodict=emodict,
                      worddict=worddict,
                      embedding=embedding,
                      args=args)
    if args.type in ['CMN']:
        print("Model type confirmed: CMN")
        model = CMN(emodict=emodict,
                    worddict=worddict,
                    embedding=embedding,
                    args=args)

    if args.type in ['UniF_Att']:
        print("Model type confirmed: UniF_Att")
        model = UniF_Att(emodict=emodict,
                         worddict=worddict,
                         embedding=embedding,
                         args=args)
    if args.type in ['BiF_Att']:
        print("Model type confirmed: BiF_Att")
        model = BiF_Att(emodict=emodict,
                        worddict=worddict,
                        embedding=embedding,
                        args=args)
    if args.type in ['UniF_AGRU']:
        print("Model type confirmed: UniF_AGRU")
        model = UniF_AGRU(emodict=emodict,
                          worddict=worddict,
                          embedding=embedding,
                          args=args)
    if args.type in ['UniF_BiAGRU']:
        print("Model type confirmed: UniF_BiAGRU")
        model = UniF_BiAGRU(emodict=emodict,
                            worddict=worddict,
                            embedding=embedding,
                            args=args)
    if args.type in ['BiF_AGRU']:
        print("Model type confirmed: BiF_AGRU")
        model = BiF_AGRU(emodict=emodict,
                         worddict=worddict,
                         embedding=embedding,
                         args=args)
    if args.type in ['BiF_BiAGRU']:
        print("Model type confirmed: BiF_BiAGRU")
        model = BiF_BiAGRU(emodict=emodict,
                           worddict=worddict,
                           embedding=embedding,
                           args=args)

    # test utterance reader
    if args.type in ['UniF_BiAGRU_CNN']:
        print("Model type confirmed: UniF_BiAGRU_CNN")
        model = UniF_BiAGRU_CNN(emodict=emodict,
                                worddict=worddict,
                                embedding=embedding,
                                args=args)
    if args.type in ['BiF_AGRU_CNN']:
        print("Model type confirmed: BiF_AGRU_CNN")
        model = BiF_AGRU_CNN(emodict=emodict,
                             worddict=worddict,
                             embedding=embedding,
                             args=args)
    if args.type in ['UniF_BiAGRU_LSTM']:
        print("Model type confirmed: UniF_BiAGRU_LSTM")
        model = UniF_BiAGRU_LSTM(emodict=emodict,
                                 worddict=worddict,
                                 embedding=embedding,
                                 args=args)
    if args.type in ['BiF_AGRU_LSTM']:
        print("Model type confirmed: UniF_Att_CNN")
        model = UniF_Att_CNN(emodict=emodict,
                             worddict=worddict,
                             embedding=embedding,
                             args=args)
    if args.type in ['UniF_Att_CNN']:
        print("Model type confirmed: UniF_BiAGRU_LSTM")
        model = UniF_Att_CNN(emodict=emodict,
                             worddict=worddict,
                             embedding=embedding,
                             args=args)
    if args.type in ['BiF_Att_CNN']:
        print("Model type confirmed: BiF_Att_CNN")
        model = BiF_Att_CNN(emodict=emodict,
                            worddict=worddict,
                            embedding=embedding,
                            args=args)

    # Choose focused emotions
    focus_emo = Const.five_emo
    args.decay = 0.95
    if args.dataset == 'IEMOCAP6':
        focus_emo = Const.six_iem
        args.decay = 0.95
    if args.dataset == 'MELD':
        focus_emo = Const.sev_meld
        args.decay = 0.95
    print("Focused emotion labels {}".format(focus_emo))

    emotrain(model=model,
             data_loader=field,
             tr_emodict=tr_emodict,
             emodict=emodict,
             args=args,
             focus_emo=focus_emo)

    # test
    print("Load best models for testing!")

    model = Utils.model_loader(args.save_dir, args.type, args.dataset)
    Recall, Precision, F1, Avgs, Val_loss = emoeval(model=model,
                                                    data_loader=test_loader,
                                                    tr_emodict=tr_emodict,
                                                    emodict=emodict,
                                                    args=args,
                                                    focus_emo=focus_emo)
    print("Test: val_loss {}\n re {}\n pr {}\n F1 {}\n Av {}\n".format(
        Val_loss, Recall, Precision, F1, Avgs))

    # record the test results
    record_file = '{}/{}_{}.txt'.format(args.save_dir, args.type, args.dataset)
    if os.path.isfile(record_file):
        f_rec = open(record_file, "a")
    else:
        f_rec = open(record_file, "w")
    f_rec.write(
        "{} - {}-{} - {}:\t\n \tre {}\n \tpr {}\n \tF1 {}\n \tAv {}\n\n".
        format(datetime.now(), args.hops, args.wind1, args.lr, Recall,
               Precision, F1, Avgs))
    f_rec.close()