Example #1
0
def main(args):
    # append extension to filename if needed
    try:
        assert args.submission_name[-4:] == '.csv'
    except AssertionError:
        args.submission_name += '.csv'

    # load necessaries from checkpoint
    check = torch.load(args.checkpoint)
    model_name = check['model']
    state_dict = check['state_dict']

    # enable cuda if available and desired
    args.use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.use_cuda else "cpu")

    # torch dataset and dataloader
    mnist_transform = transforms.Normalize((0.1307, ), (0.3081, ))
    dataset = KaggleMNIST(args.data_folder,
                          train=False,
                          transform=mnist_transform)
    loader = DataLoader(dataset, batch_size=len(dataset))

    # construct checkpoint's corresponding model type
    if model_name == 'linear':
        model = Softmax().to(device)
    elif model_name == 'neuralnet':
        model = TwoLayer().to(device)
    else:
        model = ConvNet().to(device)

    # load trained weights into model
    model.load_state_dict(state_dict)

    # make predictions with trained model on test data
    # loader has only one element (e.g. batch), so we don't need a loop
    # this won't work for large datasets; use a loop in those cases
    imgs, _ = next(iter(loader))
    imgs = imgs.to(device)
    logits = model(imgs)
    _, preds = torch.max(
        logits,
        dim=1)  # returns max prob and argmax (e.g. corresponding class)
    # when dim is supplied

    # construct numpy array with two columns: ids, digit predictions
    # we'll use that array to create our text file using the np.savetxt function
    ids = (np.arange(len(preds)) + 1).reshape(-1, 1)
    preds = preds.view(-1, 1).numpy()
    submission = np.concatenate((ids, preds), axis=1)

    # writing submisison array to text file with proper formatting
    np.savetxt(args.data_folder + args.submission_name,
               submission,
               fmt='%1.1i',
               delimiter=',',
               header='ImageId,Label',
               comments='')
Example #2
0
def make_basic_cnn(nb_filters=64, nb_classes=10,
                   input_shape=(None, 28, 28, 1)):
  layers = [Conv2D(nb_filters, (8, 8), (2, 2), "SAME"),
            ReLU(),
            Conv2D(nb_filters * 2, (6, 6), (2, 2), "VALID"),
            ReLU(),
            Conv2D(nb_filters * 2, (5, 5), (1, 1), "VALID"),
            ReLU(),
            Flatten(),
            Linear(nb_classes),
            Softmax()]

  model = MLP(nb_classes, layers, input_shape)
  return model
Example #3
0
def make_madry_ngpu(nb_classes=10, input_shape=(None, 28, 28, 1), **kwargs):
  """
  Create a multi-GPU model similar to Madry et al. (arXiv:1706.06083).
  """
  layers = [Conv2DnGPU(32, (5, 5), (1, 1), "SAME"),
            ReLU(),
            MaxPool((2, 2), (2, 2), "SAME"),
            Conv2DnGPU(64, (5, 5), (1, 1), "SAME"),
            ReLU(),
            MaxPool((2, 2), (2, 2), "SAME"),
            Flatten(),
            LinearnGPU(1024),
            ReLU(),
            LinearnGPU(nb_classes),
            Softmax()]

  model = MLPnGPU(nb_classes, layers, input_shape)
  return model
