def train(args):
    adapt = args.adapt

    train_transforms = torchvision.transforms.Compose([
        transforms.Resize((255, 255)),
        transforms.ToTensor(),
    ])
    cell_dataset = Cell_Dataset(transform=train_transforms)

    train_loader = DataLoader(cell_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)

    net = EmbeddingNet()
    if adapt:
        print("Adapt the adapt module for new dataset")
        for name, param in net.named_parameters():
            if 'adapt' not in name:
                param.require_grad = False
    else:
        print("Train with ILSV2015")

    # Optimizer
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                               milestones=args.lr_steps)

    criterion = nn.MSELoss()
    net.cuda()
    for epoch in range(args.epochs):
        # Each epoch has a training and validation phase
        net.train()
        log_loss = []

        for i_batch, (input_im, patch_im,
                      input_label) in enumerate(train_loader):
            inputs_im, patch_im, input_label = input_im.cuda(), patch_im.cuda(
            ), input_label.cuda()

            output_heatmap = net(inputs_im, patch_im)
            #            print(output_heatmap.shape, input_label.shape)
            loss = criterion(output_heatmap, input_label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            log_loss += [loss.item()]

            if i_batch % 10 == 0:
                log = 'Epoch: %3d, Batch: %5d, ' % (epoch + 1, i_batch)
                log += 'Total Loss: %6.3f, ' % (np.mean(log_loss))
                print(log, datetime.datetime.now())

        scheduler.step()
Example #2
0
def main(args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    p = args.labels_per_batch
    k = args.samples_per_label
    batch_size = p * k

    model = EmbeddingNet()
    if args.resume:
        model.load_state_dict(torch.load(args.resume))

    model.to(device)

    criterion = TripletMarginLoss(margin=args.margin)
    optimizer = Adam(model.parameters(), lr=args.lr)

    transform = transforms.Compose([
        transforms.Lambda(lambda image: image.convert("RGB")),
        transforms.Resize((224, 224)),
        transforms.PILToTensor(),
        transforms.ConvertImageDtype(torch.float),
    ])

    # Using FMNIST to demonstrate embedding learning using triplet loss. This dataset can
    # be replaced with any classification dataset.
    train_dataset = FashionMNIST(args.dataset_dir,
                                 train=True,
                                 transform=transform,
                                 download=True)
    test_dataset = FashionMNIST(args.dataset_dir,
                                train=False,
                                transform=transform,
                                download=True)

    # targets is a list where the i_th element corresponds to the label of i_th dataset element.
    # This is required for PKSampler to randomly sample from exactly p classes. You will need to
    # construct targets while building your dataset. Some datasets (such as ImageFolder) have a
    # targets attribute with the same format.
    targets = train_dataset.targets.tolist()

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              sampler=PKSampler(targets, p, k),
                              num_workers=args.workers)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.eval_batch_size,
                             shuffle=False,
                             num_workers=args.workers)

    for epoch in range(1, args.epochs + 1):
        print("Training...")
        train_epoch(model, optimizer, criterion, train_loader, device, epoch,
                    args.print_freq)

        print("Evaluating...")
        evaluate(model, test_loader, device)

        print("Saving...")
        save(model, epoch, args.save_dir, "ckpt.pth")
Example #3
0
  def __init__(self, checkpoint_path):

    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    self.model = EmbeddingNet()
    self.model.load_state_dict(torch.load(checkpoint_path))

    self.model.to(self.device)
    self.model.eval()    
Example #4
0
def main(args):
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda is True:
        print("##########Running in CUDA GPU##########")
        kwargs = {'num_workers': 4, 'pin_memory': True}
    else:
        print("##########Running in CPU##########")
        kwargs = {}
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    train_data = myDataset("./data", [1, 2, 3, 4, 5, 6], args.nframes, "train",
                           None)
    train_loader = Data.DataLoader(dataset=train_data,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   **kwargs)

    # dev_data = myDataset("./data", None, args.nframes, "dev", "enrol")
    # dev_loader = Data.DataLoader(dataset=dev_data, batch_size=args.test_batch_size, shuffle=False, **kwargs)

    test_data = myDataset("./data", None, args.nframes, "test", "enrol")
    test_loader = Data.DataLoader(dataset=test_data,
                                  batch_size=args.test_batch_size,
                                  shuffle=False,
                                  **kwargs)

    embeddingNet = EmbeddingNet(embedding_size=64,
                                num_classes=train_data.nspeakers)
    # embeddingNet = embeddingNet.double()
    if args.cuda:
        print("##########model is in cuda mode##########")
        gpu_ids = [0, 1, 2, 3, 4, 5, 6, 7]
        embeddingNet.cuda()
        # embeddingNet = nn.DataParallel(embeddingNet, device_ids=[0])

    criterion = nn.CrossEntropyLoss()
    # optimizer = optim.SGD(embeddingNet.parameters(), lr=args.lr, momentum=args.momentum)
    optimizer = optim.Adam(embeddingNet.parameters(),
                           lr=args.lr,
                           weight_decay=0.001)
    scheduler = lr_scheduler.StepLR(optimizer, 5)
    start_epoch = 0
    if args.resume is True:
        embeddingNet, optimizer, start_epoch, loss = load_model(
            args.load_epoch, args.load_step, args.load_loss, embeddingNet,
            optimizer, "./weights/")

    # for epoch in range(start_epoch, args.epochs):
    # scheduler.step()
    # train(train_loader, embeddingNet, optimizer, criterion, epoch, args)
    # dev(dev_loader, embeddingNet, epoch)

    scores = test(test_loader, embeddingNet)
    write_scores(scores, "scores.npy")
