parser.add_argument("--finetune_batch_size", help="batch size when finetune", type=int, default=10)
parser.add_argument("--finetune_lr", help="learning rate when finetune", type=float, default=0.05)
parser.add_argument("--finetune_wd", help="weight decay rate when finetune", type=float, default=1e-5)
parser.add_argument("--finetune_weight", help="loss weight of negative samples", type=float, default=0.2)

# inference batch_size
parser.add_argument("--infer_batch_size", help="batch size when inference", type=int, default=0)

# print
parser.add_argument("--print_debug", help="print debug information", action="store_true")
parser.add_argument("--eval", help="eval during snowball", action="store_true")

args = parser.parse_args()

max_length = 90
train_train_data_loader = DataLoader('./data/train_train.json', vocab='./data/bert-base-uncased/vocab.txt', max_length=max_length)
train_val_data_loader = DataLoader('./data/train_val.json', vocab='./data/bert-base-uncased/vocab.txt', max_length=max_length)
val_data_loader = DataLoader('./data/val.json', vocab='./data/bert-base-uncased/vocab.txt', max_length=max_length)
test_data_loader = DataLoader('./data/val.json', vocab='./data/bert-base-uncased/vocab.txt', max_length=max_length)
distant = DataLoader('./data/distant.json', vocab='./data/bert-base-uncased/vocab.txt', max_length=max_length, distant=True)

framework = nrekit.framework.Framework(train_val_data_loader, val_data_loader, test_data_loader, distant)
sentence_encoder = nrekit.sentence_encoder.BERTSentenceEncoder('./data/bert-base-uncased')
sentence_encoder2 = nrekit.sentence_encoder.BERTSentenceEncoder('./data/bert-base-uncased')

model2 = models.snowball.Siamese(sentence_encoder2, hidden_size=768)
model = models.snowball.Snowball(sentence_encoder, base_class=train_train_data_loader.rel_tot, siamese_model=model2, hidden_size=768, neg_loader=train_train_data_loader, args=args)

# load pretrain
checkpoint = torch.load('./checkpoint/bert_encoder_on_fewrel.pth.tar')['state_dict']
checkpoint2 = torch.load('./checkpoint/bert_siamese_on_fewrel.pth.tar')['state_dict']
import models
import nrekit
import sys
from torch import optim
from nrekit.data_loader_bert import JSONFileDataLoaderBERT as DataLoader

from pytorch_pretrained_bert import BertAdam

max_length = 90
train_data_loader = DataLoader('./data/train_train.json',
                               vocab='./data/bert_vocab.txt',
                               max_length=max_length)
val_data_loader = DataLoader('./data/train_val.json',
                             vocab='./data/bert_vocab.txt',
                             max_length=max_length,
                             rel2id=train_data_loader.rel2id,
                             shuffle=False)

framework = nrekit.framework.SuperviseFramework(train_data_loader,
                                                val_data_loader)
sentence_encoder = nrekit.sentence_encoder.BERTSentenceEncoder(
    './data/bert-base-uncased')
model = models.snowball.Siamese(sentence_encoder,
                                hidden_size=768,
                                drop_rate=0.1)

model_name = 'bert_siamese_euc_on_fewrel'

# set optimizer
batch_size = 32
train_iter = 30000
import models
import nrekit
import sys
import torch
from torch import optim
from nrekit.data_loader_bert import JSONFileDataLoaderBERT as DataLoader
import argparse
import numpy as np

max_length = 90
train_train_data_loader = DataLoader('./data/train_train.json', vocab='./data/bert_vocab.txt', max_length=max_length)
train_val_data_loader = DataLoader('./data/train_val.json', vocab='./data/bert_vocab.txt', max_length=max_length)
val_data_loader = DataLoader('./data/val.json', vocab='./data/bert_vocab.txt', max_length=max_length)
test_data_loader = DataLoader('./data/test.json', vocab='./data/bert_vocab.txt', max_length=max_length)
# distant = DataLoader('./data/distant.json', vocab='./data/bert_vocab.txt', max_length=max_length, distant=True)
distant=None

framework = nrekit.framework.Framework(train_val_data_loader, val_data_loader, test_data_loader, distant)
sentence_encoder = nrekit.sentence_encoder.BERTSentenceEncoder('./data/bert-base-uncased')
sentence_encoder2 = nrekit.sentence_encoder.BERTSentenceEncoder('./data/bert-base-uncased')

bert_encoder_repre = torch.from_numpy(np.load('./_repre/bert_encoder_on_fewrel.npy')).cuda()
bert_siamese_repre=None
# bert_siamese_repre = torch.from_numpy(np.load('./_repre/bert_siamese_on_fewrel.npy')).cuda()

model2 = models.snowball.Siamese(sentence_encoder2, hidden_size=768, pre_rep=bert_siamese_repre)
model = models.snowball.Snowball(sentence_encoder, base_class=train_train_data_loader.rel_tot, siamese_model=model2, hidden_size=768, neg_loader=train_train_data_loader, pre_rep=bert_encoder_repre)

# load pretrain
checkpoint = torch.load('./checkpoint/bert_encoder_on_fewrel.pth.tar.bak')['state_dict']
checkpoint2 = torch.load('./checkpoint/bert_siamese_on_fewrel.pth.tar.bak')['state_dict']