Exemplo n.º 1
0
def train_valid(model, optimizer, scheduler, epoch, dataloaders, data_size,
                start_time):

    for phase in ['train', 'valid']:
        if args.pure_validation and phase == 'train':
            continue
        if args.pure_training and phase == 'valid':
            continue
        labels, distances = [], []
        triplet_loss_sum = 0.0

        if phase == 'train':
            scheduler.step()
            model.train()
        else:
            model.eval()

        #for batch_idx in range(0, data_size[phase], 1):
        for batch_idx, batch_sample in enumerate(dataloaders[phase]):
            #print("batch_idx:", batch_idx)
            try:
                #batch_sample = dataloaders[phase][batch_idx]
                if not 'exception' in batch_sample:
                    anc_img = batch_sample['anc_img'].to(device)
                    pos_img = batch_sample['pos_img'].to(device)
                    neg_img = batch_sample['neg_img'].to(device)

                    pos_cls = batch_sample['pos_class'].to(device)
                    neg_cls = batch_sample['neg_class'].to(device)

                    with torch.set_grad_enabled(phase == 'train'):

                        # anc_embed, pos_embed and neg_embed are encoding(embedding) of image
                        anc_embed, pos_embed, neg_embed = model(
                            anc_img), model(pos_img), model(neg_img)
                        #for i in anc_embed:
                        #    print(i.item())
                        if args.num_valid_triplets <= 100:
                            anc_embed_cpu = anc_embed.cpu()
                            pos_embed_cpu = pos_embed.cpu()
                            neg_embed_cpu = neg_embed.cpu()
                            pos_cls_cpu = pos_cls.cpu()
                            neg_cls_cpu = neg_cls.cpu()
                            pd.DataFrame([t.numpy() for t in anc_embed_cpu
                                          ]).to_csv("./embeddings.csv",
                                                    mode='a',
                                                    header=None)
                            pd.DataFrame([t.numpy() for t in pos_embed_cpu
                                          ]).to_csv("./embeddings.csv",
                                                    mode='a',
                                                    header=None)
                            pd.DataFrame([t.numpy() for t in neg_embed_cpu
                                          ]).to_csv("./embeddings.csv",
                                                    mode='a',
                                                    header=None)
                            pd.DataFrame({
                                'type':
                                "anc",
                                'id':
                                batch_sample['anc_id'],
                                'class':
                                pos_cls_cpu,
                                'train_set':
                                args.train_csv_name.split('.')[0],
                                'val_set':
                                args.valid_csv_name.split('.')[0]
                            }).to_csv("./embeddings_info.csv",
                                      mode='a',
                                      header=None)
                            pd.DataFrame({
                                'type':
                                "pos",
                                'id':
                                batch_sample['pos_id'],
                                'class':
                                pos_cls_cpu,
                                'train_set':
                                args.train_csv_name.split('.')[0],
                                'val_set':
                                args.valid_csv_name.split('.')[0]
                            }).to_csv("./embeddings_info.csv",
                                      mode='a',
                                      header=None)
                            pd.DataFrame({
                                'type':
                                "neg",
                                'id':
                                batch_sample['neg_id'],
                                'class':
                                pos_cls_cpu,
                                'train_set':
                                args.train_csv_name.split('.')[0],
                                'val_set':
                                args.valid_csv_name.split('.')[0]
                            }).to_csv("./embeddings_info.csv",
                                      mode='a',
                                      header=None)

                        #print([t.size() for t in anc_embed])
                        # choose the hard negatives only for "training"
                        pos_dist = l2_dist.forward(anc_embed, pos_embed)
                        neg_dist = l2_dist.forward(anc_embed, neg_embed)

                        all = (neg_dist - pos_dist <
                               args.margin).cpu().numpy().flatten()
                        if phase == 'train':
                            hard_triplets = np.where(all == 1)
                            if len(hard_triplets[0]) == 0:
                                continue
                        else:
                            hard_triplets = np.where(all >= 0)

                        anc_hard_embed = anc_embed[hard_triplets].to(device)
                        pos_hard_embed = pos_embed[hard_triplets].to(device)
                        neg_hard_embed = neg_embed[hard_triplets].to(device)

                        anc_hard_img = anc_img[hard_triplets].to(device)
                        pos_hard_img = pos_img[hard_triplets].to(device)
                        neg_hard_img = neg_img[hard_triplets].to(device)

                        pos_hard_cls = pos_cls[hard_triplets].to(device)
                        neg_hard_cls = neg_cls[hard_triplets].to(device)

                        anc_img_pred = model.forward_classifier(
                            anc_hard_img).to(device)
                        pos_img_pred = model.forward_classifier(
                            pos_hard_img).to(device)
                        neg_img_pred = model.forward_classifier(
                            neg_hard_img).to(device)

                        triplet_loss = TripletLoss(args.margin).forward(
                            anc_hard_embed, pos_hard_embed,
                            neg_hard_embed).to(device)

                        if phase == 'train':
                            optimizer.zero_grad()
                            triplet_loss.backward()
                            optimizer.step()

                        dists = l2_dist.forward(anc_embed, pos_embed)
                        distances.append(dists.data.cpu().numpy())
                        labels.append(np.ones(dists.size(0)))

                        dists = l2_dist.forward(anc_embed, neg_embed)
                        distances.append(dists.data.cpu().numpy())
                        labels.append(np.zeros(dists.size(0)))

                        triplet_loss_sum += triplet_loss.item()
            except:
                #traceback.print_exc()
                print("traceback: ", traceback.format_exc())
                print("something went wrong with batch_idx: ", batch_idx,
                      ", batch_sample:", batch_sample, ", neg_img size: ",
                      batch_sample['neg_img'].shape, ", pos_img size: ",
                      batch_sample['pos_img'].shape, ", anc_img size: ",
                      batch_sample['anc_img'].shape)

        avg_triplet_loss = triplet_loss_sum / data_size[phase]
        labels = np.array([sublabel for label in labels for sublabel in label])
        distances = np.array(
            [subdist for dist in distances for subdist in dist])

        nrof_pairs = min(len(labels), len(distances))
        if nrof_pairs >= 10:
            tpr, fpr, accuracy, val, val_std, far = evaluate(distances, labels)
            print('  {} set - Triplet Loss       = {:.8f}'.format(
                phase, avg_triplet_loss))
            print('  {} set - Accuracy           = {:.8f}'.format(
                phase, np.mean(accuracy)))
            duration = time.time() - start_time

            with open('{}/{}_log_epoch{}.txt'.format(log_dir, phase, epoch),
                      'w') as f:
                f.write(
                    str(epoch) + '\t' + str(np.mean(accuracy)) + '\t' +
                    str(avg_triplet_loss) + '\t' + str(duration))

            if phase == 'train':
                torch.save({
                    'epoch': epoch,
                    'state_dict': model.state_dict()
                }, '{}/checkpoint_epoch{}.pth'.format(save_dir, epoch))
            else:
                plot_roc(fpr,
                         tpr,
                         figure_name='{}/roc_valid_epoch_{}.png'.format(
                             log_dir, epoch))