Example #5
0
def create_embedder(embedding_model=''):
    embedder = EmbeddingNet()
    if embedding_model != '':
        embedder.load_state_dict(torch.load(embedding_model))

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        embedder = torch.nn.DataParallel(embedder)

    embedder.cuda()
    return embedder
Example #6
0
def main(args):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    p = args.labels_per_batch
    k = args.samples_per_label
    batch_size = p * k

    model = EmbeddingNet(backbone=None, pretrained=args.pretrained)
    print('pretrained: ', args.pretrained)
    if args.resume:
        model.load_state_dict(torch.load(args.resume))

    model.to(device)

    criterion = TripletMarginLoss(margin=args.margin)
    optimizer = Adam(model.parameters(), lr=args.lr)

    train_transform = transforms.Compose([transforms.Lambda(lambda image: image.convert('RGB')),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.Resize((224, 224)),
                                    transforms.ToTensor()])
    
    test_transform = transforms.Compose([transforms.Lambda(lambda image: image.convert('RGB')),
                                    transforms.Resize((224, 224)),
                                    transforms.ToTensor()])    

    # Using FMNIST to demonstrate embedding learning using triplet loss. This dataset can
    # be replaced with any classification dataset.
    train_dataset = ImageFolder(os.path.join(args.dataset_dir, 'train'), transform=train_transform)
    test_dataset = ImageFolder(os.path.join(args.dataset_dir, 'test'), transform=test_transform)

    # targets is a list where the i_th element corresponds to the label of i_th dataset element.
    # This is required for PKSampler to randomly sample from exactly p classes. You will need to
    # construct targets while building your dataset. Some datasets (such as ImageFolder) have a
    # targets attribute with the same format.
    targets = train_dataset.targets

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              sampler=PKSampler(targets, p, k),
                              num_workers=args.workers)
    test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size,
                             shuffle=False,
                             num_workers=args.workers)

    for epoch in range(1, args.epochs + 1):
        print('Training...')
        train_epoch(model, optimizer, criterion, train_loader, device, epoch, args.print_freq)

        print('Evaluating...')
        evaluate(model, test_loader, device)

        print('Saving...')
        save(model, epoch, args.save_dir, 'ckpt.pth')
Example #7
0
  def __init__(self, checkpoint_path):

    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    self.model = EmbeddingNet()
    self.model.load_state_dict(torch.load(checkpoint_path))

    self.model.to(self.device)
    self.model.eval()
    
    self.transform = transforms.Compose([transforms.Lambda(lambda image: image.convert('RGB')),
                                transforms.Resize((224, 224)),
                                transforms.ToTensor()])
Example #8
0
def create_classifier(embedding_model='', model=''):
    embedder = EmbeddingNet()
    if embedding_model != '':
        embedder.load_state_dict(torch.load(embedding_model))
    classifier = FullNet(embedder)

    if model != '':
        classifier.load_state_dict(torch.load(model))

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        classifier = torch.nn.DataParallel(classifier)

    classifier.cuda()
    return classifier
Example #9
0
def create_lcc(embedding_model='', model=''):
    embedder = EmbeddingNet()
    if embedding_model != '':
        embedder.load_state_dict(torch.load(embedding_model))

    lcc = LCCNet(embedder)

    if model != '':
        lcc.load_state_dict(torch.load(model))

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        lcc = torch.nn.DataParallel(lcc)

    lcc.cuda()
    return lcc
Example #10
0
def create_distance_model(embedding_model='', model=''):
    embedder = EmbeddingNet()
    if embedding_model != '':
        embedder.load_state_dict(torch.load(embedding_model))

    distanceModel = DistanceNet(embedder)

    if model != '':
        distanceModel.load_state_dict(torch.load(model))

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        distanceModel = torch.nn.DataParallel(distanceModel)

    distanceModel.cuda()
    return distanceModel
Example #11
0
def get_model(args):
    preserved = None
    model = EmbeddingNet(network=args.model_name, pretrained=args.pretrained, embedding_len=args.embedding_size)
    if args.increment_phase == 0:
        pass
    elif args.increment_phase == -1:
        model = ClassificationNet(model)
    elif args.increment_phase == 1:
        try:
            pkl = torch.load(args.train_set_old+'/pkl/state_best.pth')
            model = pkl['model']
            preserved = {'fts_means': pkl['fts_means'],
                'preserved_embedding': pkl['embeddings']}
        except OSError as reason:
            print(args.train_set_old+'.........')
            print(reason)
    else:
        print(args.increment_phase)
        raise NotImplementedError

    return model, preserved