Example #4
0
def main(args):
    # reproducibility
    # need to seed numpy/torch random number generators
    if args.seed is not None:
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
    # need directory with checkpoint files to recover previously trained models
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)
    checkpoint_file = args.checkpoint + args.model + str(datetime.now())[:-10]

    # decide which device to use; assumes at most one GPU is available
    args.use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.use_cuda else "cpu")

    # decide if we're using a validation set;
    # if not, don't evaluate at end of epochs
    evaluate = args.train_split < 1.

    # prep data loaders
    if args.train_split == 1:
        train_loader, _, test_loader = prepare_data(args)
    else:
        train_loader, val_loader, test_loader = prepare_data(args)

    # build model
    if args.model == 'linear':
        model = Softmax().to(device)
    elif args.model == 'neuralnet':
        model = TwoLayer().to(device)
    else:
        model = ConvNet().to(device)

    # build optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 amsgrad=args.amsgrad)

    # setup validation metrics we want to track for tracking best model over training run
    best_val_loss = float('inf')
    best_val_acc = 0

    # set up tensorboard logger
    logger = LoggerX('test_mnist', 'mnist_data', 25)

    # loop over epochs
    for epoch in range(args.epochs):
        print('\n================== TRAINING ==================')
        model.train()  # set model to training mode
        # set up training metrics we want to track
        correct = 0
        train_num = len(train_loader.sampler)

        # metrics from logger
        model_metrics = CalculateMetrics(batch_size=args.batch_size,
                                         batches_per_epoch=len(train_loader))

        for ix, (img, label
                 ) in enumerate(train_loader):  # iterate over training batches
            img, label = img.to(device), label.to(
                device)  # get data, send to gpu if needed

            optimizer.zero_grad(
            )  # clear parameter gradients from previous training update
            output = model(img)  # forward pass
            loss = F.cross_entropy(output, label)  # calculate network loss
            loss.backward()  # backward pass
            optimizer.step(
            )  # take an optimization step to update model's parameters

            pred = output.max(
                1, keepdim=True)[1]  # get the index of the max logit
            # correct += pred.eq(label.view_as(pred)).sum().item() # add to running total of hits

            # convert this data to binary for the sake of testing the metrics functionality
            label[label < 5] = 0
            label[label > 0] = 1

            pred[pred < 5] = 0
            pred[pred > 0] = 1
            ######

            scores_dict = model_metrics.update_scores(label, pred)

            if ix % args.log_interval == 0:
                # log the metrics to tensorboard X, track best model according to current weighted average accuracy
                logger.log(model,
                           optimizer,
                           loss.item(),
                           track_score=scores_dict['weighted_acc'] /
                           model_metrics.bn,
                           scores_dict=scores_dict,
                           epoch=epoch,
                           bn=model_metrics.bn,
                           batches_per_epoch=model_metrics.batches_per_epoch)
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, model_metrics.bn, model_metrics.batches_per_epoch,
                    (model_metrics.bn / model_metrics.batches_per_epoch) * 100,
                    loss.item()))

        # print whole epoch's training accuracy; useful for monitoring overfitting
        print('Train Accuracy: ({:.0f}%)'.format(model_metrics.w_accuracy *
                                                 100))

        if evaluate:
            print('\n================== VALIDATION ==================')
            model.eval()  # set model to evaluate mode

            # set up validation metrics we want to track
            val_loss = 0.
            val_correct = 0
            val_num = len(val_loader.sampler)

            # disable autograd here (replaces volatile flag from v0.3.1 and earlier)
            with torch.no_grad():
                # loop over validation batches
                for img, label in val_loader:
                    img, label = img.to(device), label.to(
                        device)  # get data, send to gpu if needed
                    output = model(img)  # forward pass

                    # sum up batch loss
                    val_loss += F.cross_entropy(output,
                                                label,
                                                size_average=False).item()

                    # monitor for accuracy
                    pred = output.max(
                        1, keepdim=True)[1]  # get the index of the max logit
                    val_correct += pred.eq(
                        label.view_as(pred)).sum().item()  # add to total hits

            # update current evaluation metrics
            val_loss /= val_num
            val_acc = 100. * val_correct / val_num
            print(
                '\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
                .format(val_loss, val_correct, val_num, val_acc))

            # check if best model according to accuracy;
            # if so, replace best metrics
            is_best = val_acc > best_val_acc
            if is_best:
                best_val_acc = val_acc
                best_val_loss = val_loss  # note this is val_loss of best model w.r.t. accuracy,
                # not the best val_loss throughout training

            # create checkpoint dictionary and save it;
            # if is_best, copy the file over to the file containing best model for this run
            state = {
                'epoch': epoch,
                'model': args.model,
                'state_dict': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
                'val_loss': val_loss,
                'best_val_loss': best_val_loss,
                'val_acc': val_acc,
                'best_val_acc': best_val_acc
            }
            save_checkpoint(state, is_best, checkpoint_file)

    print('\n================== TESTING ==================')
    # load best model from training run (according to validation accuracy)
    check = torch.load(logger.best_path)
    model.load_state_dict(check['state_dict'])
    model.eval()  # set model to evaluate mode

    # set up evaluation metrics we want to track
    test_loss = 0.
    test_correct = 0
    test_num = len(test_loader.sampler)

    test_metrics = CalculateMetrics(batch_size=args.batch_size,
                                    batches_per_epoch=test_num)
    # disable autograd here (replaces volatile flag from v0.3.1 and earlier)
    with torch.no_grad():
        for img, label in test_loader:
            img, label = img.to(device), label.to(device)
            output = model(img)
            # sum up batch loss
            test_loss += F.cross_entropy(output, label,
                                         size_average=False).item()
            pred = output.max(
                1, keepdim=True)[1]  # get the index of the max logit
            test_scores = test_metrics.update_scores(label, pred)
            logger.log(model,
                       optimizer,
                       test_loss,
                       test_scores['weighted_acc'],
                       test_scores,
                       phase='test')

    test_loss /= test_num
    print('Test set: Average loss: {:.4f}, Accuracy: ({:.0f}%)\n'.format(
        test_loss, test_metrics['weighted_acc'] * 100))

    print('Final model stored at "{}".'.format(checkpoint_file +
                                               '-best.pth.tar'))
