コード例 #1
0
    def get_losses_and_metrics(self):
        """Get loss and metric meters to log them during experiments."""
        out_model_cfg = self.cfg[MODEL][OUTPUTS]
        post_proc_cfg = self.cfg[POSTPROCS]

        self.train_loss_meters = loss_meters(out_model_cfg)
        self.val_loss_meters = loss_meters(out_model_cfg)
        self.val_metrics = metrics(out_model_cfg)
        self.post_proc_metrics = metrics(post_proc_cfg)
コード例 #2
0
def _run_train(args, model, criterion, optimizer, dataloader):

    model.train()

    total_loss, acc, f1 = 0, None, None
    for idx, (ids, labels, masks) in enumerate(dataloader):
        b, l = ids.shape

        optimizer.zero_grad()

        predict = model(ids.cuda())

        loss = criterion(predict, labels.type_as(predict).cuda())
        #loss = torch.masked_select(loss, masks.expand(b,l,20).cuda())
        loss = loss.mean()
        loss.backward()

        optimizer.step()

        total_loss += loss.item() * b
        acc, f1 = metrics(predict, labels, acc, f1, idx)
        print("\t[{}/{}] train loss:{:.3f} acc:{:.2f} f1:{:.2f}".format(
            idx + 1, len(dataloader), total_loss / (idx + 1) / b, acc, f1),
              end='   \r')

    return {'loss': total_loss / len(dataloader.dataset), 'acc': acc, 'f1': f1}
コード例 #3
0
    def test_metrics_collected(self):
        m = metrics()
        self.assertEqual(self.decorated, self.decorated_func())
        self.assertTrue(self.metrics_name in m.data)

        for m_type in [Constants.TIMER, Constants.COUNTER]:
            self.assertTrue(m_type in m.data[self.metrics_name])
            self.assertTrue(m.data[self.metrics_name][m_type] > 0)
コード例 #4
0
def compute_dice(prediction, mask):
    """
    Args:
    prediction: cuda tensor - model prediction after log softmax (depth, channels, height, width)
    mask: cuda long tensor - ground truth (depth, height, width)

    Returns:
    dice score - ndarray
    concrete prediction - ndarray (depth height width) uint8
    mask - ndarray (depth height width) uint8
    """
    prediction = prediction.cpu()
    mask = mask.cpu().numpy().astype(np.uint8)
    classes = list(range(1, prediction.shape[1]))
    # get the maximum values of the channels dimension, then take the indices
    prediction = prediction.max(1)[1]
    prediction = prediction.numpy().astype(np.uint8)
    dice = metrics.metrics(mask, prediction, classes=classes)
    return dice, prediction, mask
コード例 #5
0
def _run_eval(args, model, criterion, dataloader):

    model.eval()

    with torch.no_grad():
        total_loss, acc, f1 = 0, None, None
        for idx, (_input, _label) in enumerate(dataloader):
            b = _input.shape[0]

            _predict = model(_input.cuda())

            loss = criterion(_predict, _label.type_as(_predict).cuda())

            total_loss += loss.item() * b
            acc, f1 = metrics(_predict, _label, acc, f1, idx)
            print("\t[{}/{}] valid loss:{:.3f} acc:{:.2f} f1:{:.2f}".format(
                idx + 1, len(dataloader), total_loss / (idx + 1) / b, acc, f1),
                  end='   \r')

    return {'loss': total_loss / len(dataloader.dataset), 'acc': acc, 'f1': f1}
コード例 #6
0
def _run_train(args, model, criterion, optimizer, dataloader):

    model.train()

    total_loss, acc, f1 = 0, None, None
    for idx, (_input, _label) in enumerate(dataloader):
        b = _input.shape[0]

        optimizer.zero_grad()

        _predict = model(_input.cuda())

        loss = criterion(_predict, _label.type_as(_predict).cuda())
        loss.backward()

        optimizer.step()

        total_loss += loss.item() * b
        acc, f1 = metrics(_predict, _label, acc, f1, idx)
        print("\t[{}/{}] train loss:{:.3f} acc:{:.2f} f1:{:.2f}".format(
            idx + 1, len(dataloader), total_loss / (idx + 1) / b, acc, f1),
              end='   \r')

    return {'loss': total_loss / len(dataloader.dataset), 'acc': acc, 'f1': f1}