Example #12
0
def main():
    sys_time = str(datetime.datetime.now())
    if not os.path.exists(args.check_path):
        os.mkdir(args.check_path)
    log_path = args.check_path + os.path.sep + sys_time + 'CE.txt'
    log_file = open(log_path, 'w+')

    print('Loading model......')

    # model = torch.load(log_path)
    embedding_net = EmbeddingNet()
    model = ClassificationNet(embedding_net, 10)
    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True

    train_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=np.array([0.485, 0.456, 0.406]),
                             std=np.array([0.229, 0.224, 0.225])),
    ])

    test_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=np.array([0.485, 0.456, 0.406]),
                             std=np.array([0.229, 0.224, 0.225])),
    ])

    print('Loading data...')
    train_set = torchvision.datasets.ImageFolder(root=args.train_set,
                                                 transform=train_transform)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    test_set = torchvision.datasets.ImageFolder(root=args.test_set,
                                                transform=test_transform)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.test_batch_size,
                                              shuffle=True)
    classweight = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 10.0, 10.0, 10.0]
    classweight = np.array(classweight)
    classweight = torch.from_numpy(classweight).float().cuda()
    criterion = nn.CrossEntropyLoss(weight=classweight)

    # ignored_params = list(map(id, model.fc.parameters()))
    # base_params = filter(lambda p: id(p) not in ignored_params,
    #                      model.parameters())
    #
    # optimizer = torch.optim.SGD([
    #     {'params': base_params},
    #     {'params': model.fc.parameters(), 'lr': 1e-2}
    # ], lr=1e-3, momentum=0.9)
    optimizer = optim.SGD([{
        'params': model.module.model.parameters()
    }, {
        'params': model.module.fc.parameters(),
        'lr': 1e-2
    }],
                          lr=1e-1,
                          momentum=0.9)

    writer = SummaryWriter()

    def train(epoch):
        model.train()
        total_loss = 0.0
        correct = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)

            total_loss += loss.item()

            loss.backward()
            optimizer.step()

            _, predicted = output.max(1)

            correct += predicted.eq(target).sum().item()

            writer.add_scalar('/Loss', loss.item(),
                              epoch * len(train_loader) + batch_idx)

            if batch_idx % args.log_interval == 0:
                context = 'Train Epoch: {} [{}/{} ({:.0f}%)], Average loss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    total_loss / (batch_idx + 1))
                print(context)
                log_file.write(context + '\r\n')

        context = 'Train set:  Accuracy: {}/{} ({:.3f}%)\n'.format(
            correct, len(train_loader.dataset),
            100. * float(correct) / len(train_loader.dataset))
        print(context)
        log_file.write(context + '\r\n')

    def test(epoch):
        model.eval()
        test_loss = 0
        correct = 0
        target_arr = []
        predict_arr = []

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(test_loader):
                data, target = data.cuda(), target.cuda()
                data, target = Variable(data), Variable(target)

                output = model(data)

                test_loss += criterion(output, target)
                _, pred = output.data.max(1)
                batch_correct = pred.eq(target).sum().item()
                correct += batch_correct

                predict_arr.append(pred.cpu().numpy())
                target_arr.append(target.data.cpu().numpy())
                writer.add_scalar('/Acc',
                                  100 * float(batch_correct) / data.size(0),
                                  epoch * len(test_loader) + batch_idx)

            cm_path = './' + str(epoch) + '_confusematrix'
            cm = metrics.get_confuse_matrix(predict_arr, target_arr)
            np.save(cm_path, cm)
            test_loss /= len(test_loader)
            context = 'Test set: Average loss: {:.6f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / float(len(test_loader.dataset)))
            print(context)
            log_file.write(context + '\r\n')

    def update_lr(optimizer, lr):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    print('start training')
    for epoch in range(args.epochs):
        if epoch % 5 == 0:
            print("            model: {}".format(args.model_name))
            print("            num_triplet: {}".format(train_set.__len__()))
            print("            check_path: {}".format(args.check_path))
            print("            learing_rate: {}".format(args.lr))
            print("            batch_size: {}".format(args.batch_size))
            print("            is_pretrained: {}".format(args.is_pretrained))
            print("            optimizer: {}".format(optimizer))
            log_file.write("            model: {}".format(args.model_name) +
                           '\r\n')
            log_file.write(
                "            num_triplet: {}".format(train_set.__len__()) +
                '\r\n')
            log_file.write(
                "            check_path: {}".format(args.check_path) + '\r\n')
            log_file.write("            learing_rate: {}".format(args.lr) +
                           '\r\n')
            log_file.write(
                "            batch_size: {}".format(args.batch_size) + '\r\n')
            log_file.write("            optimizer: {}".format(optimizer) +
                           '\r\n')
        train(epoch)
        test(epoch)
        if epoch == 50:
            args.lr = args.lr / 5
            update_lr(optimizer, args.lr)
        elif epoch == 100:
            args.lr = args.lr / 5
            update_lr(optimizer, args.lr)
        elif epoch == 150:
            args.lr = args.lr / 5
            update_lr(optimizer, args.lr)
        log_file.write('\r\n')

    model_path = args.check_path + os.path.sep + sys_time + 'model.pkt'
    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
    log_file.close()
    torch.save(model, model_path)
Example #13
0
import torch
from model import RelationNet, EmbeddingNet
from runner import Runner
import yaml

feature_encoder = EmbeddingNet()
relation_network = RelationNet()

feature_encoder_optim = torch.optim.SGD(feature_encoder.parameters(),
                                        lr=0.001,
                                        momentum=0.9)
relation_network_optim = torch.optim.SGD(relation_network.parameters(),
                                         lr=0.001,
                                         momentum=0.9)

feature_encoder_scheduler = torch.optim.lr_scheduler.StepLR(
    feature_encoder_optim, step_size=30, gamma=0.1)
relation_network_scheduler = torch.optim.lr_scheduler.StepLR(
    relation_network_optim, step_size=30, gamma=0.1)

loss = torch.nn.MSELoss()

with open("./default.yaml", 'r') as f:
    cfg = yaml.load(f, Loader=yaml.SafeLoader)

runner = Runner(feature_encoder, relation_network, feature_encoder_optim,
                relation_network_optim, feature_encoder_scheduler,
                relation_network_scheduler, loss, cfg)

runner.run()
Example #14
0
import tensorflow as tf
from dataset import TripletLossDataset
from train import train_step
from model import EmbeddingNet

num_steps = 1000
learning_rate = 0.001
display_step = 100

data = TripletLossDataset(num_clusters=4, num_examples=256, batch_size=16)
data.build_dataset()

model = EmbeddingNet()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
loss = None