Exemplo n.º 2
0
def train_valid(model, optimizer, triploss, scheduler, epoch, dataloaders, data_size):
    for phase in ['train', 'valid']:

        labels, distances = [], []
        triplet_loss_sum = 0.0

        if phase == 'train':
            scheduler.step()
            if scheduler.last_epoch % scheduler.step_size == 0:
                print("LR decayed to:", ', '.join(map(str, scheduler.get_lr())))
            model.train()
        else:
            model.eval()

        for batch_idx, batch_sample in enumerate(dataloaders[phase]):

            anc_img = batch_sample['anc_img'].to(device)
            pos_img = batch_sample['pos_img'].to(device)
            neg_img = batch_sample['neg_img'].to(device)

            # pos_cls = batch_sample['pos_class'].to(device)
            # neg_cls = batch_sample['neg_class'].to(device)

            with torch.set_grad_enabled(phase == 'train'):

                # anc_embed, pos_embed and neg_embed are encoding(embedding) of image
                anc_embed, pos_embed, neg_embed = model(anc_img), model(pos_img), model(neg_img)

                # choose the semi hard negatives only for "training"
                pos_dist = l2_dist.forward(anc_embed, pos_embed)
                neg_dist = l2_dist.forward(anc_embed, neg_embed)

                all = (neg_dist - pos_dist < args.margin).cpu().numpy().flatten()
                if phase == 'train':
                    hard_triplets = np.where(all == 1)
                    if len(hard_triplets[0]) == 0:
                        continue
                else:
                    hard_triplets = np.where(all >= 0)

                anc_hard_embed = anc_embed[hard_triplets]
                pos_hard_embed = pos_embed[hard_triplets]
                neg_hard_embed = neg_embed[hard_triplets]

                anc_hard_img = anc_img[hard_triplets]
                pos_hard_img = pos_img[hard_triplets]
                neg_hard_img = neg_img[hard_triplets]

                # pos_hard_cls = pos_cls[hard_triplets]
                # neg_hard_cls = neg_cls[hard_triplets]

                model.module.forward_classifier(anc_hard_img)
                model.module.forward_classifier(pos_hard_img)
                model.module.forward_classifier(neg_hard_img)

                triplet_loss = triploss.forward(anc_hard_embed, pos_hard_embed, neg_hard_embed)

                if phase == 'train':
                    optimizer.zero_grad()
                    triplet_loss.backward()
                    optimizer.step()

                distances.append(pos_dist.data.cpu().numpy())
                labels.append(np.ones(pos_dist.size(0)))

                distances.append(neg_dist.data.cpu().numpy())
                labels.append(np.zeros(neg_dist.size(0)))

                triplet_loss_sum += triplet_loss.item()

        avg_triplet_loss = triplet_loss_sum / data_size[phase]
        labels = np.array([sublabel for label in labels for sublabel in label])
        distances = np.array([subdist for dist in distances for subdist in dist])

        tpr, fpr, accuracy, val, val_std, far = evaluate(distances, labels)
        print('  {} set - Triplet Loss       = {:.8f}'.format(phase, avg_triplet_loss))
        print('  {} set - Accuracy           = {:.8f}'.format(phase, np.mean(accuracy)))

        time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        lr = '_'.join(map(str, scheduler.get_lr()))
        layers = '+'.join(args.unfreeze.split(','))
        write_csv(f'log/{phase}.csv', [time, epoch, np.mean(accuracy), avg_triplet_loss, layers, args.batch_size, lr])

        if phase == 'valid':
            save_last_checkpoint({'epoch': epoch,
                                  'state_dict': model.module.state_dict(),
                                  'optimizer_state': optimizer.state_dict(),
                                  'accuracy': np.mean(accuracy),
                                  'loss': avg_triplet_loss
                                  })
            save_if_best({'epoch': epoch,
                          'state_dict': model.module.state_dict(),
                          'optimizer_state': optimizer.state_dict(),
                          'accuracy': np.mean(accuracy),
                          'loss': avg_triplet_loss
                          }, np.mean(accuracy))
        else:
            plot_roc(fpr, tpr, figure_name='./log/roc_valid_epoch_{}.png'.format(epoch))