コード例 #7
0
ファイル: train.py プロジェクト: vvvityaaa/mcmrirecon
def train_net(params):
    # Initialize Parameters
    params = DotDict(params)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    verbose = {}
    verbose['loss_train'], verbose['loss_valid'], verbose['psnr_train'], verbose['psnr_valid'], \
        verbose['ssim_train'], verbose['ssim_valid'], verbose['vif_train'], verbose['vif_valid'] = ([] for i in range(8))

    log_metrics = True
    ssim_module = SSIM()
    msssim_module = MSSSIM()
    vifLoss = VIFLoss(sigma_n_sq=0.4, data_range=1.)
    msssimLoss = MultiScaleSSIMLoss(data_range=1.)
    best_validation_metrics = 100

    train_generator, val_generator = data_loaders(params)
    loaders = {"train": train_generator, "valid": val_generator}

    wnet_identifier = params.mask_URate[0:2] + "WNet_dense=" + str(int(params.dense)) + "_" + params.architecture + "_" \
                      + params.lossFunction + '_lr=' + str(params.lr) + '_ep=' + str(params.num_epochs) + '_complex=' \
                      + str(int(params.complex_net)) + '_' + 'edgeModel=' + str(int(params.edge_model)) \
                      + '(' + str(params.num_edge_slices) + ')_date=' + (datetime.now()).strftime("%d-%m-%Y_%H-%M-%S")

    if not os.path.isdir(params.model_save_path):
        os.mkdir(params.model_save_path)
    print("\n\nModel will be saved at:\n", params.model_save_path)
    print("WNet ID: ", wnet_identifier)

    wnet, optimizer, best_validation_loss, preTrainedEpochs = generate_model(
        params, device)

    # data = (iter(train_generator)).next()

    # Adding writer for tensorboard. Also start tensorboard, which tries to access logs in the runs directory
    writer = init_tensorboard(iter(train_generator), wnet, wnet_identifier,
                              device)

    for epoch in trange(preTrainedEpochs, params.num_epochs):
        for phase in ['train', 'valid']:
            if phase == 'train':
                wnet.train()
            else:
                wnet.eval()

            for i, data in enumerate(loaders[phase]):

                # for i in range(10000):
                x, y_true, _, _, fname, slice_num = data
                x, y_true = x.to(device, dtype=torch.float), y_true.to(
                    device, dtype=torch.float)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    y_pred = wnet(x)
                    if params.lossFunction == 'mse':
                        loss = F.mse_loss(y_pred, y_true)
                    elif params.lossFunction == 'l1':
                        loss = F.l1_loss(y_pred, y_true)
                    elif params.lossFunction == 'ssim':
                        # standard SSIM
                        loss = 0.16 * F.l1_loss(y_pred, y_true) + 0.84 * (
                            1 - ssim_module(y_pred, y_true))
                    elif params.lossFunction == 'msssim':
                        # loss = 0.16 * F.l1_loss(y_pred, y_true) + 0.84 * (1 - msssim_module(y_pred, y_true))
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2]))
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2]))
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = msssimLoss(prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2]))
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2]))
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = vifLoss(prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'mse+vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2])).to(device)
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2])).to(device)
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = 0.15 * F.mse_loss(
                            prediction_abs_flat,
                            target_abs_flat) + 0.85 * vifLoss(
                                prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'l1+vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2])).to(device)
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2])).to(device)
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = 0.146 * F.l1_loss(
                            y_pred, y_true) + 0.854 * vifLoss(
                                prediction_abs_flat, target_abs_flat)
                    elif params.lossFunction == 'msssim+vif':
                        prediction_abs = torch.sqrt(
                            torch.square(y_pred[:, 0::2]) +
                            torch.square(y_pred[:, 1::2])).to(device)
                        target_abs = torch.sqrt(
                            torch.square(y_true[:, 0::2]) +
                            torch.square(y_true[:, 1::2])).to(device)
                        prediction_abs_flat = (torch.flatten(
                            prediction_abs, start_dim=0,
                            end_dim=1)).unsqueeze(1)
                        target_abs_flat = (torch.flatten(
                            target_abs, start_dim=0, end_dim=1)).unsqueeze(1)
                        loss = 0.66 * msssimLoss(
                            prediction_abs_flat,
                            target_abs_flat) + 0.33 * vifLoss(
                                prediction_abs_flat, target_abs_flat)

                    if not math.isnan(loss.item()) and loss.item(
                    ) < 2 * best_validation_loss:  # avoid nan/spike values
                        verbose['loss_' + phase].append(loss.item())
                        writer.add_scalar(
                            'Loss/' + phase + '_epoch_' + str(epoch),
                            loss.item(), i)

                    if log_metrics and (
                        (i % params.verbose_gap == 0) or
                        (phase == 'valid' and epoch > params.verbose_delay)):
                        y_true_copy = y_true.detach().cpu().numpy()
                        y_pred_copy = y_pred.detach().cpu().numpy()
                        y_true_copy = y_true_copy[:, ::
                                                  2, :, :] + 1j * y_true_copy[:,
                                                                              1::
                                                                              2, :, :]
                        y_pred_copy = y_pred_copy[:, ::
                                                  2, :, :] + 1j * y_pred_copy[:,
                                                                              1::
                                                                              2, :, :]
                        if params.architecture[-1] == 'k':
                            # transform kspace to image domain
                            y_true_copy = np.fft.ifft2(y_true_copy,
                                                       axes=(2, 3))
                            y_pred_copy = np.fft.ifft2(y_pred_copy,
                                                       axes=(2, 3))

                        # Sum of squares
                        sos_true = np.sqrt(
                            (np.abs(y_true_copy)**2).sum(axis=1))
                        sos_pred = np.sqrt(
                            (np.abs(y_pred_copy)**2).sum(axis=1))
                        '''
                        # Normalization according to: extract_challenge_metrics.ipynb
                        sos_true_max = sos_true.max(axis = (1,2),keepdims = True)
                        sos_true_org = sos_true/sos_true_max
                        sos_pred_org = sos_pred/sos_true_max
                        # Normalization by normalzing with ref with max_ref and rec with max_rec, respectively
                        sos_true_max = sos_true.max(axis = (1,2),keepdims = True)
                        sos_true_mod = sos_true/sos_true_max
                        sos_pred_max = sos_pred.max(axis = (1,2),keepdims = True)
                        sos_pred_mod = sos_pred/sos_pred_max
                        '''
                        '''
                        # normalization by mean and std
                        std = sos_pred.std(axis=(1, 2), keepdims=True)
                        mean = sos_pred.mean(axis=(1, 2), keepdims=True)
                        sos_pred_std = (sos_pred-mean) / std
                        std = sos_true.std(axis=(1, 2), keepdims=True)
                        mean = sos_pred.mean(axis=(1, 2), keepdims=True)
                        sos_true_std = (sos_true-mean) / std
                        '''
                        '''
                        ssim, psnr, vif = metrics(sos_pred_org, sos_true_org)
                        ssim_mod, psnr_mod, vif_mod = metrics(sos_pred_mod, sos_true_mod)
                        '''
                        sos_true_max = sos_true.max(axis=(1, 2), keepdims=True)
                        sos_true_org = sos_true / sos_true_max
                        sos_pred_org = sos_pred / sos_true_max

                        ssim, psnr, vif = metrics(sos_pred, sos_true)
                        ssim_normed, psnr_normed, vif_normed = metrics(
                            sos_pred_org, sos_true_org)

                        verbose['ssim_' + phase].append(np.mean(ssim_normed))
                        verbose['psnr_' + phase].append(np.mean(psnr_normed))
                        verbose['vif_' + phase].append(np.mean(vif_normed))
                        '''
                        print("===Normalization according to: extract_challenge_metrics.ipynb===")
                        print("SSIM: ", verbose['ssim_'+phase][-1])
                        print("PSNR: ", verbose['psnr_'+phase][-1])
                        print("VIF: ",  verbose['vif_' +phase][-1])
                        print("===Normalization by normalzing with ref with max_ref and rec with max_rec, respectively===")
                        print("SSIM_mod: ", np.mean(ssim_mod))
                        print("PSNR_mod: ", np.mean(psnr_mod))
                        print("VIF_mod: ",  np.mean(vif_mod))
                        print("===Normalization by dividing by the standard deviation of ref and rec, respectively===")
                        '''
                        print("Epoch: ", epoch)
                        print("SSIM: ", np.mean(ssim))
                        print("PSNR: ", np.mean(psnr))
                        print("VIF: ", np.mean(vif))

                        print("SSIM_normed: ", verbose['ssim_' + phase][-1])
                        print("PSNR_normed: ", verbose['psnr_' + phase][-1])
                        print("VIF_normed: ", verbose['vif_' + phase][-1])
                        '''
                        if True: #verbose['vif_' + phase][-1] < 0.4:
                            plt.figure(figsize=(9, 6), dpi=150)
                            gs1 = gridspec.GridSpec(3, 2)
                            gs1.update(wspace=0.002, hspace=0.1)
                            plt.subplot(gs1[0])
                            plt.imshow(sos_true[0], cmap="gray")
                            plt.axis("off")
                            plt.subplot(gs1[1])
                            plt.imshow(sos_pred[0], cmap="gray")
                            plt.axis("off")
                            plt.show()
                            # plt.pause(10)
                            # plt.close()
                        '''
                        writer.add_scalar(
                            'SSIM/' + phase + '_epoch_' + str(epoch),
                            verbose['ssim_' + phase][-1], i)
                        writer.add_scalar(
                            'PSNR/' + phase + '_epoch_' + str(epoch),
                            verbose['psnr_' + phase][-1], i)
                        writer.add_scalar(
                            'VIF/' + phase + '_epoch_' + str(epoch),
                            verbose['vif_' + phase][-1], i)

                    print('Loss ' + phase + ': ', loss.item())

                    if phase == 'train':
                        if loss.item() < 2 * best_validation_loss:
                            loss.backward()
                            optimizer.step()

        # Calculate Averages
        psnr_mean = np.mean(verbose['psnr_valid'])
        ssim_mean = np.mean(verbose['ssim_valid'])
        vif_mean = np.mean(verbose['vif_valid'])
        validation_metrics = 0.2 * psnr_mean + 0.4 * ssim_mean + 0.4 * vif_mean

        valid_avg_loss_of_current_epoch = np.mean(verbose['loss_valid'])
        writer.add_scalar('AvgLoss/+train_epoch_' + str(epoch),
                          np.mean(verbose['loss_train']), epoch)
        writer.add_scalar('AvgLoss/+valid_epoch_' + str(epoch),
                          np.mean(verbose['loss_valid']), epoch)
        writer.add_scalar('AvgSSIM/+train_epoch_' + str(epoch),
                          np.mean(verbose['ssim_train']), epoch)
        writer.add_scalar('AvgSSIM/+valid_epoch_' + str(epoch), ssim_mean,
                          epoch)
        writer.add_scalar('AvgPSNR/+train_epoch_' + str(epoch),
                          np.mean(verbose['psnr_train']), epoch)
        writer.add_scalar('AvgPSNR/+valid_epoch_' + str(epoch), psnr_mean,
                          epoch)
        writer.add_scalar('AvgVIF/+train_epoch_' + str(epoch),
                          np.mean(verbose['vif_train']), epoch)
        writer.add_scalar('AvgVIF/+valid_epoch_' + str(epoch), vif_mean, epoch)

        verbose['loss_train'], verbose['loss_valid'], verbose['psnr_train'], verbose['psnr_valid'], \
        verbose['ssim_train'], verbose['ssim_valid'], verbose['vif_train'], verbose['vif_valid'] = ([] for i in
                                                                                                    range(8))

        # Save Networks/Checkpoints
        if best_validation_metrics > validation_metrics:
            best_validation_metrics = validation_metrics
            best_validation_loss = valid_avg_loss_of_current_epoch
            save_checkpoint(
                wnet, params.model_save_path, wnet_identifier, {
                    'epoch': epoch + 1,
                    'state_dict': wnet.state_dict(),
                    'best_validation_loss': best_validation_loss,
                    'optimizer': optimizer.state_dict(),
                }, True)
        else:
            save_checkpoint(
                wnet, params.model_save_path, wnet_identifier, {
                    'epoch': epoch + 1,
                    'state_dict': wnet.state_dict(),
                    'best_validation_loss': best_validation_loss,
                    'optimizer': optimizer.state_dict(),
                }, False)