for step in range(num_steps):
    for batch in data.dataset:
        x_batch = tf.cast(batch[0], dtype=tf.float32)
        y_batch = tf.cast(batch[1], dtype=tf.float32)
        loss = train_step(model, x_batch, y_batch, optimizer)
    if step % display_step == 0:
        print(f"Step: {step:04d}\tLoss: {loss.numpy():.4f}")
def main():

    # 4. dataset
    mean, std = 0.1307, 0.3081

    transform = tfs.Compose([tfs.Normalize((mean, ), (std, ))])
    test_transform = tfs.Compose([tfs.ToTensor(),
                                  tfs.Normalize((mean,), (std,))])

    train_set = MNIST('./data/MNIST',
                      train=True,
                      download=True,
                      transform=None)

    train_set = SEMI_MNIST(train_set,
                           transform=transform,
                           num_samples=100)

    test_set = MNIST('./data/MNIST',
                     train=False,
                     download=True,
                     transform=test_transform)

    test_set = SEMI_MNIST(test_set,
                          transform=transform,
                          num_samples=100)

    # 5. data loader
    train_loader = DataLoader(dataset=train_set,
                              shuffle=True,
                              batch_size=1,
                              num_workers=8,
                              pin_memory=True
                              )

    test_loader = DataLoader(dataset=test_set,
                             shuffle=False,
                             batch_size=1,
                             )

    # 6. model
    model = EmbeddingNet().cuda()
    model.load_state_dict(torch.load('./saves/state_dict.{}'.format(15)))

    # 7. criterion
    criterion = MetricCrossEntropy()

    data = []
    y = []
    is_known_ = []
    # for idx, (imgs, targets, samples, is_known) in enumerate(train_loader):
    #     model.train()
    #     batch_size = 1
    #     imgs = imgs.cuda()  # [N, 1, 28, 28]
    #     targets = targets.cuda()  # [N]
    #     samples = samples.cuda() # [N, 1, 32, 32]
    #     is_known = is_known.cuda()
    #
    #     output = model(imgs)
    #     y.append(targets.cpu().detach().numpy())
    #     is_known_.append(is_known.cpu().detach().numpy())
    #
    #     if idx % 100 == 0:
    #         print(idx)
    #         print(output.size())
    #
    #     data.append(output.cpu().detach().numpy())
    #
    # data_numpy = np.array(data)
    # y_numpy = np.array(y)
    # is_known_numpy = np.array(is_known_)
    #
    # np.save('data', data_numpy)
    # np.save('known', is_known_numpy)
    # np.save('y', y)

    data_numpy = np.load('data.npy')
    y_numpy = np.load('y.npy')
    is_known_numpy = np.load('known.npy')

    print(data_numpy.shape)
    print(y_numpy.shape)

    data_numpy = np.squeeze(data_numpy)
    y_numpy = np.squeeze(y_numpy)
    is_known_numpy = np.squeeze(is_known_numpy)

    print(data_numpy.shape)
    print(y_numpy.shape)

    from sklearn.manifold import TSNE
    import matplotlib.pyplot as plt

    colors = ['#476A2A', '#7851B8', '#BD3430', '#4A2D4E', '#875525',
              '#A83683', '#4E655E', '#853541', '#3A3120', '#535D8E']

    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
              '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
              '#bcbd22', '#17becf', '#ada699']

    # t-SNE 모델 생성 및 학습
    # tsne = TSNE(random_state=0)
    # digits_tsne = tsne.fit_transform(data_numpy)
    # np.save('tsne', digits_tsne)
    digits_tsne = np.load('tsne.npy')
    print('complete t-sne')

    # ------------------------------ 1 ------------------------------
    plt.figure(figsize=(10, 10))
    for i in range(11):
        inds = np.where(y_numpy == i)[0]
        known = is_known_numpy[inds]
        known_idx = np.where(known == 1)
        unknown_idx = np.where(known == 0)

        plt.scatter(digits_tsne[inds[unknown_idx], 0], digits_tsne[inds[unknown_idx], 1], alpha=0.5, color=colors[10])
        plt.scatter(digits_tsne[inds[known_idx], 0], digits_tsne[inds[known_idx], 1], alpha=0.5, color=colors[i])

        # plt.scatter(digits_tsne[inds, 0], digits_tsne[inds, 1], alpha=0.5, color=colors[i])

    plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'unknown'])
    plt.show()  # 그래프 출력