Exemplo n.º 3
0
import numpy as np
from IPython import embed
from eval_metrics import evaluate
from eval_metrics import evaluate, plot_roc, plot_acc

targets = np.random.randint(1, 10, (100))
distances = np.random.randn(100)
tpr, fpr, acc, val, val_std, far = evaluate(distances, targets)
embed()
plot_roc(fpr, tpr, figure_name='./test.png')
Exemplo n.º 4
0
def train_valid(model, optimizer, scheduler, epoch, dataloaders, data_size):

    for phase in ['train', 'valid']:

        labels, distances = [], []
        triplet_loss_sum = 0.0

        if phase == 'train':
            scheduler.step()
            model.train()
        else:
            model.eval()

        for batch_idx, batch_sample in enumerate(dataloaders[phase]):

            #break
            anc_img = batch_sample['anc_img'].to(device)
            pos_img = batch_sample['pos_img'].to(device)
            neg_img = batch_sample['neg_img'].to(device)

            pos_cls = batch_sample['pos_class'].to(device)
            neg_cls = batch_sample['neg_class'].to(device)

            with torch.set_grad_enabled(phase == 'train'):

                # anc_embed, pos_embed and neg_embed are encoding(embedding) of image
                anc_embed, pos_embed, neg_embed = model(anc_img), model(
                    pos_img), model(neg_img)

                # choose the hard negatives only for "training"
                pos_dist = l2_dist.forward(anc_embed, pos_embed)
                neg_dist = l2_dist.forward(anc_embed, neg_embed)

                all = (neg_dist - pos_dist <
                       args.margin).cpu().numpy().flatten()
                if phase == 'train':
                    hard_triplets = np.where(all == 1)
                    if len(hard_triplets[0]) == 0:
                        continue
                else:
                    hard_triplets = np.where(all >= 0)

                anc_hard_embed = anc_embed[hard_triplets].to(device)
                pos_hard_embed = pos_embed[hard_triplets].to(device)
                neg_hard_embed = neg_embed[hard_triplets].to(device)

                anc_hard_img = anc_img[hard_triplets].to(device)
                pos_hard_img = pos_img[hard_triplets].to(device)
                neg_hard_img = neg_img[hard_triplets].to(device)

                pos_hard_cls = pos_cls[hard_triplets].to(device)
                neg_hard_cls = neg_cls[hard_triplets].to(device)

                anc_img_pred = model.forward_classifier(anc_hard_img).to(
                    device)
                pos_img_pred = model.forward_classifier(pos_hard_img).to(
                    device)
                neg_img_pred = model.forward_classifier(neg_hard_img).to(
                    device)

                triplet_loss = TripletLoss(args.margin).forward(
                    anc_hard_embed, pos_hard_embed, neg_hard_embed).to(device)

                if phase == 'train':
                    optimizer.zero_grad()
                    triplet_loss.backward()
                    optimizer.step()

                dists = l2_dist.forward(anc_embed, pos_embed)
                distances.append(dists.data.cpu().numpy())
                labels.append(np.ones(dists.size(0)))

                dists = l2_dist.forward(anc_embed, neg_embed)
                distances.append(dists.data.cpu().numpy())
                labels.append(np.zeros(dists.size(0)))

                triplet_loss_sum += triplet_loss.item()

        avg_triplet_loss = triplet_loss_sum / data_size[phase]
        labels = np.array([sublabel for label in labels for sublabel in label])
        distances = np.array(
            [subdist for dist in distances for subdist in dist])

        tpr, fpr, accuracy, val, val_std, far = evaluate(distances, labels)
        print('  {} set - Triplet Loss       = {:.8f}'.format(
            phase, avg_triplet_loss))
        print('  {} set - Accuracy           = {:.8f}'.format(
            phase, np.mean(accuracy)))

        with open('./log/{}_log_epoch{}.txt'.format(phase, epoch), 'w') as f:
            f.write(
                str(epoch) + '\t' + str(np.mean(accuracy)) + '\t' +
                str(avg_triplet_loss))

        if phase == 'train':
            torch.save({
                'epoch': epoch,
                'state_dict': model.state_dict()
            }, './log/checkpoint_epoch{}.pth'.format(epoch))
        else:
            plot_roc(fpr,
                     tpr,
                     figure_name='./log/roc_valid_epoch_{}.png'.format(epoch))