コード例 #8
0
def validation(queries_loader, gallery_loader, model, query_test_count):
    model.eval()
    widgets = ["processing: ", progressbar.Percentage(),
               " ", progressbar.ETA(),
               " ", progressbar.FileTransferSpeed(),
               ]
    bar = progressbar.ProgressBar(widgets=widgets,
                                  max_value=len(queries_loader) + len(gallery_loader)).start()

    """
    1. calculate all the queries feature and get corresponding label 
    2. calculate gallery feature and get corresponding label
    3. 
    """
    queries_feature = []
    queries_label = []
    gallery_feature = []
    gallery_label = []

    for i, batch in enumerate(queries_loader):
        bar.update(i)
        mix_batch_data, batch_label = batch

        batch_data, batch_seq_data = preprocess_batch_data(mix_batch_data,
                                                           seq_len=SEQ_LEN)

        batch_data = Variable(batch_data, volatile=True).type(torch.FloatTensor).cuda()
        batch_seq_data = Variable(batch_seq_data).type(torch.FloatTensor).cuda()

        batch_label = Variable(batch_label).type(torch.LongTensor).cuda()
        logits = model(batch_data, batch_seq_data)
        queries_feature.append(logits)
        queries_label.append(batch_label)

    # queries_feature [num_queries, feature_size]
    queries_feature = torch.cat(queries_feature, dim=0)

    # queries_feature [num_gallery_samples, feature_size]
    queries_label = torch.cat(queries_label, dim=0)

    for i, batch in enumerate(gallery_loader):
        bar.update(len(queries_loader) + i)
        mix_batch_data, batch_label = batch

        batch_data, batch_seq_data = preprocess_batch_data(mix_batch_data,
                                                           seq_len=SEQ_LEN)

        batch_data = Variable(batch_data, volatile=True).type(torch.FloatTensor).cuda()
        batch_seq_data = Variable(batch_seq_data).type(torch.FloatTensor).cuda()

        batch_label = Variable(batch_label).type(torch.LongTensor).cuda()
        logits = model(batch_data, batch_seq_data)
        gallery_feature.append(logits)
        gallery_label.append(batch_label)

    bar.finish()

    gallery_feature = torch.cat(gallery_feature, dim=0)
    gallery_label = torch.cat(gallery_label, dim=0)

    print("num_queries", len(queries_label))
    print("num gallery samples: ", len(gallery_label))

    print("calculating metrics")
    metrics.metrics(queries=queries_feature, gallery_features=gallery_feature,
                    queries_label=queries_label, gallery_label=gallery_label,
                    query_test_count=query_test_count, n=10)