def main():
    # 1. argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--resume', type=int, default=0)
    opts = parser.parse_args()
    print(opts)

    # 2. device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 3. visdom
    vis = visdom.Visdom()

    # 4. dataset
    mean, std = 0.1307, 0.3081

    transform = tfs.Compose([tfs.Normalize((mean, ), (std, ))])
    test_transform = tfs.Compose(
        [tfs.ToTensor(), tfs.Normalize((mean, ), (std, ))])

    train_set = MNIST('./data/MNIST',
                      train=True,
                      download=True,
                      transform=None)

    train_set = SEMI_MNIST(train_set, transform=transform, num_samples=100)

    test_set = MNIST('./data/MNIST',
                     train=False,
                     download=True,
                     transform=test_transform)

    # 5. data loader
    train_loader = DataLoader(dataset=train_set,
                              shuffle=True,
                              batch_size=opts.batch_size,
                              num_workers=8,
                              pin_memory=True)

    test_loader = DataLoader(
        dataset=test_set,
        shuffle=False,
        batch_size=opts.batch_size,
    )

    # 6. model
    model = EmbeddingNet().to(device)

    # 7. criterion
    criterion = MetricCrossEntropy().to(device)

    # 8. optimizer
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=opts.lr,
                                momentum=0.9,
                                weight_decay=5e-4)

    # 9. scheduler
    scheduler = StepLR(optimizer=optimizer, step_size=50, gamma=1)
    # 10. resume
    if opts.resume:
        model.load_state_dict(
            torch.load('./saves/state_dict.{}'.format(opts.resume)))
        print("resume from {} epoch..".format(opts.resume - 1))
    else:
        print("no checkpoint to resume.. train from scratch.")

    # --
    for epoch in range(opts.resume, opts.epoch):

        # 11. trian
        for idx, (imgs, targets, samples, is_known) in enumerate(train_loader):
            model.train()
            batch_size = opts.batch_size

            imgs = imgs.to(device)  # [N, 1, 28, 28]
            targets = targets.to(device)  # [N]
            samples = samples.to(device)  # [N, 1, 32, 32]
            is_known = is_known.to(device)

            samples = samples.view(batch_size * 10, 1, 28, 28)
            out_x = model(imgs)  # [N, 10]
            out_z = model(samples).view(batch_size, 10,
                                        out_x.size(-1))  # [N * 10 , 2]
            loss = criterion(out_x, targets, out_z, is_known, 10, 1)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            for param_group in optimizer.param_groups:
                lr = param_group['lr']

            if idx % 100 == 0:
                print('Epoch : {}\t'
                      'step : [{}/{}]\t'
                      'loss : {}\t'
                      'lr   : {}\t'.format(epoch, idx, len(train_loader), loss,
                                           lr))

                vis.line(X=torch.ones(
                    (1, 1)) * idx + epoch * len(train_loader),
                         Y=torch.Tensor([loss]).unsqueeze(0),
                         update='append',
                         win='loss',
                         opts=dict(x_label='step',
                                   y_label='loss',
                                   title='loss',
                                   legend=['total_loss']))

        torch.save(model.state_dict(), './saves/state_dict.{}'.format(epoch))

        # 12. test
        correct = 0
        avg_loss = 0
        for idx, (img, target) in enumerate(test_loader):

            model.load_state_dict(
                torch.load('./saves/state_dict.{}'.format(epoch)))
            model.eval()
            img = img.to(device)  # [N, 1, 28, 28]
            target = target.to(device)  # [N]
            output = model(img)  # [N, 10]

            output = torch.softmax(output, -1)
            pred, idx_ = output.max(-1)
            print(idx_)
            correct += torch.eq(target, idx_).sum()
            #loss = criterion(output, target)
            #avg_loss += loss.item()

        print('Epoch {} test : '.format(epoch))
        accuracy = correct.item() / len(test_set)
        print("accuracy : {:.4f}%".format(accuracy * 100.))
        #avg_loss = avg_loss / len(test_loader)
        #print("avg_loss : {:.4f}".format(avg_loss))

        vis.line(X=torch.ones((1, 1)) * epoch,
                 Y=torch.Tensor([accuracy]).unsqueeze(0),
                 update='append',
                 win='test',
                 opts=dict(x_label='epoch',
                           y_label='test_',
                           title='test_loss',
                           legend=['accuracy']))
        scheduler.step()
