Beispiel #1
0
    def __init__(self, params, net_path=None, **kargs):
        super(TrackerSiamRPNBIG, self).__init__(name='SiamRPN',
                                                is_deterministic=True)
        '''setup model'''
        self.net = SiamRPN()
        self.data_loader = TrainDataLoader(self.net, params)
        '''setup GPU device if available'''
        self.cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:0' if self.cuda else 'cpu')
        if self.cuda:
            self.net.cuda()

        if net_path is not None:
            self.net.load_state_dict(
                torch.load(net_path,
                           map_location=lambda storage, loc: storage))

        self.net.eval()
Beispiel #2
0
def train():
    data_loader = TrainDataLoader()
    net = Net(student_n, exer_n, knowledge_n)

    net = net.to(device)
    optimizer = optim.Adam(net.parameters(), lr=0.002)
    print('training model...')

    loss_function = nn.NLLLoss()
    for epoch in range(epoch_n):
        data_loader.reset()
        running_loss = 0.0
        batch_count = 0
        while not data_loader.is_end():
            batch_count += 1
            input_stu_ids, input_exer_ids, input_knowledge_embs, labels = data_loader.next_batch(
            )
            input_stu_ids, input_exer_ids, input_knowledge_embs, labels = input_stu_ids.to(
                device), input_exer_ids.to(device), input_knowledge_embs.to(
                    device), labels.to(device)
            optimizer.zero_grad()
            output_1 = net.forward(input_stu_ids, input_exer_ids,
                                   input_knowledge_embs)
            output_0 = torch.ones(output_1.size()).to(device) - output_1
            output = torch.cat((output_0, output_1), 1)

            # grad_penalty = 0
            loss = loss_function(torch.log(output), labels)
            loss.backward()
            optimizer.step()
            net.apply_clipper()

            running_loss += loss.item()
            if batch_count % 200 == 199:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, batch_count + 1, running_loss / 200))
                running_loss = 0.0

        # validate and save current model every epoch
        rmse, auc = validate(net, epoch)
        save_snapshot(net, 'model/model_epoch' + str(epoch + 1))
def test(args):
    if args.load_var:
        test_utterances, test_labels, word_dict = read_data(
            load_var=args.load_var, input_=None, mode='test')
    else:
        test_utterances, test_labels, word_dict = read_data(load_var=args.load_var, \
                input_=os.path.join(constant.data_path, "entangled_{}.json".format(args.mode)), mode='test')

    if args.save_input:
        utils.save_or_read_input(os.path.join(constant.save_input_path, "{}_utterances.pk".format(args.mode)), \
                                    rw='w', input_obj=test_utterances)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "{}_labels.pk".format(args.mode)), \
                                    rw='w', input_obj=test_labels)

    current_time = re.findall('.*model_(.+?)/.*', args.model_path)[0]
    step_cnt = re.findall('.step_(.+?)\.pkl', args.model_path)[0]

    test_dataloader = TrainDataLoader(test_utterances,
                                      test_labels,
                                      word_dict,
                                      name='test',
                                      batch_size=4)

    ensemble_model = EnsembleModel(word_dict,
                                   word_emb=None,
                                   bidirectional=False)
    if torch.cuda.is_available():
        ensemble_model.cuda()

    supervised_trainer = SupervisedTrainer(args,
                                           ensemble_model,
                                           current_time=current_time)

    supervised_trainer.test(test_dataloader,
                            args.model_path,
                            step_cnt=step_cnt)