Exemplo n.º 5
0
def train_valid(model, optimizer, scheduler, epoch, dataloaders, data_size):
    for phase in ['train', 'valid']:
        # One step for train or valid
        labels, distances = [], []
        triplet_loss_sum = 0.0
        crossentropy_loss_sum = 0.0
        accuracy_sum = 0.0
        triplet_loss_sigma = 0.0
        crossentropy_loss_sigma = 0.0
        accuracy_sigma = 0.0

        if phase == 'train':
            scheduler.step()
            model.train()
        else:
            model.eval()

        for batch_idx, batch_sample in enumerate(dataloaders[phase]):
            anc_img = batch_sample['anc_img'].to(device)
            pos_img = batch_sample['pos_img'].to(device)
            neg_img = batch_sample['neg_img'].to(device)
            if (anc_img.shape[0] != cfg.batch_size
                    or pos_img.shape[0] != cfg.batch_size
                    or neg_img.shape[0] != cfg.batch_size):
                print("Batch Size Not Equal")
                continue

            pos_cls = batch_sample['pos_class'].to(device)
            neg_cls = batch_sample['neg_class'].to(device)

            with torch.set_grad_enabled(phase == 'train'):
                try:
                    # anc_embed, pos_embed and neg_embed are encoding(embedding) of image
                    anc_embed, pos_embed, neg_embed = model(anc_img), model(
                        pos_img), model(neg_img)

                    # choose the hard negatives only for "training"
                    pos_dist = l2_dist.forward(anc_embed, pos_embed)
                    neg_dist = l2_dist.forward(anc_embed, neg_embed)

                    all = (neg_dist - pos_dist <
                           cfg.margin).cpu().numpy().flatten()
                    if phase == 'train':
                        hard_triplets = np.where(all == 1)
                        if len(hard_triplets[0]) == 0:
                            continue
                    else:
                        hard_triplets = np.where(all >= 0)
                        if len(hard_triplets[0]) == 0:
                            continue

                    anc_hard_embed = anc_embed[hard_triplets].to(device)
                    pos_hard_embed = pos_embed[hard_triplets].to(device)
                    neg_hard_embed = neg_embed[hard_triplets].to(device)

                    anc_hard_img = anc_img[hard_triplets].to(device)
                    pos_hard_img = pos_img[hard_triplets].to(device)
                    neg_hard_img = neg_img[hard_triplets].to(device)

                    pos_hard_cls = pos_cls[hard_triplets].to(device)
                    neg_hard_cls = neg_cls[hard_triplets].to(device)

                    anc_img_pred = model.forward_classifier(anc_hard_img).to(
                        device)
                    pos_img_pred = model.forward_classifier(pos_hard_img).to(
                        device)
                    neg_img_pred = model.forward_classifier(neg_hard_img).to(
                        device)

                    triplet_loss = TL_loss.forward(anc_hard_embed,
                                                   pos_hard_embed,
                                                   neg_hard_embed).to(device)
                    triplet_loss *= cfg.triplet_lambuda
                    predicted_labels = torch.cat(
                        [anc_img_pred, pos_img_pred, neg_img_pred])
                    true_labels = torch.cat(
                        [pos_hard_cls, pos_hard_cls, neg_hard_cls]).squeeze()
                    crossentropy_loss = CE_loss(predicted_labels,
                                                true_labels).to(device)
                    loss = triplet_loss + crossentropy_loss

                    if phase == 'train':
                        optimizer.zero_grad()
                        # triplet_loss.backward()
                        loss.backward()
                        optimizer.step()
                    if phase == 'valid':
                        pic_array, _ = TestData.get_data()
                        for i, pic in enumerate(pic_array):
                            pred = model.forward_classifier(
                                pic.unsqueeze(0).to(device)).to(device)
                            pred = torch.argmax(pred, 1).cpu().numpy()
                            # print(pred)
                            writer.add_image("Person {}/{}".format(pred[0], i),
                                             pic, epoch)

                    _, predicted = torch.max(predicted_labels, 1)
                    correct = (predicted == true_labels).cpu().squeeze().sum(
                    ).numpy() / (len(hard_triplets[0]) * 3)

                    dists = l2_dist.forward(anc_embed, pos_embed)
                    distances.append(dists.data.cpu().numpy())
                    labels.append(np.ones(dists.size(0)))

                    dists = l2_dist.forward(anc_embed, neg_embed)
                    distances.append(dists.data.cpu().numpy())
                    labels.append(np.zeros(dists.size(0)))

                    triplet_loss_sum += triplet_loss.item()
                    crossentropy_loss_sum += crossentropy_loss.item()
                    accuracy_sum += correct

                    triplet_loss_sigma += triplet_loss.item()
                    crossentropy_loss_sigma += crossentropy_loss.item()
                    accuracy_sigma += correct
                    if batch_idx % 10 == 0 and batch_idx != 0:
                        print(
                            '{} Inter {:4d}/{:4d} - Triplet Loss = {:.5f} - CrossEntropy Loss = {:.5f} - All Loss = {:.5f} - Accuaracy = {:.5f} len:{}'
                            .format(phase, batch_idx, len(dataloaders[phase]),
                                    triplet_loss_sigma / 10,
                                    crossentropy_loss_sigma / 10,
                                    (triplet_loss_sigma +
                                     crossentropy_loss_sigma) / 10,
                                    accuracy_sigma / 10,
                                    len(hard_triplets[0])))
                        triplet_loss_sigma = 0
                        crossentropy_loss_sigma = 0
                        accuracy_sigma = 0
                except Exception as e:
                    print(e)
                    pass
        avg_triplet_loss = triplet_loss_sum / int(
            data_size[phase] / cfg.batch_size)
        avg_crossentropy_loss = crossentropy_loss_sum / int(
            data_size[phase] / cfg.batch_size)
        labels = np.array([sublabel for label in labels for sublabel in label])
        distances = np.array(
            [subdist for dist in distances for subdist in dist])

        tpr, fpr, accuracy, val, val_std, far = evaluate(distances, labels)
        print('  {} set - Triplet Loss       = {:.8f}'.format(
            phase, avg_triplet_loss))
        print('  {} set - CrossEntropy Loss  = {:.8f}'.format(
            phase, avg_crossentropy_loss))
        print('  {} set - All Loss           = {:.8f}'.format(
            phase, avg_triplet_loss + avg_crossentropy_loss))
        print('  {} set - Accuracy           = {:.8f}'.format(
            phase, np.mean(accuracy)))

        # 记录训练loss
        writer.add_scalars('Loss/Triplet Loss Group'.format(phase),
                           {'{} triplet loss'.format(phase): avg_triplet_loss},
                           epoch)
        writer.add_scalars(
            'Loss/Crossentropy Loss Group'.format(phase),
            {'{} crossentropy loss'.format(phase): avg_crossentropy_loss},
            epoch)
        writer.add_scalars('Loss/All Loss Group'.format(phase), {
            '{} loss'.format(phase):
            avg_triplet_loss + avg_crossentropy_loss
        }, epoch)
        writer.add_scalars('Accuracy_group'.format(phase),
                           {'{} accuracy'.format(phase): np.mean(accuracy)},
                           epoch)
        # 记录learning rate
        writer.add_scalar('learning rate', scheduler.get_lr()[0], epoch)

        with open('./log/{}_log_epoch{}.txt'.format(phase, epoch), 'w') as f:
            f.write(
                str(epoch) + '\t' + str(np.mean(accuracy)) + '\t' +
                str(avg_triplet_loss) + '\t' + str(avg_crossentropy_loss) +
                '\t' + str(avg_triplet_loss + avg_crossentropy_loss))

        if phase == 'train':
            torch.save({
                'epoch': epoch,
                'state_dict': model.state_dict()
            }, './log/checkpoint_epoch{}.pth'.format(epoch))
        else:
            plot_roc(fpr,
                     tpr,
                     figure_name='./log/roc_valid_epoch_{}.png'.format(epoch))