def test_triplet():
    parser = argparse.ArgumentParser(
        description='Face recognition using triplet loss.')
    parser.add_argument('--CVDs', type=str, default='1,2', metavar='CUDA_VISIBLE_DEVICES',
                        help='CUDA_VISIBLE_DEVICES')
    parser.add_argument('--server', type=int, default=82, metavar='T',
                        help='which server is being used')
    parser.add_argument('--train-set', type=str, default='/home/zili/memory/FaceRecognition-master/data/cifar100/train2', metavar='dir',
                        help='path of train set.')
    parser.add_argument('--test-set', type=str, default='/home/zili/memory/FaceRecognition-master/data/cifar100/test2', metavar='dir',
                        help='path of train set.')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 32)')
    parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--epochs', type=int, default=300, metavar='N',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--embedding-size', type=int, default=128, metavar='N',
                        help='embedding size of model (default: 256)')
    parser.add_argument('--num-classes', type=int, default=10, metavar='N',
                        help='classes number of dataset')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.8, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=2, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--model-name', type=str, default='resnet34', metavar='M',
                        help='model name (default: resnet50)')
    parser.add_argument('--dropout-p', type=float, default=0.2, metavar='D',
                        help='Dropout probability (default: 0.2)')
    parser.add_argument('--check-path', type=str, default='/home/zili/memory/FaceRecognition-master/checkpoints2', metavar='C',
                        help='Checkpoint path')
    parser.add_argument('--pretrained', type=bool, default=False,metavar='R',
                        help='whether model is pretrained.')
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.CVDs
    if args.server == 31:
        args.train_set  = '/share/zili/code/triplet/data/mnist/train'
        args.test_set   = '/share/zili/code/triplet/data/test_class'
        args.check_path = '/share/zili/code/triplet/checkpoints2'
    if args.server == 16:
        args.train_set  = '/data0/zili/code/triplet/data/cifar100/train2'
        args.test_set   = '/data0/zili/code/triplet/data/cifar100/test2'
        args.check_path = '/data0/zili/code/triplet/checkpoints2'
    if args.server == 17:
        args.train_set  = '/data/jiaxin/zili/data/cifar100/train2'
        args.test_set   = '/data/jiaxin/zili/data/cifar100/test'
        args.check_path = '/data/jiaxin/zili/checkpoints2'
    if args.server == 15:
        args.train_set = '/home/zili/code/triplet/data/cifar100/train2'
        args.test_set = '/home/zili/code/triplet/data/cifar100/test2'
        args.train_set_csv = '/home/zili/code/triplet/data/cifar100/train.csv'
        args.check_path = '/home/zili/code/triplet/checkpoints'
    now_time = str(datetime.datetime.now())
    args.check_path = os.path.join(args.check_path, now_time)

    if not os.path.exists(args.check_path):
        os.mkdir(args.check_path)

    shutil.copy('crossentropyloss.py', args.check_path)
    # os.path.join(args.check_path)
    f = open(args.check_path + os.path.sep + now_time + 'CrossEntropy.txt', 'w+')

    print('Loading model...')
    # model = FaceModelForCls(model_name=args.model_name,
    #                         num_classes=args.num_classes,
    #                         pretrained=args.pretrained)
    # model = model_.ResNet34(False, num_classes=args.num_classes)
    embedding_net = EmbeddingNet()
    model = ClassificationNet(embedding_net, 10)
    f.write("     model: {}".format(model))
    if torch.cuda.is_available():
        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True

    transform = transforms.Compose([
        # transforms.Resize(32),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=np.array([0.4914, 0.4822, 0.4465]),
            std=np.array([0.2023, 0.1994, 0.2010])),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    print('Loading data...')
    train_set = torchvision.datasets.ImageFolder(root=args.train_set, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=20)

    test_set = torchvision.datasets.ImageFolder(root=args.test_set,transform=test_transform)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=1)

    weight = torch.FloatTensor([1.,1.,1.,1.,1.,1.,1,10.,10.,10.]).cuda()
    criterion = nn.CrossEntropyLoss(weight=weight)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4)
    # optimizer = optim.Adam(model.parameters(),lr=args.lr)

    writer = SummaryWriter()
    def train(epoch):
        model.train()
        total_loss, correct = 0.0, 0
        fea, l = torch.zeros(0), torch.zeros(0)
        for batch_idx, (data, target) in enumerate(train_loader):

            data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)

            optimizer.zero_grad()
            output = model.forward(data)
            loss = criterion(output, target)

            total_loss += loss.item()
            loss.backward()
            optimizer.step()

            _, predicted = output.max(1)
            fea = torch.cat((fea, output.data.cpu()))
            l = torch.cat((l, target.data.cpu().float()))

            correct += predicted.eq(target).sum().item()
            writer.add_scalar('/Loss', loss.item(), epoch * len(train_loader) + batch_idx)

            if (batch_idx+1) % args.log_interval == 0:
                context = 'Train Epoch: {} [{}/{} ({:.0f}%)], Average loss: {:.6f}'.format(
                          epoch, fea.size()[0], len(train_loader.dataset),
                          100.0 * batch_idx / len(train_loader), total_loss / (batch_idx+1))
                print(context)
                f.write(context + '\r\n')

        writer.add_embedding(mat=fea, metadata=l, global_step=epoch)
        context = 'Train Epoch: {} [{}/{} ({:.0f}%)], Average loss: {:.4f}'.format(
            epoch, len(train_loader.dataset), len(train_loader.dataset),
            100.0 * len(train_loader) / len(train_loader), total_loss / len(train_loader))
        print(context)
        f.write(context + '\r\n')

        context = 'Train set:  Accuracy: {}/{} ({:.3f}%)\n'.format(
             correct, len(train_loader.dataset),
            100. * float(correct) / len(train_loader.dataset))
        print(context)
        f.write(context+'\r\n')


    def test(epoch):
        model.eval()
        test_loss, correct = 0, 0
        target_arr, predict_arr = [], []
        data1 = [0] * args.num_classes
        data2 = [0] * args.num_classes
        data3 = [0] * args.num_classes
        with torch.no_grad():
            for i , (data, target) in enumerate(test_loader):

                data, target = data.cuda(), target.cuda()
                data, target = Variable(data), Variable(target)

                output = model(data)
                test_loss += criterion(output, target)
                _, pred = output.data.max(1)

                for i in range(0, target.size(0)):
                    data1[target[i]] += 1
                    data3[pred[i]] += 1
                    if target[i] == pred[i]:
                        data2[target[i]] += 1

                batch_correct = pred.eq(target).sum().item()
                correct += batch_correct

                predict_arr.append(pred.cpu().numpy())
                target_arr.append(target.data.cpu().numpy())
                writer.add_scalar('/Acc', 100 * float(batch_correct) / data.size(0), epoch * len(test_loader) + i)

            cm_path = args.check_path + '/' + str(epoch) + '_confusematrix'
            cm = metrics.get_confuse_matrix(predict_arr, target_arr)
            np.save(cm_path, cm)

            for j in range(10):
                recall = 0 if data1[j] == 0 else data2[j] / data1[j]
                precision = 0 if data3[j] == 0 else data2[j] / data3[j]
                context = 'Class%1s: recall is %.2f%% (%d in %d), precision is %.2f%% (%d in %d)' % (
                    str(j), 100 * recall, data2[j], data1[j],
                    100 * precision, data2[j], data3[j])
                print(context)
                f.write(context + "\r\n")

            test_loss /= len(test_loader)
            context = 'Test set: Average loss: {:.6f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / float(len(test_loader.dataset)))
            print(context)
            f.write(context + '\r\n')



    def update_lr(optimizer, lr):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    print('start training')
    for epoch in range(args.epochs):
        if (epoch ) % 2 == 0:
            print(" server: {}".format(args.server))
            print(" train-size: {}".format(args.train_set))
            print(" embedding-size: {}".format(args.embedding_size))
            print(" model: {}".format(args.model_name))
            print(" dropout: {}".format(args.dropout_p))
            print(" num_train_set: {}".format(train_set.__len__()))
            print(" check_path: {}".format(args.check_path))
            print(" learing_rate: {}".format(args.lr))
            print(" batch_size: {}".format(args.batch_size))
            print(" pretrained: {}".format(args.pretrained))
            print(" optimizer: {}".format(optimizer))
            f.write(" server: {}".format(args.server) + '\r\n')
            f.write(" train-size: {}".format(args.train_set) + '\r\n')
            f.write(" embedding-size: {}".format(args.embedding_size) + '\r\n')
            f.write(" model: {}".format(args.model_name) + '\r\n')
            f.write(" dropout: {}".format(args.dropout_p) + '\r\n')
            f.write(" num_train_set: {}".format(train_set.__len__()) + '\r\n')
            f.write(" check_path: {}".format(args.check_path) + '\r\n')
            f.write(" learing_rate: {}".format(args.lr) + '\r\n')
            f.write(" batch_size: {}".format(args.batch_size) + '\r\n')
            f.write(" pretrained: {}".format(args.pretrained) + '\r\n')
            f.write(" optimizer: {}".format(optimizer) + '\r\n')
        train(epoch)
        test(epoch)
        if (epoch + 1) % 20 == 0 :
            args.lr = args.lr / 3
            update_lr(optimizer, args.lr)
        if epoch > 30 and epoch < 80:
            torch.save(model, args.check_path+ os.path.sep + 'epoch' + str(epoch)+'.pth')
        f.write('\r\n')
    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
    f.close()