Beispiel #4
0
def main():
    """ train dataloader """
    args = parser.parse_args()
    data_loader = TrainDataLoader(args.train_path, check=True)
    if not os.path.exists(args.weight_dir):
        os.makedirs(args.weight_dir)
    """ compute max_batches """
    for root, dirs, files in os.walk(args.train_path):
        for dirname in dirs:
            dir_path = os.path.join(root, dirname)
            args.max_batches += len(os.listdir(dir_path))
    inter = args.max_batches // 10
    print('Max batches:{} in one epoch '.format(args.max_batches))
    """ Model on gpu """
    model = SiameseRPN()
    model = model.cuda()
    cudnn.benchmark = True
    """ loss and optimizer """
    criterion = MultiBoxLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    """ load weights """
    init_weights(model)
    if not args.checkpoint_path == None:
        assert os.path.isfile(
            args.checkpoint_path), '{} is not valid checkpoint_path'.format(
                args.checkpoint_path)
        try:
            checkpoint = torch.load(args.checkpoint_path)
            start = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        except:
            start = 0
            init_weights(model)
    else:
        start = 0
    """ train phase """
    closses, rlosses, tlosses = AverageMeter(), AverageMeter(), AverageMeter()
    for epoch in range(start, args.max_epoches):
        cur_lr = adjust_learning_rate(args.lr, optimizer, epoch, gamma=0.1)
        index_list = range(data_loader.__len__())
        #for example in range(args.max_batches):
        for example in range(900):
            ret = data_loader.__get__(random.choice(index_list))
            template = ret['template_tensor'].cuda()
            detection = ret['detection_tensor'].cuda()
            pos_neg_diff = ret['pos_neg_diff_tensor'].cuda(
            ) if ret['pos_neg_diff_tensor'] is not None else None

            cout, rout = model(template, detection)

            predictions = (cout, rout)
            targets = pos_neg_diff

            area = ret['area_target_in_resized_detection']
            num_pos = len(np.where(pos_neg_diff == 1)[0])
            if area == 0 or num_pos == 0 or pos_neg_diff is None:
                continue

            closs, rloss, loss, reg_pred, reg_target, pos_index, neg_index = criterion(
                predictions, targets)

            # debug for class
            cout = cout.squeeze().permute(1, 2, 0).reshape(-1, 2)
            cout = cout.cpu().detach().numpy()
            print(cout.shape)
            score = 1 / (1 + np.exp(cout[:, 0] - cout[:, 1]))
            print(score[pos_index])
            print(score[neg_index])
            #time.sleep(1)

            # debug for reg
            tmp_dir = '/home/song/srpn/tmp/visualization/7_train_debug_pos_anchors'
            if not os.path.exists(tmp_dir):
                os.makedirs(tmp_dir)
            detection = ret['detection_cropped_resized'].copy()
            draw = ImageDraw.Draw(detection)
            pos_anchors = ret['pos_anchors'].copy()

            # pos anchor的回归情况
            x = pos_anchors[:, 0] + pos_anchors[:, 2] * reg_pred[
                pos_index, 0].cpu().detach().numpy()
            y = pos_anchors[:, 1] + pos_anchors[:, 3] * reg_pred[
                pos_index, 1].cpu().detach().numpy()
            w = pos_anchors[:, 2] * np.exp(reg_pred[pos_index,
                                                    2].cpu().detach().numpy())
            h = pos_anchors[:, 3] * np.exp(reg_pred[pos_index,
                                                    3].cpu().detach().numpy())
            x1s, y1s, x2s, y2s = x - w // 2, y - h // 2, x + w // 2, y + h // 2
            for i in range(2):
                x1, y1, x2, y2 = x1s[i], y1s[i], x2s[i], y2s[i]
                draw.line([(x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)],
                          width=1,
                          fill='red')  #predict

            # 应当的gt
            x = pos_anchors[:, 0] + pos_anchors[:, 2] * reg_target[
                pos_index, 0].cpu().detach().numpy()
            y = pos_anchors[:, 1] + pos_anchors[:, 3] * reg_target[
                pos_index, 1].cpu().detach().numpy()
            w = pos_anchors[:, 2] * np.exp(
                reg_target[pos_index, 2].cpu().detach().numpy())
            h = pos_anchors[:, 3] * np.exp(
                reg_target[pos_index, 3].cpu().detach().numpy())
            x1s, y1s, x2s, y2s = x - w // 2, y - h // 2, x + w // 2, y + h // 2
            for i in range(2):
                x1, y1, x2, y2 = x1s[i], y1s[i], x2s[i], y2s[i]
                draw.line([(x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)],
                          width=1,
                          fill='green')  #gt

            # 找分数zui da de,
            m_indexs = np.argsort(score)[::-1][:5]
            for m_index in m_indexs:
                diff = reg_pred[m_index].cpu().detach().numpy()
                anc = ret['anchors'][m_index]
                x = anc[0] + anc[0] * diff[0]
                y = anc[1] + anc[1] * diff[1]
                w = anc[2] * np.exp(diff[2])
                h = anc[3] * np.exp(diff[3])
                x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
                draw.line([(x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)],
                          width=2,
                          fill='black')

            save_path = osp.join(
                tmp_dir,
                'epoch_{:04d}_{:04d}_{:02d}.jpg'.format(epoch, example, i))
            detection.save(save_path)

            closs_ = closs.cpu().item()
            if np.isnan(closs_):
                sys.exit(0)

            #loss = closs + rloss
            closses.update(closs.cpu().item())
            rlosses.update(rloss.cpu().item())
            tlosses.update(loss.cpu().item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #time.sleep(1)

            print(
                "Epoch:{:04d}\texample:{:08d}/{:08d}({:.2f})\tlr:{:.7f}\tcloss:{:.4f}\trloss:{:.4f}\ttloss:{:.4f}"
                .format(epoch, example + 1, args.max_batches,
                        100 * (example + 1) / args.max_batches, cur_lr,
                        closses.avg, rlosses.avg, tlosses.avg))

        if epoch % 5 == 0:
            file_path = os.path.join(
                args.weight_dir, 'epoch_{:04d}_weights.pth.tar'.format(epoch))
            state = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(state, file_path)
Beispiel #5
0
 def init(self, image, box):
     """ dataloader """
     self.data_loader = TrainDataLoader(image, box)
     self.box = box
Beispiel #6
0
def main():
    args = parser.parse_args()

    """ train dataloader """
    data_loader = TrainDataLoader("C:\\Users\\sport\\Desktop\\SiamMask-Pytorch\\DAVIS-4\\JPEGImages\\480p")

    print('-')

    """ compute max_batches """
    for root, dirs, files in os.walk(args.train_path):
        for dirname in dirs:
            dir_path = os.path.join(root, dirname)
            args.max_batches += len(os.listdir(dir_path))

    # Setup Model
    cfg = load_config(args)
    from experiments.siammask.custom import Custom
    model = Custom(anchors=cfg['anchors'])
    if args.resume:
        assert isfile(args.resume), '{} is not a valid file'.format(args.resume)
        model = load_pretrain(model, args.resume)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    model.eval().to(device)
    cudnn.benchmark = True

    """ loss and optimizer """
    criterion = MultiBoxLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay = args.weight_decay)

    """ train phase """
    closses, rlosses, tlosses = AverageMeter(), AverageMeter(), AverageMeter()
    steps = 0
    start = 0
    for epoch in range(start, args.max_epoches):
        cur_lr = adjust_learning_rate(args.lr, optimizer, epoch, gamma=0.1)
        index_list = range(data_loader.__len__()) 
        for example in range(args.max_batches):
            ret = data_loader.__get__(random.choice(index_list)) 
            template = ret['template_tensor'].to(device)
            detection= ret['detection_tensor'].to(device)
            mask_target = ret['mask_template_tensor'].to(device)
            pos_neg_diff = ret['pos_neg_diff_tensor'].to(device)
            cout, rout, mask = model(template, detection)
            predictions, targets = (cout, rout, mask), pos_neg_diff
            closs, rloss, mloss, loss, reg_pred, reg_target, pos_index, neg_index = criterion(predictions, targets, mask_target)
            closs_ = closs.cpu().item()

            if np.isnan(closs_): 
               sys.exit(0)

            closses.update(closs.cpu().item())
            rlosses.update(rloss.cpu().item())
            tlosses.update(loss.cpu().item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            steps += 1

                        
            cout = cout.cpu().detach().numpy()
            # score = 1/(1 + np.exp(cout[:,0]-cout[:,1]))
            print("Epoch:{:04d}\texample:{:06d}/{:06d}({:.2f})%\tsteps:{:010d}\tlr:{:.7f}\tcloss:ss:{:.4f}\ttloss:{:.4f}".format(epoch, example+1, args.max_batches, 100*(example+1)/args.max_batches, steps, cur_lr, closses.avg, rlosses.avg, tlosses.avg ))

        if epoch % 5 == 0 :
            file_path = os.path.join(args.weight_dir, 'epoch_{:04d}_weights.pth.tar'.format(epoch))
            state = {
            'epoch' :epoch+1,
            'state_dict' :model.state_dict(),
            'optimizer' : optimizer.state_dict(),
            }
            torch.save(state, file_path)
Beispiel #7
0
def main():
    """ train dataloader """
    args = parser.parse_args()
    data_loader = TrainDataLoader(args.train_path, check = args.debug)
    if not os.path.exists(args.weight_dir):
        os.makedirs(args.weight_dir)

    """ compute max_batches """
    for root, dirs, files in os.walk(args.train_path):
        for dirname in dirs:
            dir_path = os.path.join(root, dirname)
            args.max_batches += len(os.listdir(dir_path))
    print('max_batches: {}'.format(args.max_batches))
    """ Model on gpu """
    model = SiameseRPN()
    model = model.cuda()
    cudnn.benchmark = True

    """ loss and optimizer """
    criterion = MultiBoxLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay = args.weight_decay)

    """ load weights """
    init_weights(model)
    if not args.checkpoint_path == None:
        assert os.path.isfile(args.checkpoint_path), '{} is not valid checkpoint_path'.format(args.checkpoint_path)
        try:
            checkpoint = torch.load(args.checkpoint_path)
            start = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        except:
            start = 0
            init_weights(model)
    else:
        start = 0

    """ train phase """
    closses, rlosses, tlosses = AverageMeter(), AverageMeter(), AverageMeter()
    steps = 0
    #print('data_loader length: {}'.format(len(data_loader)))
    for epoch in range(start, args.max_epoches):
        cur_lr = adjust_learning_rate(args.lr, optimizer, epoch, gamma=0.1)
        index_list = range(data_loader.__len__())
        example_index = 0
        for example in range(args.max_batches):
            ret = data_loader.__get__(random.choice(index_list))
            template = ret['template_tensor'].cuda()
            detection= ret['detection_tensor'].cuda()
            pos_neg_diff = ret['pos_neg_diff_tensor'].cuda()
            cout, rout = model(template, detection)
            predictions, targets = (cout, rout), pos_neg_diff
            closs, rloss, loss, reg_pred, reg_target, pos_index, neg_index = criterion(predictions, targets)
            closs_ = closs.cpu().item()

            if np.isnan(closs_):
               sys.exit(0)

            closses.update(closs.cpu().item())
            rlosses.update(rloss.cpu().item())
            tlosses.update(loss.cpu().item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            steps += 1

            cout = cout.cpu().detach().numpy()
            score = 1/(1 + np.exp(cout[:,0]-cout[:,1]))

            # ++++++++++++ post process below just for debug ++++++++++++++++++++++++
            # ++++++++++++++++++++ v1.0 add penalty +++++++++++++++++++++++++++++++++
            if ret['pos_anchors'] is not None:
                penalty_k = 0.055
                tx, ty, tw, th = ret['template_target_xywh'].copy()
                tw *= ret['template_cropprd_resized_ratio']
                th *= ret['template_cropprd_resized_ratio']

                anchors = ret['anchors'].copy()
                w = anchors[:,2] * np.exp(reg_pred[:, 2].cpu().detach().numpy())
                h = anchors[:,3] * np.exp(reg_pred[:, 3].cpu().detach().numpy())

                eps = 1e-2
                change_w = np.maximum(w/(tw+eps), tw/(w+eps))
                change_h = np.maximum(h/(th+eps), th/(h+eps))
                penalty = np.exp(-(change_w + change_h - 1) * penalty_k)
                pscore = score * penalty
            else:
                pscore = score

            # +++++++++++++++++++ v1.0 add window default cosine ++++++++++++++++++++++
            score_size = 17
            window_influence = 0.42
            window = (np.outer(np.hanning(score_size), np.hanning(score_size)).reshape(17,17,1) + np.zeros((1, 1, 5))).reshape(-1)
            pscore = pscore * (1 - window_influence) + window * window_influence
            score_old = score
            score = pscore #from 0.2 - 0.7

            # ++++++++++++++++++++ debug for class ++++++++++++++++++++++++++++++++++++
            if example_index%1000 == 0:
                print(score[pos_index])  # this should tend to be 1
                print(score[neg_index])  # this should tend to be 0


            # ++++++++++++++++++++ debug for reg ++++++++++++++++++++++++++++++++++++++
            tmp_dir = '/home/ly/chz/Siamese-RPN-pytorch/code_v1.0/tmp/visualization/7_check_train_phase_debug_pos_anchors'
            if not os.path.exists(tmp_dir):
                os.makedirs(tmp_dir)
            detection = ret['detection_cropped_resized'].copy()
            draw = ImageDraw.Draw(detection)
            pos_anchors = ret['pos_anchors'].copy() if ret['pos_anchors'] is not None else None

            if pos_anchors is not None:
                # draw pos anchors
                x = pos_anchors[:, 0]
                y = pos_anchors[:, 1]
                w = pos_anchors[:, 2]
                h = pos_anchors[:, 3]
                x1s, y1s, x2s, y2s = x - w//2, y - h//2, x + w//2, y + h//2
                for i in range(16):
                    x1, y1, x2, y2 = x1s[i], y1s[i], x2s[i], y2s[i]
                    draw.line([(x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)], width=1, fill='white') # pos anchor

                # pos anchor transform to red box after prediction
                x = pos_anchors[:,0] + pos_anchors[:, 2] * reg_pred[pos_index, 0].cpu().detach().numpy()
                y = pos_anchors[:,1] + pos_anchors[:, 3] * reg_pred[pos_index, 1].cpu().detach().numpy()
                w = pos_anchors[:,2] * np.exp(reg_pred[pos_index, 2].cpu().detach().numpy())
                h = pos_anchors[:,3] * np.exp(reg_pred[pos_index, 3].cpu().detach().numpy())
                x1s, y1s, x2s, y2s = x - w//2, y - h//2, x + w//2, y + h//2
                for i in range(16):
                    x1, y1, x2, y2 = x1s[i], y1s[i], x2s[i], y2s[i]
                    draw.line([(x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)], width=1, fill='red')  # predict(white -> red)

                # pos anchor should be transformed to green gt box, if red and green is same, it is overfitting
                x = pos_anchors[:,0] + pos_anchors[:, 2] * reg_target[pos_index, 0].cpu().detach().numpy()
                y = pos_anchors[:,1] + pos_anchors[:, 3] * reg_target[pos_index, 1].cpu().detach().numpy()
                w = pos_anchors[:,2] * np.exp(reg_target[pos_index, 2].cpu().detach().numpy())
                h = pos_anchors[:,3] * np.exp(reg_target[pos_index, 3].cpu().detach().numpy())
                x1s, y1s, x2s, y2s = x - w//2, y-h//2, x + w//2, y + h//2
                for i in range(16):
                    x1, y1, x2, y2 = x1s[i], y1s[i], x2s[i], y2s[i]
                    draw.line([(x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)], width=1, fill='green') # gt  (white -> green)
                x1, y1, x3, y3 = x1s[0], y1s[0], x2s[0], y2s[0]
            else:
                x1, y1, x3, y3 = 0, 0, 0, 0
            # top1 proposal after nms (white)
            if example_index%1000 == 0:
                save_path = osp.join(tmp_dir, 'epoch_{:010d}_{:010d}_anchor_pred.jpg'.format(epoch, example))
                detection.save(save_path)
            example_index = example_index+1

            # +++++++++++++++++++ v1.0 restore ++++++++++++++++++++++++++++++++++++++++
            ratio = ret['detection_cropped_resized_ratio']
            detection_cropped = ret['detection_cropped'].copy()
            detection_cropped_resized = ret['detection_cropped_resized'].copy()
            original = Image.open(ret['detection_img_path'])
            x_, y_ = ret['detection_tlcords_of_original_image']
            draw = ImageDraw.Draw(original)
            w, h = original.size
            """ un resized """
            x1, y1, x3, y3 = x1/ratio, y1/ratio, y3/ratio, y3/ratio

            """ un cropped """
            x1 = np.clip(x_ + x1, 0, w-1).astype(np.int32) # uncropped #target_of_original_img
            y1 = np.clip(y_ + y1, 0, h-1).astype(np.int32)
            x3 = np.clip(x_ + x3, 0, w-1).astype(np.int32)
            y3 = np.clip(y_ + y3, 0, h-1).astype(np.int32)

            draw.line([(x1, y1), (x3, y1), (x3, y3), (x1, y3), (x1, y1)], width=3, fill='yellow')
            #save_path = osp.join(tmp_dir, 'epoch_{:010d}_{:010d}_restore.jpg'.format(epoch, example))
            #original.save(save_path)

            print("Epoch:{:04d}\texample:{:06d}/{:06d}({:.2f})%\tsteps:{:010d}\tlr:{:.7f}\tcloss:{:.4f}\trloss:{:.4f}\ttloss:{:.4f}".format(epoch, example+1, args.max_batches, 100*(example+1)/args.max_batches, steps, cur_lr, closses.avg, rlosses.avg, tlosses.avg ))

        if steps % 1 == 0:
            file_path = os.path.join(args.weight_dir, 'weights-{:07d}.pth.tar'.format(steps))
            state = {
            'epoch' :epoch+1,
            'state_dict' :model.state_dict(),
            'optimizer' : optimizer.state_dict(),
            }
            torch.save(state, file_path)
def main(_):
    # Setting up tensorflow
    tf.logging.set_verbosity(tf.logging.INFO)
    sess = tf.InteractiveSession()
    # Create preprocessor and data loader
    preprocessor = Preprocessor(feature_count=40,
                                window_size_ms=20,
                                window_stride_ms=10)
    dl = TrainDataLoader(FLAGS.data_dir, preprocessor)
    label_count = len(dl.classes)

    # Parse experiment settings
    settings = pd.read_csv(FLAGS.settings)

    # Create input and call models to create the output tensors
    fingerprint_input = tf.placeholder(tf.float32,
                                       [None, preprocessor.fingerprint_size],
                                       name='fingerprint_input')
    fingerprint_input_4d = tf.reshape(
        fingerprint_input,
        [-1, preprocessor.feature_count, preprocessor.window_number, 1])
    logits, dropout_prob = create_model(FLAGS.architecture,
                                        fingerprint_input_4d,
                                        {'label_count': label_count},
                                        is_training=True)

    # Create following tensors
    ground_truth_input = tf.placeholder(tf.float32, [None, label_count],
                                        name='groundtruth_input')

    # Cross entropy with summary
    with tf.name_scope('cross_entropy'):
        cross_entropy_mean = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(labels=ground_truth_input,
                                                    logits=logits))
    tf.summary.scalar('cross_entropy', cross_entropy_mean)

    # Declare optimizer and learning rate
    learning_rate_input = tf.placeholder(tf.float32, [],
                                         name='learning_rate_input')
    if FLAGS.optimizer == 'gd':
        train_step = tf.train.GradientDescentOptimizer(
            learning_rate_input).minimize(cross_entropy_mean)
    elif FLAGS.optimizer == 'adam':
        train_step = tf.train.AdamOptimizer(learning_rate_input).minimize(
            cross_entropy_mean)
    else:
        raise Exception('Optimizer not recognized')

    predicted_indices = tf.argmax(logits, 1)
    expected_indices = tf.argmax(ground_truth_input, 1)
    correct_prediction = tf.equal(predicted_indices, expected_indices)
    confusion_matrix = tf.confusion_matrix(expected_indices,
                                           predicted_indices,
                                           num_classes=label_count)
    evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    tf.summary.scalar('accuracy', evaluation_step)

    # Create global step
    global_step = tf.train.get_or_create_global_step()
    increment_global_step = tf.assign(global_step, global_step + 1)

    # Create saver to save model
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=40)

    # Merge summaries and write them to dir
    merged_summaries = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(
        FLAGS.summaries_dir + '/' + FLAGS.log_alias + '_train', sess.graph)
    validation_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/' +
                                              FLAGS.log_alias + '_validation')

    # Init global variables and start step
    tf.global_variables_initializer().run()
    start_step = 1

    # Check if there is a checkpoint to restore
    if FLAGS.start_checkpoint:
        _saver = tf.train.Saver(tf.global_variables())
        _saver.restore(sess, FLAGS.start_checkpoint)
        start_step = global_step.eval(session=sess)
    tf.logging.info('Training from step: %d ', start_step)

    # Saving graph
    tf.train.write_graph(sess.graph_def, FLAGS.train_dir,
                         FLAGS.architecture + '.pbtxt')

    # Total steps
    total_steps = np.sum(list(settings.steps))

    for training_step in range(start_step, total_steps + 1):
        # Get the current learning rate
        training_steps_sum = 0
        for i in range(len(list(settings.steps))):
            training_steps_sum += list(settings.steps)[i]
            if training_step <= training_steps_sum:
                # Get the settings
                current_settings = settings.iloc[i]
                break
        # Get the data
        samples, labels = dl.get_train_data(
            FLAGS.batch_size,
            unknown_percentage=0.1,
            silence_percentage=0.1,
            noise_volume=current_settings.background_volume,
            noise_frequency=current_settings.background_frequency_train,
            time_shift_samples=current_settings.time_shift_samples,
            time_shift_frequency=current_settings.time_shift_frequency_train)
        train_summary, train_accuracy, cross_entropy_value, _, _ = sess.run(
            [
                merged_summaries, evaluation_step, cross_entropy_mean,
                train_step, increment_global_step
            ],
            feed_dict={
                fingerprint_input: samples,
                ground_truth_input: labels,
                learning_rate_input: current_settings.learning_rate,
                dropout_prob: 0.8
            })
        train_writer.add_summary(train_summary, training_step)
        tf.logging.info(
            'Step #%d: rate %f, accuracy %.1f%%, cross entropy %f' %
            (training_step, current_settings.learning_rate,
             train_accuracy * 100, cross_entropy_value))

        # Check if we need to check validation now
        is_last_step = (training_step == total_steps)
        if (training_step % FLAGS.evaluation_step == 0) or is_last_step:
            validation_size = dl.get_validation_size()
            total_accuracy = 0
            total_conf_matrix = None
            for i in range(0, validation_size, FLAGS.batch_size):
                validation_samples, validation_labels = dl.get_validation_data(
                    FLAGS.batch_size,
                    i,
                    unknown_percentage=0.1,
                    silence_percentage=0.1,
                    noise_volume=current_settings.background_volume,
                    noise_frequency=current_settings.
                    background_frequency_validation,
                    time_shift_samples=current_settings.time_shift_samples,
                    time_shift_frequency=current_settings.
                    time_shift_frequency_validation)
                validation_summary, validation_accuracy, conf_matrix = sess.run(
                    [merged_summaries, evaluation_step, confusion_matrix],
                    feed_dict={
                        fingerprint_input: validation_samples,
                        ground_truth_input: validation_labels,
                        dropout_prob: 1.0
                    })
                validation_writer.add_summary(validation_summary,
                                              training_step)
                batch_size = min(FLAGS.batch_size, validation_size - i)
                total_accuracy += (validation_accuracy *
                                   batch_size) / validation_size
                if total_conf_matrix is None:
                    total_conf_matrix = conf_matrix
                else:
                    total_conf_matrix += conf_matrix
            tf.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix))
            tf.logging.info(
                'Step %d: Validation accuracy = %.1f%% (N=%d)' %
                (training_step, total_accuracy * 100, validation_size))
        # Save the model checkpoint periodically.
        if (training_step % FLAGS.save_step_interval == 0
                or training_step == total_steps):
            checkpoint_path = os.path.join(FLAGS.train_dir,
                                           FLAGS.architecture + '.ckpt')
            tf.logging.info('Saving to "%s-%d"', checkpoint_path,
                            training_step)
            saver.save(sess, checkpoint_path, global_step=training_step)