Exemplo n.º 6
0
def train_model():
    step = 0
    log_file = open(log_file_path, 'w')
    iteration = 0
    lr = args.lr
    top1 = AverageMeter()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    train_loss = []
    val_acc = []
    for epoch in range(args.epochs):
        # for phase in ['train']:
        for phase in ['train', 'val']:
            load_t0 = time.time()
            if phase == 'train':
                model.train()
            else:
                model.eval()
            running_loss = 0.0
            running_corrects = 0
            # top1=AverageMeter()
            if phase == 'train':
                for batch_dix, (inputs,
                                labels) in enumerate(data_loaders[phase]):
                    #---(batch,3,96,96)
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    optimizer.zero_grad()

                    with torch.set_grad_enabled(phase == 'train'):
                        prediction = model.forward_classifier(inputs)
                        _, preds = torch.max(F.softmax(prediction), 1)
                        center_loss, model.centers = model.get_center_loss(
                            labels, args.alpha)
                        cross_entropy_loss = criterion(prediction, labels)
                        loss = args.center_loss_weight * center_loss + cross_entropy_loss

                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                            iteration += 1
                            # top1.update
                            if epoch in [10, 18]:
                                step += 1
                                # lr=adjust_learning_rate(optimizer, args.gamma,step,args.lr
                            # prec=accuracy(F.softmax(prediction), labels, topk=(1,))
                            # top1.update(prec[0], inputs.size(0))
                    load_t1 = time.time()
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                    if iteration % 10 == 0 and phase == 'train':
                        log_file.write('Epoch:' + repr(epoch) +
                                       ' || Totel iter ' + repr(iteration) +
                                       ' || L: {:.4f} A: {:.4f} || '.format(
                                           loss.item() * inputs.size(0),
                                           torch.sum(preds == labels.data)) +
                                       'Batch time: {:.4f}  || '.format(
                                           load_t1 - load_t0) +
                                       'LR: {:.8f}'.format(lr) + '\n')
                    train_loss.append(running_loss / dataset_sizes[phase])
                epoch_loss = running_loss / dataset_sizes[phase]

                epoch_acc = running_corrects.double() / dataset_sizes[phase]
                print('epoch {} : {} Loss: {:.4f} Acc: {:.4f}'.format(
                    epoch, phase, epoch_loss, epoch_acc))

            else:
                targets, distances = [], []
                for batch_dix, (data_a, data_p,
                                labels) in enumerate(data_loaders[phase]):
                    data_a, data_p = data_a.to(device), data_p.to(device)
                    labels = labels.to(device)
                    out_a, out_p = model(data_a), model(data_p)
                    dists = l2_dist.forward(out_a, out_p)
                    distances.append(dists.data.cpu().numpy())
                    targets.append(labels.data.cpu().numpy())
                    # _, _, plot_acc, _,_,_ = evaluate(dists.data.cpu().numpy().tolist(),labels.data.cpu().tolist())

                targets = np.array(
                    [sublabel for label in targets for sublabel in label])
                distances = np.array(
                    [subdist for dist in distances for subdist in dist])
                tpr, fpr, acc, val, val_std, far = evaluate(distances, targets)
                acc_max = acc.max()
                val_acc.append(acc_max.max())
                # embed()
                print('epoch {} : {}  Acc: {:.4f} '.format(
                    epoch, phase, acc_max))
                log_file.write('epoch {} : {}  Acc: {:.4f} '.format(
                    epoch, phase, acc_max))
                plot_roc(
                    fpr,
                    tpr,
                    figure_name='./log/roc_valid_epoch_{}.png'.format(epoch))

                if acc_max > best_acc:
                    best_acc = acc_max
                    best_model_wts = copy.deepcopy(model.state_dict())

    plot_acc(train_loss, val_acc, './loss.png')
    torch.save(
        {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'centers': model.centers,
        }, './centers_{}.pth'.format(best_acc))