Example #18
0
def main():
    parser = argparse.ArgumentParser(
        description='Classifiar using triplet loss.')
    parser.add_argument('--CVDs',
                        type=str,
                        default='0',
                        metavar='CUDA_VISIBLE_DEVICES',
                        help='CUDA_VISIBLE_DEVICES')
    parser.add_argument('--server',
                        type=int,
                        default=82,
                        metavar='T',
                        help='which server is being used')
    parser.add_argument(
        '--train-set',
        type=str,
        default='/home/zili/memory/FaceRecognition-master/data/cifar100/train2',
        metavar='dir',
        help='path of train set.')
    parser.add_argument(
        '--test-set',
        type=str,
        default='/home/zili/memory/FaceRecognition-master/data/cifar100/test2',
        metavar='dir',
        help='path of test set.')
    parser.add_argument(
        '--train-set-csv',
        type=str,
        default=
        '/home/zili/memory/FaceRecognition-master/data/cifar100/train.csv',
        metavar='file',
        help='path of train set.csv.')
    parser.add_argument('--num-triplet',
                        type=int,
                        default=1000,
                        metavar='number',
                        help='number of triplet in dataset (default: 32)')
    parser.add_argument('--train-batch-size',
                        type=int,
                        default=96,
                        metavar='number',
                        help='input batch size for training (default: 32)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=192,
                        metavar='number',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='number',
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--embedding-size',
                        type=int,
                        default=128,
                        metavar='number',
                        help='embedding size of model (default: 256)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.1,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--margin',
                        type=float,
                        default=1.,
                        metavar='margin',
                        help='loss margin (default: 1.0)')
    parser.add_argument('--num-classes',
                        type=int,
                        default=10,
                        metavar='number',
                        help='classes number of dataset')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='number',
        help='how many batches to wait before logging training status')
    parser.add_argument('--model-name',
                        type=str,
                        default='resnet34',
                        metavar='M',
                        help='model name (default: resnet34)')
    parser.add_argument('--dropout-p',
                        type=float,
                        default=0.2,
                        metavar='D',
                        help='Dropout probability (default: 0.2)')
    parser.add_argument(
        '--check-path',
        type=str,
        default='/home/zili/memory/FaceRecognition-master/checkpoints',
        metavar='folder',
        help='Checkpoint path')
    parser.add_argument(
        '--is-semihard',
        type=bool,
        default=True,
        metavar='R',
        help='whether the dataset is selected in semi-hard way.')
    parser.add_argument('--is-pretrained',
                        type=bool,
                        default=False,
                        metavar='R',
                        help='whether model is pretrained.')

    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.CVDs
    if args.server == 31:
        args.train_set = '/share/zili/code/triplet/data/cifar100/train2'
        args.test_set = '/share/zili/code/triplet/data/cifar100/test2'
        args.train_set_csv = '/share/zili/code/triplet/data/cifar100/train.csv'
        args.check_path = '/share/zili/code/triplet/checkpoints'
    if args.server == 16:
        args.train_set = '/data0/zili/code/triplet/data/cifar100/train2'
        args.test_set = '/data0/zili/code/triplet/data/cifar100/test2'
        args.train_set_csv = '/data0/zili/code/triplet/data/cifar100/train.csv'
        args.check_path = '/data0/zili/code/triplet/checkpoints'
    if args.server == 17:
        args.train_set = '/data/jiaxin/zili/data/cifar100/train2'
        args.test_set = '/data/jiaxin/zili/data/cifar100/test'
        args.train_set_csv = '/data/jiaxin/zili/data/cifar100/train.csv'
        args.check_path = '/data/jiaxin/zili/checkpoints'
    if args.server == 15:
        args.train_set = '/home/zili/code/triplet/data/cifar100/train2'
        args.test_set = '/home/zili/code/triplet/data/cifar100/test2'
        args.train_set_csv = '/home/zili/code/triplet/data/cifar100/train.csv'
        args.check_path = '/home/zili/code/triplet/checkpoints'
    now_time = str(datetime.datetime.now())
    if not os.path.exists(args.check_path):
        os.mkdir(args.check_path)
    args.check_path = os.path.join(args.check_path, now_time)
    if not os.path.exists(args.check_path):
        os.mkdir(args.check_path)
    shutil.copy('tripletloss.py', args.check_path)

    output1 = 'main_' + now_time
    f = open(args.check_path + os.path.sep + output1 + '.txt', 'w+')
    writer = SummaryWriter()

    print('Loading model...')

    # model = FaceModel(model_name     = args.model_name,
    #                   embedding_size = args.embedding_size,
    #                   pretrained     = args.is_pretrained)
    # model = model_.ResNet34(True, args.embedding_size)
    model = EmbeddingNet(embedding_len=args.embedding_size)
    if torch.cuda.is_available():

        model = torch.nn.DataParallel(model).cuda()
        cudnn.benchmark = True
    f.write("     model: {}".format(model.module) + '\r\n')

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=1e-5)
    print('start training...')

    features, labels, clf, destination = select_three_sample(
        model, args, -1, writer)
    for epoch in range(args.epochs):
        file_operation(f, args, optimizer)

        if (epoch + 1) % 10 == 0:
            args.lr = args.lr / 3
            update_lr(optimizer, args.lr)

        train(epoch, model, optimizer, args, f, features, destination, writer)
        features, labels, clf, destination = select_three_sample(
            model, args, epoch, writer)
        validate(epoch, model, clf, args, f, writer, False)
        validate(epoch, model, clf, args, f, writer, True)

        f.write('\r\n')
        if epoch < 80 and epoch > 10:
            torch.save(
                model,
                args.check_path + os.path.sep + 'epoch' + str(epoch) + '.pth')
    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