Beispiel #9
0
def main():
    """ dataloader """
    args = parser.parse_args()
    data_loader = TrainDataLoader(args.train_path, check=False)
    """ Model on gpu """
    model = SiameseRPN()
    model = model.cuda()
    cudnn.benchmark = True
    """ loss and optimizer """
    criterion = MultiBoxLoss()
    """ load weights """
    init_weights(model)
    if args.checkpoint_path == None:
        sys.exit('please input trained model')
    else:
        assert os.path.isfile(
            args.checkpoint_path), '{} is not valid checkpoint_path'.format(
                args.checkpoint_path)
        checkpoint = torch.load(args.checkpoint_path)
        start = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
    """ test phase """
    index_list = range(data_loader.__len__())
    for example in range(args.max_batches):
        ret = data_loader.__get__(random.choice(index_list))
        template = ret['template_tensor'].cuda()
        detection = ret['detection_tensor'].cuda()
        pos_neg_diff = ret['pos_neg_diff_tensor'].cuda()
        cout, rout = model(template,
                           detection)  #[1, 10, 17, 17], [1, 20, 17, 17]

        cout = cout.reshape(-1, 2)
        rout = rout.reshape(-1, 4)
        cout = cout.cpu().detach().numpy()
        print('cout size: {}'.format(cout.shape))
        score = 1 / (1 + np.exp(cout[:, 1] - cout[:, 0]))
        print('score: {}, size: {}'.format(score, score.shape))
        diff = rout.cpu().detach().numpy()  #1445

        num_proposals = 15
        score_64_index = np.argsort(score)[::-1][:num_proposals]

        print('score_64_index: {}, size: {}'.format(score_64_index,
                                                    score_64_index.shape))
        score64 = score[score_64_index]
        print('score: {}'.format(score64))
        diffs64 = diff[score_64_index, :]
        anchors64 = ret['anchors'][score_64_index]
        proposals_x = (anchors64[:, 0] +
                       anchors64[:, 2] * diffs64[:, 0]).reshape(-1, 1)
        proposals_y = (anchors64[:, 1] +
                       anchors64[:, 3] * diffs64[:, 1]).reshape(-1, 1)
        proposals_w = (anchors64[:, 2] * np.exp(diffs64[:, 2])).reshape(-1, 1)
        proposals_h = (anchors64[:, 3] * np.exp(diffs64[:, 3])).reshape(-1, 1)
        proposals = np.hstack(
            (proposals_x, proposals_y, proposals_w, proposals_h))

        d = os.path.join(ret['tmp_dir'], '6_pred_proposals')
        if not os.path.exists(d):
            os.makedirs(d)

        detection = ret['detection_cropped_resized']
        save_path = os.path.join(ret['tmp_dir'], '6_pred_proposals',
                                 '{:04d}_1_detection.jpg'.format(example))
        detection.save(save_path)

        template = ret['template_cropped_resized']
        save_path = os.path.join(ret['tmp_dir'], '6_pred_proposals',
                                 '{:04d}_0_template.jpg'.format(example))
        template.save(save_path)
        """ 可视化 """
        draw = ImageDraw.Draw(detection)
        for i in range(num_proposals):
            x, y, w, h = proposals_x[i], proposals_y[i], proposals_w[
                i], proposals_h[i]
            x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
            draw.line([(x1, y1), (x2, y1), (x2, y2), (x1, y2), (x1, y1)],
                      width=1,
                      fill='red')
        """ save detection template proposals"""
        save_path = os.path.join(ret['tmp_dir'], '6_pred_proposals',
                                 '{:04d}_2_proposals.jpg'.format(example))
        detection.save(save_path)

        print('save at {}'.format(save_path))
        # restore
        """
def main():
    """parameter initialization"""
    args = parser.parse_args()
    exp_name_dir = experiment_name_dir(args.experiment_name)
    """Load the parameters from json file"""
    json_path = os.path.join(exp_name_dir, 'parameters.json')
    assert os.path.isfile(json_path), (
        "No json configuration file found at {}".format(json_path))
    with open(json_path) as data_file:
        params = json.load(data_file)
    """ train dataloader """
    data_loader = TrainDataLoader(args.train_path)
    """ compute max_batches """
    for root, dirs, files in os.walk(args.train_path):
        for dirname in dirs:
            dir_path = os.path.join(root, dirname)
            args.max_batches += len(os.listdir(dir_path))
    """ Model on gpu """
    model = TrackerSiamRPN(params)
    #model = model.cuda()
    cudnn.benchmark = True
    """ load weights """
    init_weights(model)
    if not args.checkpoint_path == None:
        assert os.path.isfile(
            args.checkpoint_path), '{} is not valid checkpoint_path'.format(
                args.checkpoint_path)
        try:
            checkpoint = torch.load(args.checkpoint_path)
            start = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        except:
            start = 0
            init_weights(model)
    else:
        start = 0
    """ train phase """
    closses, rlosses, tlosses = AverageMeter(), AverageMeter(), AverageMeter()
    steps = 0
    for epoch in range(start, args.max_epoches):
        #cur_lr = adjust_learning_rate(params["lr"], optimizer, epoch, gamma=0.1)
        index_list = range(data_loader.__len__())
        for example in tqdm(range(1000)):  # args.max_batches
            ret = data_loader.__get__(random.choice(index_list))

            closs, rloss, loss, reg_pred, reg_target, pos_index, neg_index, cur_lr = model.step(
                ret, epoch, backward=True)

            closs_ = closs.cpu().item()

            if np.isnan(closs_):
                sys.exit(0)

            closses.update(closs.cpu().item())
            rlosses.update(rloss.cpu().item())
            tlosses.update(loss.cpu().item())
            steps += 1

            if example % 1000 == 0:
                print(
                    "Epoch:{:04d}\texample:{:06d}/{:06d}({:.2f})%\tlr:{:.7f}\tcloss:{:.4f}\trloss:{:.4f}\ttloss:{:.4f}"
                    .format((epoch + 1), steps, args.max_batches,
                            100 * (steps) / args.max_batches, cur_lr,
                            closses.avg, rlosses.avg, tlosses.avg))
        """save model"""
        model_save_dir_pth = '{}/model'.format(exp_name_dir)
        if not os.path.exists(model_save_dir_pth):
            os.makedirs(model_save_dir_pth)
        net_path = os.path.join(model_save_dir_pth,
                                'model_e%d.pth' % (epoch + 1))
        torch.save(model.net.state_dict(), net_path)
    label_batch = np.reshape(label_batch, newshape=prediction.shape)
    sample_loss = loss_object(label_batch, prediction)
    sample_loss = tf.convert_to_tensor(sample_loss)
    # Get the gradients of the loss w.r.t to the sample batch.
    gradient = tape.gradient(sample_loss,
                             sample_batch_tf,
                             unconnected_gradients='zero')

    # Get the sign of the gradients to create the perturbation
    signed_grad = [tf.sign(gradient[0]), tf.sign(gradient[1])]
    # signed_grad = gradient
    return signed_grad, prediction


train_data_loader = TrainDataLoader()
sample_batch, label_batch = train_data_loader[0]

perturbation, prediction = create_adversarial_pattern(sample_batch,
                                                      label_batch)
adv_sample_batch = [
    sample_batch[0] + H.epsilon * perturbation[0],
    sample_batch[1] + H.epsilon * perturbation[1]
]
adv_prediction = model.predict(adv_sample_batch)

prediction = [0 if pred < 0.5 else 1 for pred in prediction]
adv_prediction = [0 if pred < 0.5 else 1 for pred in adv_prediction]
misclassification_rate = np.logical_xor(
    adv_prediction, prediction).sum() / sample_batch[0].shape[0]
print("misclassification rate: {}".format(misclassification_rate))
def main():
    """ dataloader """
    args = parser.parse_args()
    data_loader = TrainDataLoader(args.test_path, out_feature=25, check=False)
    """ Model on gpu """
    model = SiameseRPN()
    model = model.cuda()
    cudnn.benchmark = True
    """ loss and optimizer """
    criterion = MultiBoxLoss()
    """ load weights """
    init_weights(model)
    if args.checkpoint_path == None:
        sys.exit('please input trained model')
    else:
        assert os.path.isfile(
            args.checkpoint_path), '{} is not valid checkpoint_path'.format(
                args.checkpoint_path)
        checkpoint = torch.load(args.checkpoint_path)
        start = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
    """ test phase """
    index_list = range(data_loader.__len__())
    threshold = 50
    precision = []
    precision_c = []
    average_error = []
    average_error_c = []
    iou = []
    iou_c = []
    for example in range(args.max_batches):
        ret = data_loader.__get__(random.choice(index_list))
        template = ret['template_tensor'].cuda()
        detection = ret['detection_tensor'].cuda()
        pos_neg_diff = ret['pos_neg_diff_tensor'].cuda()
        cout, rout = model(template,
                           detection)  #[1, 10, 17, 17], [1, 20, 17, 17]
        template_img = ret['template_cropped_transformed']
        detection_img = ret['detection_cropped_transformed']

        cout = cout.reshape(-1, 2)
        rout = rout.reshape(-1, 4)
        cout = cout.cpu().detach().numpy()
        score = 1 / (1 + np.exp(cout[:, 0] - cout[:, 1]))
        diff = rout.cpu().detach().numpy()  #1445

        num_proposals = 1
        score_64_index = np.argsort(score)[::-1][:num_proposals]

        score64 = score[score_64_index]
        diffs64 = diff[score_64_index, :]
        anchors64 = ret['anchors'][score_64_index]
        proposals_x = (anchors64[:, 0] +
                       anchors64[:, 2] * diffs64[:, 0]).reshape(-1, 1)
        proposals_y = (anchors64[:, 1] +
                       anchors64[:, 3] * diffs64[:, 1]).reshape(-1, 1)
        proposals_w = (anchors64[:, 2] * np.exp(diffs64[:, 2])).reshape(-1, 1)
        proposals_h = (anchors64[:, 3] * np.exp(diffs64[:, 3])).reshape(-1, 1)
        proposals = np.hstack(
            (proposals_x, proposals_y, proposals_w, proposals_h))

        d = os.path.join(ret['tmp_dir'], '6_pred_proposals')
        if not os.path.exists(d):
            os.makedirs(d)

        template = ret['template_cropped_transformed']
        save_path = os.path.join(ret['tmp_dir'], '6_pred_proposals',
                                 '{:04d}_0_template.jpg'.format(example))
        template.save(save_path)
        """traditional correlation match method"""
        template_img = cv2.cvtColor(np.asarray(template_img),
                                    cv2.COLOR_RGB2BGR)
        detection_img = cv2.cvtColor(np.asarray(detection_img),
                                     cv2.COLOR_RGB2BGR)
        res = cv2.matchTemplate(detection_img, template_img,
                                cv2.TM_CCOEFF_NORMED)
        min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
        x1_c = max_loc[0]
        y1_c = max_loc[1]
        """ visualization """
        ratio = ret['detection_cropped_resized_ratio']
        original = Image.open(ret['detection_img_path'])
        origin_w, origin_h = original.size
        x_, y_ = ret['detection_tlcords_of_original_image']
        draw = ImageDraw.Draw(original)
        for i in range(num_proposals):
            x, y, w, h = proposals_x[i], proposals_y[i], proposals_w[
                i], proposals_h[i]
            x1, y1, x3, y3 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
            """ un resized """
            x1, y1, x3, y3 = x1 / ratio, y1 / ratio, x3 / ratio, y3 / ratio
            x1_c, y1_c = x1_c / ratio, y1_c / ratio
            """ un cropped """
            x1_g, y1_g, w, h = ret['template_target_x1y1wh']
            x3_g = x1_g + w
            y3_g = y1_g + h

            x1 = np.clip(x_ + x1, 0, origin_w - 1).astype(
                np.int32)  # uncropped #target_of_original_img
            y1 = np.clip(y_ + y1, 0, origin_h - 1).astype(np.int32)
            x3 = np.clip(x_ + x3, 0, origin_w - 1).astype(np.int32)
            y3 = np.clip(y_ + y3, 0, origin_h - 1).astype(np.int32)

            x1_c = np.clip(x_ + x1_c, 0, origin_w - 1).astype(np.int32)
            y1_c = np.clip(y_ + y1_c, 0, origin_h - 1).astype(np.int32)
            x3_c = x1_c + ret['template_target_xywh'][2]
            y3_c = y1_c + ret['template_target_xywh'][3]

            draw.line([(x1, y1), (x3, y1), (x3, y3), (x1, y3), (x1, y1)],
                      width=3,
                      fill='yellow')
            draw.line([(x1_g, y1_g), (x3_g, y1_g), (x3_g, y3_g), (x1_g, y3_g),
                       (x1_g, y1_g)],
                      width=3,
                      fill='blue')
            draw.line([(x1_c, y1_c), (x3_c, y1_c), (x3_c, y3_c), (x1_c, y3_c),
                       (x1_c, y1_c)],
                      width=3,
                      fill='red')

        save_path = os.path.join(ret['tmp_dir'], '6_pred_proposals',
                                 '{:04d}_1_restore.jpg'.format(example))
        original.save(save_path)
        print('save at {}'.format(save_path))
        """compute iou"""
        s1 = np.array([x1, y1, x3, y1, x3, y3, x1, y3, x1, y1])
        s2 = np.array(
            [x1_g, y1_g, x3_g, y1_g, x3_g, y3_g, x1_g, y3_g, x1_g, y1_g])
        s3 = np.array(
            [x1_c, y1_c, x3_c, y1_c, x3_c, y3_c, x1_c, y3_c, x1_c, y1_c])
        iou.append(intersection(s1, s2))
        iou_c.append(intersection(s3, s2))
        """compute average error"""
        cx = (x1 + x3) / 2
        cy = (y1 + y3) / 2
        cx_g = (x1_g + x3_g) / 2
        cy_g = (y1_g + y3_g) / 2
        cx_c = (x1_c + x3_c) / 2
        cy_c = (y1_c + y3_c) / 2
        error = math.sqrt(math.pow(cx - cx_g, 2) + math.pow(cy - cy_g, 2))
        error_c = math.sqrt(math.pow(cx - cx_c, 2) + math.pow(cy - cy_c, 2))
        average_error.append(error)
        average_error_c.append(error_c)
        if error <= threshold:
            precision.append(1)
        else:
            precision.append(0)
        if error_c <= threshold:
            precision_c.append(1)
        else:
            precision_c.append(0)

    iou_mean = np.mean(np.array(iou))
    error_mean = np.mean(np.array(average_error))
    iou_mean_c = np.mean(np.array(iou_c))
    error_mean_c = np.mean(np.array(average_error_c))
    precision = np.mean(np.array(precision))
    precision_c = np.mean(np.array(precision_c))
    print('average iou: {:.4f}'.format(iou_mean))
    print('average error: {:.4f}'.format(error_mean))
    print('average iou for traditional method: {:.4f}'.format(iou_mean_c))
    print('average error for traditional method: {:.4f}'.format(error_mean_c))
    print('precision: {:.4f} @ threshold {:02d}'.format(precision, threshold))
    print('precision for traditional method: {:.4f} @ threshold {:02d}'.format(
        precision_c, threshold))
Beispiel #13
0
def main():
    parser = argparse.ArgumentParser()
    # Path parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The raw data dir.")
    parser.add_argument("--vocab_path",
                        default=None,
                        type=str,
                        required=True,
                        help="bert vocab path")
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--model_output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        help="The param init of pretrain or finetune")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")
    # Data Process Parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")
    # Model Paramters
    parser.add_argument("--sop",
                        action='store_true',
                        help="whether use sop task.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")

    # Train Eval Test Paramters

    parser.add_argument("--checkpoint_steps",
                        required=True,
                        type=int,
                        help="save model eyery checkpoint_steps")

    parser.add_argument("--total_steps",
                        required=True,
                        type=int,
                        help="all steps of training model")

    parser.add_argument("--max_checkpoint",
                        required=True,
                        type=int,
                        help="max saved model in model_output_dir")

    parser.add_argument(
        "--examples_size_once",
        type=int,
        default=1000,
        help="read how many examples every time in pretrain or finetune")

    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="process rank in local")
    parser.add_argument("--local_debug",
                        action='store_true',
                        help="whether debug")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--fine_tune",
                        action='store_true',
                        help="Whether to run fine_tune.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates   accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=str,
        default='dynamic',
        help=
        '(float or str, optional, default=None):  Optional property override.  '
        'If passed as a string,must be a string representing a number, e.g., "128.0", or the string "dynamic".'
    )
    parser.add_argument(
        '--opt_level',
        type=str,
        default='O1',
        help=
        ' (str, optional, default="O1"):  Pure or mixed precision optimization level.  '
        'Accepted values are "O0", "O1", "O2", and "O3", explained in detail above.'
    )
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )

    # Other Patameters
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--rank',
                        type=int,
                        default=0,
                        help="global rank of current process")
    parser.add_argument("--world_size",
                        default=2,
                        type=int,
                        help="Number of process(显卡)")

    args = parser.parse_args()
    cur_env = os.environ
    args.rank = int(cur_env.get('RANK', -1))
    args.world_size = int(cur_env.get('WORLD_SIZE', -1))
    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)
    assert args.train_batch_size >= 1, 'batch_size < 1 '

    # 更新一次模型参数需要多少个样本
    examples_per_update = args.world_size * args.train_batch_size * args.gradient_accumulation_steps
    args.examples_size_once = args.examples_size_once // examples_per_update * examples_per_update
    if args.fine_tune:
        args.examples_size_once = examples_per_update

    os.makedirs(args.model_output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.model_output_dir, 'unilm_config.json'),
                   'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = torch.cuda.device_count()
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=args.world_size,
                                rank=args.rank)
    logger.info(
        "world_size:{}, rank:{}, device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}"
        .format(args.world_size, args.rank, device, n_gpu,
                bool(args.world_size > 1), args.fp16))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    if not args.fine_tune and not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.vocab_path,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    if args.local_rank == 0:
        dist.barrier()
    bi_uni_pipeline = [
        Preprocess4Seq2seq(args.max_pred,
                           args.mask_prob,
                           list(tokenizer.vocab.keys()),
                           tokenizer.convert_tokens_to_ids,
                           args.max_seq_length,
                           new_segment_ids=args.new_segment_ids,
                           truncate_config={
                               'max_len_a': args.max_len_a,
                               'max_len_b': args.max_len_b,
                               'trunc_seg': args.trunc_seg,
                               'always_truncate_tail':
                               args.always_truncate_tail
                           },
                           mask_source_words=args.mask_source_words,
                           skipgram_prb=args.skipgram_prb,
                           skipgram_size=args.skipgram_size,
                           mask_whole_word=args.mask_whole_word,
                           mode="s2s",
                           has_oracle=args.has_sentence_oracle,
                           num_qkv=args.num_qkv,
                           s2s_special_token=args.s2s_special_token,
                           s2s_add_segment=args.s2s_add_segment,
                           s2s_share_segment=args.s2s_share_segment,
                           pos_shift=args.pos_shift,
                           fine_tune=args.fine_tune)
    ]
    file_oracle = None
    if args.has_sentence_oracle:
        file_oracle = os.path.join(args.data_dir, 'train.oracle')

    # t_total表示模型参数更新的次数
    # t_total = args.train_steps
    # Prepare model
    recover_step = _get_max_epoch_model(args.model_output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_model_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * args.checkpoint_step)
        # 预训练时模型的参数初始化,比如使用chinese-bert-base的模型参数进行初始化
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        model = BertForPreTrainingLossMask.from_pretrained(
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb,
            local_debug=args.local_debug)

    total_trainable_params = sum(p.numel() for p in model.parameters()
                                 if p.requires_grad)
    logger.info("模型参数: {}".format(total_trainable_params))
    if args.local_rank == 0:
        dist.barrier()

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=args.total_steps)
    if args.amp and args.fp16:
        from apex import amp
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level,
                                          loss_scale=args.loss_scale)
        from apex.parallel import DistributedDataParallel as DDP
        model = DDP(model)
    else:
        from torch.nn.parallel import DistributedDataParallel as DDP
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)

    if recover_step:
        logger.info("** ** * Recover optimizer: %d * ** **", recover_step)
        optim_recover = torch.load(os.path.join(
            args.model_output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.fp16 and args.amp:
            amp_recover = torch.load(os.path.join(
                args.model_output_dir, "amp.{0}.bin".format(recover_step)),
                                     map_location='cpu')
            logger.info("** ** * Recover amp: %d * ** **", recover_step)
            amp.load_state_dict(amp_recover)
    logger.info("** ** * CUDA.empty_cache() * ** **")
    torch.cuda.empty_cache()

    if args.rank == 0:
        writer = SummaryWriter(log_dir=args.log_dir)
    logger.info("***** Running training *****")
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Param Update Num = %d", args.total_steps)
    model.train()

    PRE = "rank{},local_rank {},".format(args.rank, args.local_rank)
    step = 1
    start = time.time()
    train_data_loader = TrainDataLoader(
        bi_uni_pipline=bi_uni_pipeline,
        examples_size_once=args.examples_size_once,
        world_size=args.world_size,
        train_batch_size=args.train_batch_size,
        num_workers=args.num_workers,
        data_dir=args.data_dir,
        tokenizer=tokenizer,
        max_len=args.max_seq_length)
    best_result = -float('inf')
    for global_step, batch in enumerate(train_data_loader, start=global_step):
        batch = [t.to(device) if t is not None else None for t in batch]
        if args.has_sentence_oracle:
            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, task_idx, sop_label, oracle_pos, oracle_weights, oracle_labels = batch
        else:
            input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, task_idx, sop_label = batch
            oracle_pos, oracle_weights, oracle_labels = None, None, None
        if not args.sop:
            # 不使用sop训练任务
            sop_label = None
        loss_tuple = model(input_ids,
                           segment_ids,
                           input_mask,
                           masked_lm_labels=lm_label_ids,
                           next_sentence_label=sop_label,
                           masked_pos=masked_pos,
                           masked_weights=masked_weights,
                           task_idx=task_idx,
                           masked_pos_2=oracle_pos,
                           masked_weights_2=oracle_weights,
                           masked_labels_2=oracle_labels,
                           mask_qkv=mask_qkv)
        masked_lm_loss, next_sentence_loss = loss_tuple
        # mean() to average on multi-gpu.
        if n_gpu > 1:
            masked_lm_loss = masked_lm_loss.mean()
            next_sentence_loss = next_sentence_loss.mean()
        # ensure that accumlated gradients are normalized
        if args.gradient_accumulation_steps > 1:
            masked_lm_loss = masked_lm_loss / args.gradient_accumulation_steps
            next_sentence_loss = next_sentence_loss / args.gradient_accumulation_steps
        if not args.sop:
            loss = masked_lm_loss
        else:
            loss = masked_lm_loss + next_sentence_loss
        if args.fp16 and args.amp:
            with amp.scale_loss(loss, optimizer) as scale_loss:
                scale_loss.backward()
        else:
            loss.backward()
        if (global_step + 1) % args.gradient_accumulation_steps == 0:
            if args.rank == 0:
                writer.add_scalar('unilm/mlm_loss', masked_lm_loss,
                                  global_step)
                writer.add_scalar('unilm/sop_loss', next_sentence_loss,
                                  global_step)
            lr_this_step = args.learning_rate * warmup_linear(
                global_step / args.total_steps, args.warmup_proportion)
            if args.fp16:
                # modify learning rate with special warm up BERT uses
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
            optimizer.step()
            optimizer.zero_grad()
            #global_step += 1
            #更新一次模型参数花费的时间,单位:秒
            cost_time_per_update = time.time() - start
            # 更新完所有参数花费的时间,单位:小时
            need_time = cost_time_per_update * (args.total_steps -
                                                global_step) / 3600.0
            cost_time_per_chectpoint = cost_time_per_update * args.checkpoint_steps / 3600.0
            start = time.time()
            if args.local_rank in [-1, 0]:
                INFO = PRE + '当前/chcklpoint_steps/total:{}/{}/{},loss{}/{},更新一次参数{}秒,checkpoint_steps {}小时,' \
                             '训练完成{}小时\n'.format(global_step, args.checkpoint_steps, args.total_steps,
                                                 round(masked_lm_loss.item(), 5),
                                                 round(next_sentence_loss.item(), 5), round(cost_time_per_update, 4),
                                                 round(cost_time_per_chectpoint, 3), round(need_time, 3))
                print(INFO)
        # Save a trained model
        if (global_step + 1) % args.checkpoint_steps == 0:
            checkpoint_index = (global_step + 1) % args.checkpoint_steps
            if args.rank >= 0:
                train_data_loader.train_sampler.set_epoch(checkpoint_index)
            # if args.eval:
            #     # 如果是pretrain,验证MLM;如果微调,验证评价指标
            #     result = None
            #if best_result < result and _get_checkpont_num(args.model_output_num):
            if args.rank in [0, -1]:
                logger.info("** ** * Saving  model and optimizer * ** **")

                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.model_output_dir,
                    "model.{0}.bin".format(checkpoint_index))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.model_output_dir,
                    "optim.{0}.bin".format(checkpoint_index))
                torch.save(optimizer.state_dict(), output_optim_file)
                if args.fp16 and args.amp:
                    logger.info("** ** * Saving  amp state  * ** **")
                    output_amp_file = os.path.join(
                        args.model_output_dir,
                        "amp.{0}.bin".format(checkpoint_index))
                    torch.save(amp.state_dict(), output_amp_file)
                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
    if args.rank == 0:
        writer.close()
        print('** ** * train finished * ** **')
Beispiel #14
0
def train(method):
    seed = 0
    client_list = [0, 1]

    Nets = []
    random.seed(seed)
    path = 'data/'
    for client in client_list:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        data_loader = TrainDataLoader(client, path)
        val_loader = ValTestDataLoader(client, path)
        net = Net(data_loader.student_n, data_loader.exer_n,
                  data_loader.knowledge_n)
        net = net.to(device)
        Nets.append([
            client, data_loader, net,
            copy.deepcopy(net.state_dict()), val_loader
        ])

    global_model1 = Nets[0][3]
    loss_function = nn.MSELoss(reduction='mean')

    Gauc = 0
    for i in range(Epoch):
        AUC = []
        ACC = []
        for index in range(len(Nets)):
            school = Nets[index][0]
            net = Nets[index][2]
            data_loader = Nets[index][1]
            val_loader = Nets[index][4]
            optimizer = optim.Adam(net.parameters(), lr=0.001)
            print('training model...' + str(school))
            best = 0
            best_epoch = 0
            best_knowauc = None
            best_indice = 0
            for epoch in range(epoch_n):
                metric, _, _, know_auc, know_acc = validate(
                    net, epoch, school, path, val_loader)
                auc = metric[1]
                rmse = metric[0]
                indice = metric[1]
                if auc > best:
                    best = auc
                    best_knowauc = know_auc
                    best_indice = indice
                    best_epoch = epoch
                    best_knowacc = know_acc
                    Nets[index][3] = copy.deepcopy(net.state_dict())
                if epoch - best_epoch >= 5:
                    break

                data_loader.reset()
                running_loss = 0.0
                batch_count = 0
                know_distribution = torch.zeros((data_loader.knowledge_n))
                while not data_loader.is_end():

                    batch_count += 1
                    input_stu_ids, input_exer_ids, input_knowledge_embs, labels = data_loader.next_batch(
                    )

                    know_distribution += torch.sum(input_knowledge_embs, 0)
                    input_stu_ids, input_exer_ids, input_knowledge_embs, labels = input_stu_ids.to(
                        device), input_exer_ids.to(
                            device), input_knowledge_embs.to(
                                device), labels.to(device)
                    optimizer.zero_grad()
                    output = net.forward(input_stu_ids, input_exer_ids,
                                         input_knowledge_embs)
                    loss = loss_function(output, labels)
                    loss.backward()
                    optimizer.step()

            net.load_state_dict(Nets[index][3])
            Nets[index][2] = net
            distribution = know_distribution * best_knowacc
            distribution[distribution == 0] = 0.001
            Nets[index].append(distribution.unsqueeze(1).to(device))
            print('Best AUC:', best)
            AUC.append([best_indice, best_knowacc])
            ACC.append(best_indice)

        l_school = [item[0] for item in Nets]
        l_weights = [len(item[1].data) for item in Nets]
        l_know = [item[5] for item in Nets]
        l_net = [item[3] for item in Nets]
        metric0 = []
        metric1 = []
        metric2 = []
        global_model2, student_group, question_group, _ = Fedknow(
            l_net, l_weights, l_know, AUC, method)
        print('global test ===========')
        for k in range(len(Nets)):
            metric2.append(
                validate(Nets[k][2], i, l_school[k], path, Nets[k][4]))
        globalauc = total(metric2)

        for k in range(len(Nets)):
            Apply(copy.deepcopy(global_model2), Nets[k][2], AUC[k],
                  student_group, question_group, method)
def train(args):
    utils.make_all_dirs(current_time)
    if args.load_var:
        all_utterances, labels, word_dict = read_data(load_var=args.load_var,
                                                      input_=None,
                                                      mode='train')
        dev_utterances, dev_labels, _ = read_data(load_var=args.load_var,
                                                  input_=None,
                                                  mode='dev')
    else:
        all_utterances, labels, word_dict = read_data(load_var=args.load_var, \
                input_=os.path.join(constant.data_path, "entangled_train.json"), mode='train')
        dev_utterances, dev_labels, _ = read_data(load_var=args.load_var, \
                input_=os.path.join(constant.data_path, "entangled_dev.json"), mode='dev')

    word_emb = build_embedding_matrix(word_dict, glove_loc=args.glove_loc, \
                    emb_loc=os.path.join(constant.save_input_path, "word_emb.pk"), load_emb=False)

    if args.save_input:
        utils.save_or_read_input(os.path.join(constant.save_input_path, "train_utterances.pk"), \
                                    rw='w', input_obj=all_utterances)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "train_labels.pk"), \
                                    rw='w', input_obj=labels)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "word_dict.pk"), \
                                    rw='w', input_obj=word_dict)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "word_emb.pk"), \
                                    rw='w', input_obj=word_emb)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "dev_utterances.pk"), \
                                    rw='w', input_obj=dev_utterances)
        utils.save_or_read_input(os.path.join(constant.save_input_path, "dev_labels.pk"), \
                                    rw='w', input_obj=dev_labels)

    train_dataloader = TrainDataLoader(all_utterances, labels, word_dict)
    if args.add_noise:
        noise_train_dataloader = TrainDataLoader(all_utterances,
                                                 labels,
                                                 word_dict,
                                                 add_noise=True)
    else:
        noise_train_dataloader = None
    dev_dataloader = TrainDataLoader(dev_utterances,
                                     dev_labels,
                                     word_dict,
                                     name='dev')

    logger_name = os.path.join(constant.log_path,
                               "{}.txt".format(current_time))
    LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
    logging.basicConfig(format=LOG_FORMAT,
                        level=logging.INFO,
                        filename=logger_name,
                        filemode='w')
    logger = logging.getLogger()
    global log_head
    log_head = log_head + "Training Model: {}; ".format(args.model)
    if args.add_noise:
        log_head += "Add Noise: True; "
    logger.info(log_head)

    if args.model == 'T':
        ensemble_model_bidirectional = EnsembleModel(word_dict,
                                                     word_emb=word_emb,
                                                     bidirectional=True)
    elif args.model == 'TS':
        ensemble_model_bidirectional = EnsembleModel(word_dict,
                                                     word_emb=None,
                                                     bidirectional=True)
    else:
        ensemble_model_bidirectional = None
    if args.model == 'TS':
        ensemble_model_bidirectional.load_state_dict(
            torch.load(args.model_path))
    ensemble_model = EnsembleModel(word_dict,
                                   word_emb=word_emb,
                                   bidirectional=False)

    if torch.cuda.is_available():
        ensemble_model.cuda()
        if args.model == 'T' or args.model == 'TS':
            ensemble_model_bidirectional.cuda()

    supervised_trainer = SupervisedTrainer(args, ensemble_model, teacher_model=ensemble_model_bidirectional, \
                                                logger=logger, current_time=current_time)

    supervised_trainer.train(train_dataloader, noise_train_dataloader,
                             dev_dataloader)