Example #5
0
    def __init__(self, conf, inference=False, embedding_size=512):
        conf.embedding_size = embedding_size
        print(conf)

        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).cuda()
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode).cuda()
            print('{}_{} model generated'.format(conf.net_mode,
                                                 conf.net_depth))

        parameter_num_cal(self.model)

        self.milestones = conf.milestones
        self.loader, self.class_num = get_train_loader(conf)
        self.step = 0
        self.agedb_30, self.cfp_fp, self.lfw, self.calfw, self.cplfw, self.vgg2_fp, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame, self.calfw_issame, self.cplfw_issame, self.vgg2_fp_issame = get_val_data(
            self.loader.dataset.root.parent)
        self.writer = SummaryWriter(conf.log_path)

        if not inference:
            self.milestones = conf.milestones
            self.loader, self.class_num = get_train_loader(conf)

            self.writer = SummaryWriter(conf.log_path)
            self.step = 0

            if conf.multi_sphere:
                if conf.arcface_loss:
                    self.head = ArcfaceMultiSphere(
                        embedding_size=conf.embedding_size,
                        classnum=self.class_num,
                        num_shpere=conf.num_sphere,
                        m=conf.m).to(conf.device)
                elif conf.am_softmax:
                    self.head = MultiAm_softmax(
                        embedding_size=conf.embedding_size,
                        classnum=self.class_num,
                        num_sphere=conf.num_sphere,
                        m=conf.m).to(conf.device)
                else:
                    self.head = MultiSphereSoftmax(
                        embedding_size=conf.embedding_size,
                        classnum=self.class_num,
                        num_sphere=conf.num_sphere).to(conf.device)

            else:
                if conf.arcface_loss:
                    self.head = Arcface(embedding_size=conf.embedding_size,
                                        classnum=self.class_num).to(
                                            conf.device)
                elif conf.am_softmax:
                    self.head = Am_softmax(embedding_size=conf.embedding_size,
                                           classnum=self.class_num).to(
                                               conf.device)
                else:
                    self.head = Softmax(embedding_size=conf.embedding_size,
                                        classnum=self.class_num).to(
                                            conf.device)

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            if conf.use_mobilfacenet:
                if conf.multi_sphere:
                    self.optimizer = optim.SGD([{
                        'params': paras_wo_bn[:-1],
                        'weight_decay': 4e-5
                    }, {
                        'params': [paras_wo_bn[-1]] + self.head.kernel_list,
                        'weight_decay':
                        4e-4
                    }, {
                        'params': paras_only_bn
                    }],
                                               lr=conf.lr,
                                               momentum=conf.momentum)
                else:
                    self.optimizer = optim.SGD(
                        [{
                            'params': paras_wo_bn[:-1],
                            'weight_decay': 4e-5
                        }, {
                            'params': [paras_wo_bn[-1]] + [self.head.kernel],
                            'weight_decay': 4e-4
                        }, {
                            'params': paras_only_bn
                        }],
                        lr=conf.lr,
                        momentum=conf.momentum)
            else:
                if conf.multi_sphere:
                    self.optimizer = optim.SGD(
                        [{
                            'params': paras_wo_bn + self.head.kernel_list,
                            'weight_decay': 5e-4
                        }, {
                            'params': paras_only_bn
                        }],
                        lr=conf.lr,
                        momentum=conf.momentum)
                else:
                    self.optimizer = optim.SGD(
                        [{
                            'params': paras_wo_bn + [self.head.kernel],
                            'weight_decay': 5e-4
                        }, {
                            'params': paras_only_bn
                        }],
                        lr=conf.lr,
                        momentum=conf.momentum)

            print(self.optimizer)

            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, patience=40, verbose=True)

            print('optimizers generated')
            self.board_loss_every = len(self.loader) // 100
            self.evaluate_every = len(self.loader) // 10
            self.save_every = len(self.loader) // 5
            self.agedb_30, self.cfp_fp, self.lfw, self.calfw, self.cplfw, self.vgg2_fp, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame, self.calfw_issame, self.cplfw_issame, self.vgg2_fp_issame = get_val_data(
                self.loader.dataset.root.parent)
        else:
            self.threshold = conf.threshold