Example #19
0
def main():
    args = get_args()
    logdir = 'log/{}-emb{}-{}layers-{}resblk-lr{}-wd{}-maxlen{}-alpha10-margin{}'\
             '{}class-{}sample-{}selector'\
             .format(args.name, 
                     args.embedding_size,
                     args.layers,
                     args.resblk,
                     args.lr, 
                     args.wd, 
                     args.maxlen,
                     args.margin,
                     args.n_classes,
                     args.n_samples,
                     args.selection)
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    resblock = []
    for i in range(args.layers):
        resblock.append(args.resblk)

    if args.train:
        logger = Logger(logdir)
        if not os.path.exists(args.trainfeature):
            os.mkdir(args.trainfeature)
            extractFeature(args.training-dataset, args.trainfeature)
        trainset = DeepSpkDataset(args.trainfeature, args.maxlen)
        pre_loader = DataLoader(trainset, batch_size = 128, shuffle = True, num_workers = 8)
        train_batch_sampler = BalancedBatchSampler(trainset.train_labels, 
                                                   n_classes = args.n_classes, 
                                                   n_samples = args.n_samples)
        kwargs = {'num_workers' : 1, 'pin_memory' : True}
        online_train_loader = torch.utils.data.DataLoader(trainset, 
                                                          batch_sampler=train_batch_sampler,
                                                          **kwargs) 
        margin = args.margin
        
        embedding_net = EmbeddingNet(resblock,  
                                     embedding_size = args.embedding_size,
                                     layers = args.layers)
        model = DeepSpeaker(embedding_net, trainset.get_num_class())
        device = torch.device('cuda:0')
        model.to(device) # 要在初始化optimizer之前把model转换到GPU上,这样初始化optimizer的时候也是在GPU上
        optimizer = optim.SGD(model.embedding_net.parameters(), 
                              lr = args.lr, 
                              momentum = 0.99,
                              weight_decay = args.wd)
        start_epoch = 0
        if args.resume:
            if os.path.isfile(args.resume):
                print('=> loading checkpoint {}'.format(args.resume))
                checkpoint = torch.load(args.resume)
                start_epoch = checkpoint['epoch']
                model.embedding_net.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
            else:
                print('=> no checkpoint found at {}'.format(args.resume))

        pretrain_epoch = args.pretrain_epoch
        
        if args.selection == 'randomhard':
            selector = RandomNegativeTripletSelector(margin)
        if args.selection == 'hardest':
            selector = HardestNegativeTripletSelector(margin)
        if args.selection == 'semihard':
            selector = SemihardNegativeTripletSelector(margin)
        if args.selection == 'all':
            print('warning : select all triplet may take very long time')
            selector = AllTripletSelector()

        loss_fn = OnlineTripletLoss(margin, selector)   
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size = args.lr_adjust_step,
                                        gamma = args.lr_decay,
                                        last_epoch = -1) 
        n_epochs = args.n_epochs
        log_interval = 50
        fit(online_train_loader,
            pre_loader,
            model,
            loss_fn,
            optimizer,
            scheduler,
            pretrain_epoch,
            n_epochs,
            True,
            device,
            log_interval,
            log_dir = logdir,
            eval_path = args.evalfeature,
            logger = logger,
            metrics = [AverageNonzeroTripletsMetric()],
            evaluatee = args.eval,
            start_epoch = start_epoch)
    else:
        if not os.path.exists(args.testfeature):
            os.mkdir(args.testfeature)
            extractFeature(args.test-dataset, args.testfeature)
        model = EmbeddingNet(resblock,  
                             embedding_size = args.embedding_size,
                             layers = args.layers)
        model.cpu()
        if args.model:
            if os.path.isfile(args.model):
                print('=> loading checkpoint {}'.format(args.model))
                checkpoint = torch.load(args.model)
                model.load_state_dict(checkpoint['state_dict'])
            else:
                print('=> no checkpoint found at {}'.format(args.model))
        thres = np.loadtxt(logdir + '/thres.txt')
        acc = np.loadtxt(logdir + '/acc.txt')
        idx = np.argmax(acc)
        best_thres = thres[idx]
        predict(model, args.testfeature, best_thres)