コード例 #9
0
ファイル: eval.py プロジェクト: vvvityaaa/mcmrirecon
        print(
            "Number of files incompaible with number of files in the test set!"
        )
        continue

    m = np.zeros((nfiles * nslices, 3))
    for (jj, ii) in enumerate(rec_path):
        print(ii)
        with h5py.File(ii, 'r') as f:
            rec = f['reconstruction'][()]

        name = ii.split("\\")[-1]
        with h5py.File(os.path.join(refs[i], name), 'r') as f:
            ref = f['reconstruction'][()]

        ref_max = ref.max(axis=(1, 2), keepdims=True)
        ref = ref / ref_max

        rec = rec / ref_max

        ssim, psnr, vif = metrics(rec, ref)
        m[jj * nslices:(jj + 1) * nslices, 0] = ssim
        m[jj * nslices:(jj + 1) * nslices, 1] = psnr
        m[jj * nslices:(jj + 1) * nslices, 2] = vif

    res_dic[dic_keys[i]] = m

with open(os.path.join(results_path, 'Metrics', team_name + '.pickle'),
          'wb') as handle:
    pickle.dump(res_dic, handle, protocol=pickle.HIGHEST_PROTOCOL)
コード例 #10
0
    all_probs += probs

print('max prob: ', max(all_probs))
_, _, _, _, best_thres = tune_thres_new(
    dev_batch.gold(), all_probs)  # , start=0.0, end=0.002, fold=1001)
print('Best thres (dev): %.8f' % best_thres)

all_probs = []
for i, b in enumerate(test_batch):
    _, probs = model.predict(b, thres=0.0)
    all_probs += probs

if not os.path.exists(opt['model_save_dir']):
    os.mkdir(opt['model_save_dir'])
preds = get_preds(all_probs, best_thres)
accuracy, precision, recall, f1 = metrics(test_batch.gold(), preds)
auc, _, _, _, _ = tune_thres_new(dev_batch.gold(), all_probs, opt)
print('Accuracy: %.4f, Precision: %.4f, Recall: %.4f, F1: %.4f' %
      (accuracy, precision, recall, f1))

# thres_to_test = [0.0, 0.00001, 0.0005, 0.001]
# for thres in thres_to_test:
#     preds = get_preds(all_probs, thres)
#     accuracy, precision, recall, f1 = metrics(test_batch.gold(), preds)
#
#     print('Accuracy: %.4f, Precision: %.4f, Recall: %.4f, F1: %.4f' % (accuracy, precision, recall, f1))
# print('Tunning on test...')
# print('max prob: ', max(all_probs))
# _, _, _, _, best_thres = tune_thres(test_batch.gold(), all_probs, start=0.0, end=0.002, fold=1001)
# print('Best thres (dev): %.8f' % best_thres)
#
コード例 #11
0
# train_batch = DataLoader(os.path.join(opt['data_dir'], 'train.csv'),
#                    opt['batch_size'],
#                    opt,
#                    weibo2embid=weibo2embid,
#                    evaluation=False)
dev_batch = DataLoader(os.path.join(opt['data_dir'], 'dev.csv'),
                   opt['batch_size'],
                   opt,
                   weibo2embid=weibo2embid,
                   evaluation=True)

model = ModelWrapper(opt, weibo2embid, eva=True)
model.load(os.path.join(opt['model_save_dir'], 'best_model.pt'))

all_probs = []
all_preds = []
for i, b in enumerate(dev_batch):
    preds, probs, _ = model.predict(b, thres=0.5)
    all_probs += probs
    all_preds += preds

acc, prec, rec, dev_f1 = metrics(dev_batch.gold(), all_preds)
print('acc: {}, prec: {}, rec: {}, f1: {}\n'.format(acc, prec, rec, dev_f1))

auc, prec, rec, f1, best_thres = tune_thres_new(dev_batch.gold(), all_probs)
print('auc: {}, prec: {}, rec: {}, f1: {}, best_thres: {}'.format(auc, prec, rec, f1, best_thres))

with open('./log.txt', 'a+') as fout:
    fout.write('\n' + time.asctime(time.localtime(time.time())))
    fout.write(' '.join(sys.argv))
    fout.write('\nauc: {}, prec: {}, rec: {}, f1: {}, best_thres: {}\n'.format(auc, prec, rec, f1, best_thres))