Example #6
0
class face_learner(object):
    def __init__(self, conf, inference=False, embedding_size=512):
        conf.embedding_size = embedding_size
        print(conf)

        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).cuda()
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode).cuda()
            print('{}_{} model generated'.format(conf.net_mode,
                                                 conf.net_depth))

        parameter_num_cal(self.model)

        self.milestones = conf.milestones
        self.loader, self.class_num = get_train_loader(conf)
        self.step = 0
        self.agedb_30, self.cfp_fp, self.lfw, self.calfw, self.cplfw, self.vgg2_fp, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame, self.calfw_issame, self.cplfw_issame, self.vgg2_fp_issame = get_val_data(
            self.loader.dataset.root.parent)
        self.writer = SummaryWriter(conf.log_path)

        if not inference:
            self.milestones = conf.milestones
            self.loader, self.class_num = get_train_loader(conf)

            self.writer = SummaryWriter(conf.log_path)
            self.step = 0

            if conf.multi_sphere:
                if conf.arcface_loss:
                    self.head = ArcfaceMultiSphere(
                        embedding_size=conf.embedding_size,
                        classnum=self.class_num,
                        num_shpere=conf.num_sphere,
                        m=conf.m).to(conf.device)
                elif conf.am_softmax:
                    self.head = MultiAm_softmax(
                        embedding_size=conf.embedding_size,
                        classnum=self.class_num,
                        num_sphere=conf.num_sphere,
                        m=conf.m).to(conf.device)
                else:
                    self.head = MultiSphereSoftmax(
                        embedding_size=conf.embedding_size,
                        classnum=self.class_num,
                        num_sphere=conf.num_sphere).to(conf.device)

            else:
                if conf.arcface_loss:
                    self.head = Arcface(embedding_size=conf.embedding_size,
                                        classnum=self.class_num).to(
                                            conf.device)
                elif conf.am_softmax:
                    self.head = Am_softmax(embedding_size=conf.embedding_size,
                                           classnum=self.class_num).to(
                                               conf.device)
                else:
                    self.head = Softmax(embedding_size=conf.embedding_size,
                                        classnum=self.class_num).to(
                                            conf.device)

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            if conf.use_mobilfacenet:
                if conf.multi_sphere:
                    self.optimizer = optim.SGD([{
                        'params': paras_wo_bn[:-1],
                        'weight_decay': 4e-5
                    }, {
                        'params': [paras_wo_bn[-1]] + self.head.kernel_list,
                        'weight_decay':
                        4e-4
                    }, {
                        'params': paras_only_bn
                    }],
                                               lr=conf.lr,
                                               momentum=conf.momentum)
                else:
                    self.optimizer = optim.SGD(
                        [{
                            'params': paras_wo_bn[:-1],
                            'weight_decay': 4e-5
                        }, {
                            'params': [paras_wo_bn[-1]] + [self.head.kernel],
                            'weight_decay': 4e-4
                        }, {
                            'params': paras_only_bn
                        }],
                        lr=conf.lr,
                        momentum=conf.momentum)
            else:
                if conf.multi_sphere:
                    self.optimizer = optim.SGD(
                        [{
                            'params': paras_wo_bn + self.head.kernel_list,
                            'weight_decay': 5e-4
                        }, {
                            'params': paras_only_bn
                        }],
                        lr=conf.lr,
                        momentum=conf.momentum)
                else:
                    self.optimizer = optim.SGD(
                        [{
                            'params': paras_wo_bn + [self.head.kernel],
                            'weight_decay': 5e-4
                        }, {
                            'params': paras_only_bn
                        }],
                        lr=conf.lr,
                        momentum=conf.momentum)

            print(self.optimizer)

            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, patience=40, verbose=True)

            print('optimizers generated')
            self.board_loss_every = len(self.loader) // 100
            self.evaluate_every = len(self.loader) // 10
            self.save_every = len(self.loader) // 5
            self.agedb_30, self.cfp_fp, self.lfw, self.calfw, self.cplfw, self.vgg2_fp, self.agedb_30_issame, self.cfp_fp_issame, self.lfw_issame, self.calfw_issame, self.cplfw_issame, self.vgg2_fp_issame = get_val_data(
                self.loader.dataset.root.parent)
        else:
            self.threshold = conf.threshold

    def save_state(self,
                   conf,
                   accuracy,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        torch.save(
            self.model.state_dict(),
            save_path / ('model_{}_accuracy:{}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                save_path / ('head_{}_accuracy:{}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(),
                save_path / ('optimizer_{}_accuracy:{}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))

    def get_new_state(self, path):
        state_dict = torch.load(path)

        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if 'module' not in k:
                k = 'module.' + k
            else:
                k = k.replace('features.module.', 'module.features.')
            new_state_dict[k] = v
        return new_state_dict

    def load_state(self, save_path, fixed_str, model_only=False):
        self.model.load_state_dict(
            torch.load(save_path / 'model_{}'.format(fixed_str)))

        if not model_only:
            self.head.load_state_dict(
                torch.load(save_path / 'head_{}'.format(fixed_str)))
            self.optimizer.load_state_dict(
                torch.load(save_path / 'optimizer_{}'.format(fixed_str)))
            print(self.optimizer)

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor,
                  angle_info):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name),
                               best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor,
                              self.step)
        self.writer.add_scalar('{}_same_pair_angle_mean'.format(db_name),
                               angle_info['same_pair_angle_mean'], self.step)
        self.writer.add_scalar('{}_same_pair_angle_var'.format(db_name),
                               angle_info['same_pair_angle_var'], self.step)
        self.writer.add_scalar('{}_diff_pair_angle_mean'.format(db_name),
                               angle_info['diff_pair_angle_mean'], self.step)
        self.writer.add_scalar('{}_diff_pair_angle_var'.format(db_name),
                               angle_info['diff_pair_angle_var'], self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=10, tta=False, n=1):
        self.model.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size // n])
        i = 0
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:idx + conf.batch_size] = l2_norm(
                        emb_batch).cpu()[:, i * conf.embedding_size //
                                         n:(i + 1) * conf.embedding_size // n]
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()[:,
                                                     i * conf.embedding_size //
                                                     n:(i + 1) *
                                                     conf.embedding_size // n]
                idx += conf.batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)) + self.model(
                        fliped.to(conf.device))
                    embeddings[idx:] = l2_norm(
                        emb_batch).cpu()[:, i * conf.embedding_size //
                                         n:(i + 1) * conf.embedding_size // n]
                else:
                    embeddings[idx:] = self.model(batch.to(
                        conf.device)).cpu()[:, i * conf.embedding_size //
                                            n:(i + 1) * conf.embedding_size //
                                            n]
        tpr, fpr, accuracy, best_thresholds, angle_info = evaluate(
            embeddings, issame, nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(
        ), roc_curve_tensor, angle_info

    def find_lr(self,
                conf,
                init_value=1e-8,
                final_value=10.,
                beta=0.98,
                bloding_scale=3.,
                num=None):
        if not num:
            num = len(self.loader)
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        for params in self.optimizer.param_groups:
            params['lr'] = lr
        self.model.train()
        avg_loss = 0.
        best_loss = 0.
        batch_num = 0
        losses = []
        log_lrs = []
        for i, (imgs, labels) in tqdm(enumerate(self.loader), total=num):

            imgs = imgs.to(conf.device)
            labels = labels.to(conf.device)
            batch_num += 1

            self.optimizer.zero_grad()

            embeddings = self.model(imgs)
            thetas = self.head(embeddings, labels)
            if conf.multi_sphere:
                loss = conf.ce_loss(thetas[0], labels)
                for theta in thetas[1:]:
                    loss = loss + conf.ce_loss(theta, labels)
            else:
                loss = conf.ce_loss(thetas, labels)

            #Compute the smoothed loss
            avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            self.writer.add_scalar('avg_loss', avg_loss, batch_num)
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            self.writer.add_scalar('smoothed_loss', smoothed_loss, batch_num)
            #Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > bloding_scale * best_loss:
                print('exited with best_loss at {}'.format(best_loss))
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses
            #Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            #Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            self.writer.add_scalar('log_lr', math.log10(lr), batch_num)
            #Do the SGD step
            #Update the lr for the next step
            loss.backward()
            self.optimizer.step()

            lr *= mult
            for params in self.optimizer.param_groups:
                params['lr'] = lr
            if batch_num > num:
                plt.plot(log_lrs[10:-5], losses[10:-5])
                return log_lrs, losses

    def model_evaluation(self, conf):
        self.model.load_state_dict(torch.load(conf.pretrained_model_path))
        accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
            conf, self.agedb_30, self.agedb_30_issame, tta=True)
        print('age_db_acc:', accuracy)

        accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
            conf, self.lfw, self.lfw_issame, tta=True)
        print('lfw_acc:', accuracy)

        accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
            conf, self.cfp_fp, self.cfp_fp_issame, tta=True)
        print('cfp_acc:', accuracy)

        accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
            conf, self.calfw, self.calfw_issame, tta=True)
        print('calfw_acc:', accuracy)

        accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
            conf, self.cplfw, self.cplfw_issame, tta=True)
        print('cplfw_acc:', accuracy)

        accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
            conf, self.vgg2_fp, self.vgg2_fp_issame, tta=True)
        print('vgg2_acc:', accuracy)

    def train(self, conf, epochs):
        self.model.train()
        running_loss = 0.

        if conf.pretrain:
            self.model_evaluation(conf)
            sys.exit(0)

        logging.basicConfig(
            filename=conf.log_path / 'log.txt',
            level=logging.INFO,
            format="%(asctime)s %(name)s %(levelname)s %(message)s",
            datefmt='%Y-%m-%d  %H:%M:%S %a')
        logging.info(
            '\n******\nnum of sphere is: {},\nnet is: {},\ndepth is: {},\nlr is: {},\nbatch size is: {}\n******'
            .format(conf.num_sphere, conf.net_mode, conf.net_depth, conf.lr,
                    conf.batch_size))
        for e in range(epochs):
            print('epoch {} started,all is {}'.format(e, epochs))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()

            for imgs, labels in tqdm(iter(self.loader)):
                self.model.train()

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                embeddings = self.model(imgs)
                thetas = self.head(embeddings, labels)

                if conf.multi_sphere:
                    loss = conf.ce_loss(thetas[0], labels)
                    for theta in thetas[1:]:
                        loss = loss + conf.ce_loss(theta, labels)
                else:
                    loss = conf.ce_loss(thetas, labels)

                running_loss += loss.item()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
                        conf, self.agedb_30, self.agedb_30_issame)
                    print('age_db_acc:', accuracy)
                    self.board_val('agedb_30', accuracy, best_threshold,
                                   roc_curve_tensor, angle_info)
                    logging.info('agedb_30 acc: {}'.format(accuracy))

                    accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    print('lfw_acc:', accuracy)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor, angle_info)
                    logging.info('lfw acc: {}'.format(accuracy))

                    accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
                        conf, self.cfp_fp, self.cfp_fp_issame)
                    print('cfp_acc:', accuracy)
                    self.board_val('cfp', accuracy, best_threshold,
                                   roc_curve_tensor, angle_info)
                    logging.info('cfp acc: {}'.format(accuracy))

                    accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
                        conf, self.calfw, self.calfw_issame)
                    print('calfw_acc:', accuracy)
                    self.board_val('calfw', accuracy, best_threshold,
                                   roc_curve_tensor, angle_info)
                    logging.info('calfw acc: {}'.format(accuracy))

                    accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
                        conf, self.cplfw, self.cplfw_issame)
                    print('cplfw_acc:', accuracy)
                    self.board_val('cplfw', accuracy, best_threshold,
                                   roc_curve_tensor, angle_info)
                    logging.info('cplfw acc: {}'.format(accuracy))

                    accuracy, best_threshold, roc_curve_tensor, angle_info = self.evaluate(
                        conf, self.vgg2_fp, self.vgg2_fp_issame)
                    print('vgg2_acc:', accuracy)
                    self.board_val('vgg2', accuracy, best_threshold,
                                   roc_curve_tensor, angle_info)
                    logging.info('vgg2_fp acc: {}'.format(accuracy))

                    self.model.train()
                self.step += 1

    def schedule_lr(self):
        for params in self.optimizer_corr.param_groups:
            params['lr'] /= 10
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)

    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = trans.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum