示例#1
0
def get_instance_segmentation_model(bone='resnet50', attention=False):
    if bone == 'mobilenet_v2':
        if attention == False:
            backbone = models.mobilenet_v2(pretrained=True,
                                           att=attention).features
        if attention == True:
            backbone = models.mobilenet_v2(pretrained=False,
                                           att=attention).features
        backbone.out_channels = 1280
    if bone == 'googlenet':
        if attention == False:
            backbone = models.googlenet(pretrained=True)
        if attention == True:
            backbone = models.googlenet(pretrained=False)
        backbone.out_channels = 1024
    if bone == 'densenet121':
        if attention == False:
            backbone = models.densenet121(pretrained=True,
                                          att=attention).features
        if attention == True:
            backbone = models.densenet121(pretrained=False,
                                          att=attention).features
        backbone.out_channels = 1024
    if bone == 'resnet50':
        if attention == False:
            backbone = models.resnet50(pretrained=True, att=attention)
        if attention == True:
            backbone = models.resnet50(pretrained=False, att=attention)
        backbone.out_channels = 2048
    if bone == 'shufflenet_v2_x1_0':
        if attention == False:
            backbone = models.shufflenet_v2_x1_0(pretrained=True)
        if attention == True:
            backbone = models.shufflenet_v2_x1_0(pretrained=False)
        backbone.out_channels = 1024
    if bone == 'inception_v3':
        if attention == False:
            backbone = models.inception_v3(
            )  #'InceptionOutputs' object has no attribute 'values'
        if attention == True:
            backbone = models.inception_v3(
            )  #'InceptionOutputs' object has no attribute 'values'
        backbone.out_channels = 2048
    if bone == 'squeezenet1_0':
        if attention == False:
            backbone = models.squeezenet1_0(pretrained=True).features
        if attention == True:
            backbone = models.squeezenet1_0(pretrained=False).features
        backbone.out_channels = 512

    anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512), ),
                                       aspect_ratios=((0.5, 1.0, 2.0), ))
    roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                    output_size=7,
                                                    sampling_ratio=2)
    model = MaskRCNN(backbone,
                     num_classes=2,
                     rpn_anchor_generator=anchor_generator,
                     box_roi_pool=roi_pooler)
    return model
def Mobilenetv2(num_classes, test=False):
    model = mobilenet_v2(num_classes=num_classes)
    # if not test:
    #     if LOCAL_PRETRAINED['moblienetv2'] is None:
    #         state_dict = load_state_dict_from_url(model_urls['moblienetv2'], progress=True)
    #     else:
    #         state_dict = state_dict = torch.load(LOCAL_PRETRAINED['moblienetv2'])
    #     model.load_state_dict(state_dict)
    # print("mobilenet_v2", model.state_dict().keys())
    # fc_features = model.classifier[1].in_features
    # model.classifier = nn.Linear(fc_features, num_classes)
    return model
示例#3
0
def Mobilenetv2(num_classes, test=False):
    model = mobilenet_v2()
    if not test:
        if LOCAL_PRETRAINED['moblienetv2'] == None:
            state_dict = load_state_dict_from_url(model_urls['moblienetv2'],
                                                  progress=True)
        else:
            state_dict = state_dict = torch.load(
                LOCAL_PRETRAINED['moblienetv2'])
        model.load_state_dict(state_dict)
    print(model.state_dict().keys())
    fc_features = model.classifier[1].in_features
    model.classifier = nn.Linear(fc_features, num_classes)
    return model
示例#4
0
def main():
    fold = config.fold
    # mkdirs
    if not os.path.exists(config.submit):
        os.mkdir(config.submit)
    if not os.path.exists(config.weights):
        os.mkdir(config.weights)
    if not os.path.exists(config.best_models):
        os.mkdir(config.best_models)
    if not os.path.exists(config.logs):
        os.mkdir(config.logs)
    if not os.path.exists(config.weights + config.model_name + os.sep +str(fold) + os.sep):
        os.makedirs(config.weights + config.model_name + os.sep +str(fold) + os.sep)
    if not os.path.exists(config.best_models + config.model_name + os.sep +str(fold) + os.sep):
        os.makedirs(config.best_models + config.model_name + os.sep +str(fold) + os.sep)   

    # 4.2 get model
    if config.model_name is "lenet":
        model = get_lenet(config.img_channels)
    elif config.model_name is "mobilenet":
        model = get_mobilenet(config.img_channels)
    elif config.model_name is "mobilenet_v2":
        model = mobilenet_v2(config.img_channels)
    elif config.model_name is "resnet":
        model = resnet18()
    elif config.model_name is "densenet":
        model = densenet121()
    elif config.model_name is "inception_v3":
        model = inception_v3()

    print(model, "\n")
    # model = torch.nn.DataParallel(model)
    model.cuda()

    # load dataset
    with h5py.File(config.test_data, 'r') as db:
        num_test = db.attrs['size']
        #  num_test = 20
        print('test dataset size:', num_test)
    
    test_modle = "./checkpoints/%s/%s/_checkpoint.pt" % (config.model_name, config.fold)
    test_modle = "./checkpoints/best_model/%s/%s/model_best_88.75.pt" % (config.model_name, config.fold)
    print("Test modle:", test_modle)
    best_model = torch.load(test_modle)
    model.load_state_dict(best_model["state_dict"])

    test_dataloader = DataLoader(H5Dataset(config.test_data, 0, num_test), batch_size=20, shuffle=False)
    test(test_dataloader, model, fold)
示例#5
0
def create_model(config):
    model_type = config["model_type"]
    if model_type == "SimpleConvNet":
        if model_type not in config:
            config[model_type] = {
                "conv1_size": 32,
                "conv2_size": 64,
                "fc_size": 128
            }
        model = SimpleConvNet(**config[model_type])
    elif model_type == "MiniVGG":
        if model_type not in config:
            config[model_type] = {
                "conv1_size": 128,
                "conv2_size": 256,
                "classifier_size": 1024
            }
        model = MiniVGG(**config[model_type])
    elif model_type == "WideResNet":
        if model_type not in config:
            config[model_type] = {
                "depth": 34,
                "num_classes": 10,
                "widen_factor": 10
            }
        model = WideResNet(**config[model_type])
    # elif model_type == "ShuffleNetv2":
    #     if model_type not in config:
    #         config[model_type] = {}
    #     model = shufflenet_v2_x0_5()
    elif model_type == "MobileNetv2":
        if model_type not in config:
            config[model_type] = {"pretrained": False}
        model = mobilenet_v2(num_classes=10,
                             pretrained=config[model_type]["pretrained"])
    else:
        print(f"Error: MODEL_TYPE {model_type} unknown.")
        exit()

    config["num_parameters"] = sum(p.numel() for p in model.parameters())
    config["num_trainable_parameters"] = sum(p.numel()
                                             for p in model.parameters()
                                             if p.requires_grad)
    return model
示例#6
0
def load_model_quantized(model_name, device, dataset, num_labels):
    pretrained = (dataset == "imagenet")
    if model_name == "mobilenet":
        model = models.mobilenet_v2(pretrained=pretrained, progress=True, quantize=False)
    elif model_name == "resnet50":
        model = torchvision.models.quantization.resnet50(pretrained=pretrained, progress=True, quantize=False)
    elif model_name == "resnet50_ptcv":
        model = ptcv.qresnet50_ptcv(pretrained=pretrained)
    elif model_name == "inceptionv3":
        model = models.inception_v3(pretrained=pretrained, progress=True, quantize=False)
    elif model_name == "googlenet":
        model = models.googlenet(pretrained=pretrained, progress=True, quantize=False)
    elif model_name == "shufflenetv2":
        model = models.shufflenet_v2_x1_0(pretrained=pretrained, progress=True, quantize=False)
    elif model_name == 'dlrm':
        # These arguments are hardcoded to the defaults from DLRM (matching the pretrained model).
        model = DLRM_Net(16,
                         np.array([1460, 583, 10131227, 2202608, 305, 24, 12517, 633, 3, 93145, 5683,
                                   8351593, 3194, 27, 14992, 5461306, 10, 5652, 2173, 4, 7046547,
                                   18, 15, 286181, 105, 142572], dtype=np.int32),
                         np.array([13, 512, 256,  64,  16]),
                         np.array([367, 512, 256,   1]),
                         'dot', False, -1, 2, True, 0., 1, False, 'mult', 4, 200, False, 200)
        ld_model = torch.load('data/dlrm.pt')
        model.load_state_dict(ld_model["state_dict"])
    elif model_name == 'bert':
        config = AutoConfig.from_pretrained(
            'bert-base-cased',
            num_labels=num_labels,
            finetuning_task='mnli',
        )
        model = BertForSequenceClassification.from_pretrained('data/bert.bin', from_tf=False, config=config)
    else:
        raise ValueError("Unsupported model type")

    if dataset == "cifar10":
        ld_model = torch.load(f"data/{model_name}.pt")
        model.load_state_dict(ld_model)

    model = model.to(device)
    return model
示例#7
0
def main():
    fold = config.fold
    # mkdirs
    if not os.path.exists(config.submit):
        os.mkdir(config.submit)
    if not os.path.exists(config.weights):
        os.mkdir(config.weights)
    if not os.path.exists(config.best_models):
        os.mkdir(config.best_models)
    if not os.path.exists(config.logs):
        os.mkdir(config.logs)
    if not os.path.exists(config.weights + config.model_name + os.sep + fold +
                          os.sep):
        os.makedirs(config.weights + config.model_name + os.sep + fold +
                    os.sep)
    if not os.path.exists(config.best_models + config.model_name + os.sep +
                          fold + os.sep):
        os.makedirs(config.best_models + config.model_name + os.sep + fold +
                    os.sep)

    # get model and optimizer
    if config.model_name is "lenet":
        model = get_lenet(config.img_channels)
    elif config.model_name is "mobilenet":
        model = get_mobilenet(config.img_channels)
    elif config.model_name is "mobilenet_v2":
        model = mobilenet_v2(config.img_channels)
    elif config.model_name is "resnet":
        model = resnet18()
    elif config.model_name is "densenet":
        model = densenet121()
    elif config.model_name is "inception_v3":
        model = inception_v3()

    # from thop import profile
    # input = torch.randn(1, 12, 60, 60)

    # flops, params = profile(model, inputs=(input,))
    # print("[DEBUG] Number of FLOPs: %.2fM" % (flops / 1e6))  # 12ch: 36.91M  3ch: 22.80M
    # print("[DEBUG] Number of parameter: %.2fM" % (params / 1e6))  # 12ch: 3.63M  3ch: 3.63M

    # import time
    # model.cuda()
    # input = Variable(input).cuda()
    # sum = 0
    # for i in range(10000):
    #     start = time.time()
    #     output = model(input)
    #     sum += time.time()-start
    # print("forward time:", (sum / 10000) * 1000)  # 12ch: 0.2694ms  3ch: 0.2631ms

    print(model, "\n")
    # model = torch.nn.DataParallel(model)

    exit()
    model.cuda()

    # optimizer = optim.SGD(model.parameters(),lr = config.lr,momentum=0.9,weight_decay=config.weight_decay)
    optimizer = optim.Adam(model.parameters(),
                           lr=config.lr,
                           amsgrad=True,
                           weight_decay=config.weight_decay)
    criterion = nn.CrossEntropyLoss().cuda()
    # some parameters for restart model
    start_epoch = 0
    best_precision1 = 0
    best_precision_save = 0
    resume = False

    # restart the training process
    if resume:
        resume_file = config.best_models + config.model_name + "/" + fold + "/model_best_89.32.pt"
        print("resume" + resume_file)
        checkpoint = torch.load(resume_file)
        start_epoch = checkpoint["epoch"]
        fold = checkpoint["fold"]
        best_precision1 = checkpoint["best_precision1"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])

    # load dataset
    with h5py.File(config.train_data, 'r') as db:
        num_train = db.attrs['size']
        # num_train = 10
        print('train dataset size:', num_train)
    train_dataloader = DataLoader(H5Dataset(config.train_data, 0, num_train),
                                  batch_size=config.batch_size,
                                  shuffle=True)
    with h5py.File(config.val_data, 'r') as db:
        num_val = db.attrs['size']
        print('val   dataset size:', num_val)

    val_dataloader = DataLoader(H5Dataset(config.val_data, 0, num_val),
                                batch_size=config.batch_size * 2,
                                shuffle=True)

    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", verbose=1, patience=3)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    # define metrics
    train_losses = AverageMeter()
    train_top1 = AverageMeter()
    valid_loss = [np.inf, 0, 0]
    model.train()

    # train
    start = timer()
    for epoch in range(start_epoch, config.epochs):
        scheduler.step(epoch)
        lr = get_learning_rate(optimizer)
        # print("learning rate:", lr)

        train_progressor = ProgressBar(mode="Train",
                                       epoch=epoch,
                                       total_epoch=config.epochs,
                                       model_name=config.model_name,
                                       total=len(train_dataloader),
                                       resume=resume)
        # train
        # global iter
        for iter, (input, target) in enumerate(train_dataloader):
            # switch to continue train process
            train_progressor.current = iter
            model.train()
            input = Variable(input).cuda()
            target = Variable(target).cuda()
            starttime = time.time()
            output = model(input)
            endtime = time.time()
            print("forward time:", (endtime - starttime) * 1000)

            loss = criterion(output, target)

            precision1_train, precision2_train = accuracy(output,
                                                          target,
                                                          topk=(1, 2))
            train_losses.update(loss.item(), input.size(0))
            train_top1.update(precision1_train[0], input.size(0))
            train_progressor.current_lr = lr
            train_progressor.current_loss = train_losses.avg
            train_progressor.current_top1 = train_top1.avg
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_progressor()

            logger.add_scalar('train_loss', train_losses.avg,
                              iter + epoch * len(train_dataloader))
            logger.add_scalar('train_acc', train_top1.avg,
                              iter + epoch * len(train_dataloader))

        train_progressor.done()

        # evaluate every epoch
        valid_loss = evaluate(val_dataloader, model, criterion, epoch, lr,
                              resume)
        is_best = valid_loss[1] > best_precision1
        best_precision1 = max(valid_loss[1], best_precision1)
        try:
            best_precision_save = best_precision1.cpu().data.numpy()
        except:
            pass
        save_checkpoint(
            {
                "epoch": epoch + 1,
                "model_name": config.model_name,
                "state_dict": model.state_dict(),
                "best_precision1": best_precision1,
                "optimizer": optimizer.state_dict(),
                "fold": fold,
                "valid_loss": valid_loss,
            }, is_best, fold)
示例#8
0
def train():
    # 定义数据集
    dataset_root = args.dataset_root
    if args.dataset == 'imagenet':
        dataset = ImageNet(
            dataset_root,
            transforms.Compose([
                transforms.Resize([300, 300]),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=imagenet_config['mean'],
                                     std=imagenet_config['std']),
            ]))

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=imagenet_config['batch_size'],
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    # 定义网络结构
    if args.model == 'mobilenet_bn':
        net = models.mobilenet_v2(num_classes=imagenet_config['num_classes'])
    net_gpus = net

    # 判断是否使用多gpus
    if args.cuda and args.multi_cuda:
        net_gpus = torch.nn.DataParallel(net)
        cudnn.benchmark = True
        print("GPU name is: {}".format(
            torch.cuda.get_device_name(torch.cuda.current_device())))

    # 初始化网络权重, 这里要使用net
    if args.resume:
        net.load_state_dict(torch.load(args.resume))
    else:
        print("Initing weights...")
        net.init_weights()

    # 将net转移至gpu中
    if args.cuda:
        print("Converting model to GPU...")
        net_gpus.cuda()

    net_gpus.train()

    # 定义优化器
    optimizer = optim.SGD(net.parameters(),
                          lr=imagenet_config['lr_init'],
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    criterion = classification_loss()
    if args.cuda:
        criterion = criterion.cuda()

    for epoch in range(imagenet_config['max_epoch']):
        # 每一个epoch调整一次学习率
        step_index = epoch // imagenet_config['lr_interval']
        adjust_learning_rate(optimizer, imagenet_config['lr_init'],
                             imagenet_config['gamma'], step_index)

        for iteration, (images, targets) in enumerate(dataloader):
            if args.cuda:
                images = Variable(images.cuda())
                targets = Variable(targets.cuda())
            else:
                images = Variable(images)
                targets = Variable(targets)

            begin_time = time.time()
            predictions = net_gpus(images)

            optimizer.zero_grad()
            targets = targets.squeeze(1)
            model_loss = criterion(predictions, targets)
            model_loss.backward()
            optimizer.step()

            end_time = time.time()
            if iteration % 10 == 0:
                print("---[Epoch:%d]" % epoch + "---Iter: %d " % iteration +
                      "---Loss: %.4f---" % (model_loss.item()) +
                      "%.4f seconds/iter.---" % (end_time - begin_time))

        print("---Saving model.---")
        torch.save(net.state_dict(),
                   'weights/mobilenet_imagenet_epoch_' + repr(epoch) + '.pth')
    # 3.bool parameters
    use_cuda = False


config = DefaultConfigs()

device = torch.device("cuda:0" if config.use_cuda else "cpu")

# An instance of your model.
if config.model_name is "lenet":
    model = get_lenet(config.img_channels)
elif config.model_name is "mobilenet":
    model = get_mobilenet(config.img_channels)
elif config.model_name is "mobilenet_v2":
    model = mobilenet_v2(config.img_channels)
elif config.model_name is "resnet":
    model = resnet18()
elif config.model_name is "densenet":
    model = densenet121()
elif config.model_name is "inception_v3":
    model = inception_v3()

print(model)
model.load_state_dict(torch.load(config.pre_trained)["state_dict"])

if config.use_cuda:
    model.to(device)

# An example input you would normally provide to your model's forward() method.
示例#10
0
def train_multiclass(train_file, test_file, stat_file,
                     model='mobilenet_v2',
                     classes=('artist_name', 'genre', 'style', 'technique', 'century'),
                     label_file='_user_labels.pkl',
                     im_path='/export/home/kschwarz/Documents/Data/Wikiart_artist49_images',
                     chkpt=None, weight_file=None,
                     triplet_selector='semihard', margin=0.2,
                     labels_per_class=4, samples_per_label=4,
                     use_gpu=True, device=0,
                     epochs=100, batch_size=32, lr=1e-4, momentum=0.9,
                     log_interval=10, log_dir='runs',
                     exp_name=None, seed=123):
    argvars = locals().copy()
    torch.manual_seed(seed)

    # LOAD DATASET
    with open(stat_file, 'r') as f:
        data = pickle.load(f)
        mean, std = data['mean'], data['std']
        mean = [float(m) for m in mean]
        std = [float(s) for s in std]
    normalize = transforms.Normalize(mean=mean, std=std)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.ToTensor(),
                normalize,
    ])

    if model.lower() == 'inception_v3':            # change input size to 299
        train_transform.transforms[0].size = (299, 299)
        val_transform.transforms[0].size = (299, 299)
    trainset = create_trainset(train_file, label_file, im_path, train_transform, classes)
    for c in classes:
        if len(trainset.labels_to_ints[c]) < labels_per_class:
            print('less labels in class {} than labels_per_class, use all available labels ({})'
                  .format(c, len(trainset.labels_to_ints[c])))
    valset = create_valset(test_file, im_path, val_transform, trainset.labels_to_ints)
    # PARAMETERS
    use_cuda = use_gpu and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(device)
        torch.cuda.manual_seed_all(seed)

    if model.lower() not in ['squeezenet', 'mobilenet_v1', 'mobilenet_v2', 'vgg16_bn', 'inception_v3', 'alexnet']:
        assert False, 'Unknown model {}\n\t+ Choose from: ' \
                      '[sqeezenet, mobilenet_v1, mobilenet_v2, vgg16_bn, inception_v3, alexnet].'.format(model)
    elif model.lower() == 'mobilenet_v1':
        bodynet = mobilenet_v1(pretrained=weight_file is None)
    elif model.lower() == 'mobilenet_v2':
        bodynet = mobilenet_v2(pretrained=weight_file is None)
    elif model.lower() == 'vgg16_bn':
        bodynet = vgg16_bn(pretrained=weight_file is None)
    elif model.lower() == 'inception_v3':
        bodynet = inception_v3(pretrained=weight_file is None)
    elif model.lower() == 'alexnet':
        bodynet = alexnet(pretrained=weight_file is None)
    else:       # squeezenet
        bodynet = squeezenet(pretrained=weight_file is None)

    # Load weights for the body network
    if weight_file is not None:
        print("=> loading weights from '{}'".format(weight_file))
        pretrained_dict = torch.load(weight_file, map_location=lambda storage, loc: storage)['state_dict']
        state_dict = bodynet.state_dict()
        pretrained_dict = {k.replace('bodynet.', ''): v for k, v in pretrained_dict.items()         # in case of multilabel weight file
                           if (k.replace('bodynet.', '') in state_dict.keys() and v.shape == state_dict[k.replace('bodynet.', '')].shape)}  # number of classes might have changed
        # check which weights will be transferred
        if not pretrained_dict == state_dict:  # some changes were made
            for k in set(state_dict.keys() + pretrained_dict.keys()):
                if k in state_dict.keys() and k not in pretrained_dict.keys():
                    print('\tWeights for "{}" were not found in weight file.'.format(k))
                elif k in pretrained_dict.keys() and k not in state_dict.keys():
                    print('\tWeights for "{}" were are not part of the used model.'.format(k))
                elif state_dict[k].shape != pretrained_dict[k].shape:
                    print('\tShapes of "{}" are different in model ({}) and weight file ({}).'.
                          format(k, state_dict[k].shape, pretrained_dict[k].shape))
                else:  # everything is good
                    pass

        state_dict.update(pretrained_dict)
        bodynet.load_state_dict(state_dict)

    net = MetricNet(bodynet, len(classes))

    n_parameters = sum([p.data.nelement() for p in net.parameters() if p.requires_grad])
    if use_cuda:
        net = net.cuda()
    print('Using {}\n\t+ Number of params: {}'.format(str(net).split('(', 1)[0], n_parameters))

    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)

    # tensorboard summary writer
    timestamp = time.strftime('%m-%d-%H-%M')
    expname = timestamp + '_' + str(net).split('(', 1)[0]
    if exp_name is not None:
        expname = expname + '_' + exp_name
    log = TBPlotter(os.path.join(log_dir, 'tensorboard', expname))
    log.print_logdir()

    # allow auto-tuner to find best algorithm for the hardware
    cudnn.benchmark = True

    with open(label_file, 'rb') as f:
        labels = pickle.load(f)['labels']
        n_labeled = '\t'.join([str(Counter(l).items()) for l in labels.transpose()])

    write_config(argvars, os.path.join(log_dir, expname), extras={'n_labeled': n_labeled})


    # ININTIALIZE TRAINING
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, threshold=1e-1, verbose=True)

    if triplet_selector.lower() not in ['random', 'semihard', 'hardest', 'mixed', 'khardest']:
        assert False, 'Unknown option {} for triplet selector. Choose from "random", "semihard", "hardest" or "mixed"' \
                      '.'.format(triplet_selector)
    elif triplet_selector.lower() == 'random':
        criterion = TripletLoss(margin=margin,
                                triplet_selector=RandomNegativeTripletSelector(margin, cpu=not use_cuda))
    elif triplet_selector.lower() == 'semihard' or triplet_selector.lower() == 'mixed':
        criterion = TripletLoss(margin=margin,
                                triplet_selector=SemihardNegativeTripletSelector(margin, cpu=not use_cuda))
    elif triplet_selector.lower() == 'khardest':
        criterion = TripletLoss(margin=margin,
                                triplet_selector=KHardestNegativeTripletSelector(margin, k=3, cpu=not use_cuda))
    else:
        criterion = TripletLoss(margin=margin,
                                triplet_selector=HardestNegativeTripletSelector(margin, cpu=not use_cuda))
    if use_cuda:
        criterion = criterion.cuda()

    kwargs = {'num_workers': 4} if use_cuda else {}
    multilabel_train = np.stack([trainset.df[c].values for c in classes]).transpose()
    train_batch_sampler = BalancedBatchSamplerMulticlass(multilabel_train, n_label=labels_per_class,
                                                         n_per_label=samples_per_label, ignore_label=None)
    trainloader = DataLoader(trainset, batch_sampler=train_batch_sampler, **kwargs)
    multilabel_val = np.stack([valset.df[c].values for c in classes]).transpose()
    val_batch_sampler = BalancedBatchSamplerMulticlass(multilabel_val, n_label=labels_per_class,
                                                       n_per_label=samples_per_label, ignore_label=None)
    valloader = DataLoader(valset, batch_sampler=val_batch_sampler, **kwargs)

    # optionally resume from a checkpoint
    start_epoch = 1
    if chkpt is not None:
        if os.path.isfile(chkpt):
            print("=> loading checkpoint '{}'".format(chkpt))
            checkpoint = torch.load(chkpt, map_location=lambda storage, loc: storage)
            start_epoch = checkpoint['epoch']
            best_acc_score = checkpoint['best_acc_score']
            best_acc = checkpoint['acc']
            net.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(chkpt, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(chkpt))

    def train(epoch):
        losses = AverageMeter()
        gtes = AverageMeter()
        non_zero_triplets = AverageMeter()
        distances_ap = AverageMeter()
        distances_an = AverageMeter()

        # switch to train mode
        net.train()
        for batch_idx, (data, target) in enumerate(trainloader):
            target = torch.stack(target)
            if use_cuda:
                data, target = Variable(data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)

            # normalize features
            for i in range(len(classes)):
                outputs[i] = torch.nn.functional.normalize(outputs[i], p=2, dim=1)

            loss = Variable(torch.Tensor([0]), requires_grad=True).type_as(data[0])
            n_triplets = 0
            for op, tgt in zip(outputs, target):
                # filter unlabeled samples if there are any (have label -1)
                labeled = (tgt != -1).nonzero().view(-1)
                op, tgt = op[labeled], tgt[labeled]

                l, nt = criterion(op, tgt)
                loss += l
                n_triplets += nt

            non_zero_triplets.update(n_triplets, target[0].size(0))
            # measure GTE and record loss
            gte, dist_ap, dist_an = GTEMulticlass(outputs, target)           # do not compute ap pairs for concealed classes
            gtes.update(gte.data, target[0].size(0))
            distances_ap.update(dist_ap.data, target[0].size(0))
            distances_an.update(dist_an.data, target[0].size(0))
            losses.update(loss.data[0], target[0].size(0))

            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'GTE: {:.2f}% ({:.2f}%)\t'
                      'Non-zero Triplets: {:d} ({:d})'.format(
                    epoch, batch_idx * len(target[0]), len(trainloader) * len(target[0]),
                    float(losses.val), float(losses.avg),
                    float(gtes.val) * 100., float(gtes.avg) * 100.,
                    int(non_zero_triplets.val), int(non_zero_triplets.avg)))

        # log avg values to somewhere
        log.write('loss', float(losses.avg), epoch, test=False)
        log.write('gte', float(gtes.avg), epoch, test=False)
        log.write('non-zero trplts', int(non_zero_triplets.avg), epoch, test=False)
        log.write('dist_ap', float(distances_ap.avg), epoch, test=False)
        log.write('dist_an', float(distances_an.avg), epoch, test=False)

    def test(epoch):
        losses = AverageMeter()
        gtes = AverageMeter()
        non_zero_triplets = AverageMeter()
        distances_ap = AverageMeter()
        distances_an = AverageMeter()

        # switch to evaluation mode
        net.eval()
        for batch_idx, (data, target) in enumerate(valloader):
            target = torch.stack(target)
            if use_cuda:
                data, target = Variable(data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]
            # compute output
            outputs = net(data)

            # normalize features
            for i in range(len(classes)):
                outputs[i] = torch.nn.functional.normalize(outputs[i], p=2, dim=1)

            loss = Variable(torch.Tensor([0]), requires_grad=True).type_as(data[0])
            n_triplets = 0
            for op, tgt in zip(outputs, target):
                # filter unlabeled samples if there are any (have label -1)
                labeled = (tgt != -1).nonzero().view(-1)
                op, tgt = op[labeled], tgt[labeled]

                l, nt = criterion(op, tgt)
                loss += l
                n_triplets += nt

            non_zero_triplets.update(n_triplets, target[0].size(0))
            # measure GTE and record loss
            gte, dist_ap, dist_an = GTEMulticlass(outputs, target)
            gtes.update(gte.data.cpu(), target[0].size(0))
            distances_ap.update(dist_ap.data.cpu(), target[0].size(0))
            distances_an.update(dist_an.data.cpu(), target[0].size(0))
            losses.update(loss.data[0].cpu(), target[0].size(0))

        print('\nVal set: Average loss: {:.4f} Average GTE {:.2f}%, '
              'Average non-zero triplets: {:d} LR: {:.6f}'.format(float(losses.avg), float(gtes.avg) * 100.,
                                                       int(non_zero_triplets.avg),
                                                                  optimizer.param_groups[-1]['lr']))
        log.write('loss', float(losses.avg), epoch, test=True)
        log.write('gte', float(gtes.avg), epoch, test=True)
        log.write('non-zero trplts', int(non_zero_triplets.avg), epoch, test=True)
        log.write('dist_ap', float(distances_ap.avg), epoch, test=True)
        log.write('dist_an', float(distances_an.avg), epoch, test=True)
        return losses.avg, 1 - gtes.avg

    if start_epoch == 1:         # compute baseline:
        _, best_acc = test(epoch=0)
    else:       # checkpoint was loaded
        best_acc = best_acc

    for epoch in range(start_epoch, epochs + 1):
        if triplet_selector.lower() == 'mixed' and epoch == 26:
            criterion.triplet_selector = HardestNegativeTripletSelector(margin, cpu=not use_cuda)
            print('Changed negative selection from semihard to hardest.')
        # train for one epoch
        train(epoch)
        # evaluate on validation set
        val_loss, val_acc = test(epoch)
        scheduler.step(val_loss)

        # remember best acc and save checkpoint
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint({
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'best_acc': best_acc,
        }, is_best, expname, directory=log_dir)

        if optimizer.param_groups[-1]['lr'] < 1e-5:
            print('Learning rate reached minimum threshold. End training.')
            break

    # report best values
    best = torch.load(os.path.join(log_dir, expname + '_model_best.pth.tar'), map_location=lambda storage, loc: storage)
    print('Finished training after epoch {}:\n\tbest acc score: {}'
          .format(best['epoch'], best['acc']))
    print('Best model mean accuracy: {}'.format(best_acc))
示例#11
0
def main():
    # set the path to pre-trained model and output
    args.outf = args.outf + args.net_type + '_' + args.dataset + '/'
    if os.path.isdir(args.outf) == False:
        os.mkdir(args.outf)
    torch.cuda.manual_seed(0)
    torch.cuda.set_device(args.gpu)

    out_dist_list = [
        'skin_cli', 'skin_derm', 'corrupted', 'corrupted_70', 'imgnet', 'nct',
        'final_test'
    ]

    # load networks
    if args.net_type == 'densenet_121':
        model = densenet_121.Net(models.densenet121(pretrained=False), 8)
        ckpt = torch.load("../checkpoints/densenet-121/checkpoint.pth")
        model.load_state_dict(ckpt['model_state_dict'])
        model.eval()
        model.cuda()
    elif args.net_type == 'mobilenet':
        model = mobilenet.Net(models.mobilenet_v2(pretrained=False), 8)
        ckpt = torch.load("../checkpoints/mobilenet/checkpoint.pth")
        model.load_state_dict(ckpt['model_state_dict'])
        model.eval()
        model.cuda()
        print("Done!")
    elif args.net_type == 'resnet_50':
        model = resnet_50.Net(models.resnet50(pretrained=False), 8)
        ckpt = torch.load("../checkpoints/resnet-50/checkpoint.pth")
        model.load_state_dict(ckpt['model_state_dict'])
        model.eval()
        model.cuda()
        print("Done!")
    elif args.net_type == 'vgg_16':
        model = vgg_16.Net(models.vgg16_bn(pretrained=False), 8)
        ckpt = torch.load("../checkpoints/vgg-16/checkpoint.pth")
        model.load_state_dict(ckpt['model_state_dict'])
        model.eval()
        model.cuda()
        print("Done!")
    else:
        raise Exception(f"There is no net_type={args.net_type} available.")

    in_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    print('load model: ' + args.net_type)

    # load dataset
    print('load target data: ', args.dataset)
    train_loader, test_loader = data_loader.getTargetDataSet(
        args.dataset, args.batch_size, in_transform, args.dataroot)

    # set information about feature extaction
    model.eval()
    temp_x = torch.rand(2, 3, 224, 224).cuda()
    temp_x = Variable(temp_x)
    temp_list = model.feature_list(temp_x)[1]
    num_output = len(temp_list)
    feature_list = np.empty(num_output)
    count = 0
    for out in temp_list:
        feature_list[count] = out.size(1)
        count += 1

    print('get sample mean and covariance')
    sample_mean, precision = lib_generation.sample_estimator(
        model, args.num_classes, feature_list, train_loader)

    print('get Mahalanobis scores')
    m_list = [0.0, 0.01, 0.005, 0.002, 0.0014, 0.001, 0.0005]

    for magnitude in m_list:
        print('Noise: ' + str(magnitude))
        for i in range(num_output):
            M_in = lib_generation.get_Mahalanobis_score(model, test_loader, args.num_classes, args.outf, \
                                                        True, args.net_type, sample_mean, precision, i, magnitude)
            M_in = np.asarray(M_in, dtype=np.float32)
            if i == 0:
                Mahalanobis_in = M_in.reshape((M_in.shape[0], -1))
            else:
                Mahalanobis_in = np.concatenate(
                    (Mahalanobis_in, M_in.reshape((M_in.shape[0], -1))),
                    axis=1)

        for out_dist in out_dist_list:
            out_test_loader = data_loader.getNonTargetDataSet(
                out_dist, args.batch_size, in_transform, args.dataroot)
            print('Out-distribution: ' + out_dist)
            for i in range(num_output):
                M_out = lib_generation.get_Mahalanobis_score(model, out_test_loader, args.num_classes, args.outf, \
                                                             False, args.net_type, sample_mean, precision, i, magnitude)
                M_out = np.asarray(M_out, dtype=np.float32)
                if i == 0:
                    Mahalanobis_out = M_out.reshape((M_out.shape[0], -1))
                else:
                    Mahalanobis_out = np.concatenate(
                        (Mahalanobis_out, M_out.reshape((M_out.shape[0], -1))),
                        axis=1)

            Mahalanobis_in = np.asarray(Mahalanobis_in, dtype=np.float32)
            Mahalanobis_out = np.asarray(Mahalanobis_out, dtype=np.float32)
            Mahalanobis_data, Mahalanobis_labels = lib_generation.merge_and_generate_labels(
                Mahalanobis_out, Mahalanobis_in)
            file_name = os.path.join(
                args.outf, 'Mahalanobis_%s_%s_%s.npy' %
                (str(magnitude), args.dataset, out_dist))
            Mahalanobis_data = np.concatenate(
                (Mahalanobis_data, Mahalanobis_labels), axis=1)
            np.save(file_name, Mahalanobis_data)
示例#12
0
def train_multiclass(
        train_file,
        test_file,
        stat_file,
        model='mobilenet_v2',
        classes=('artist_name', 'genre', 'style', 'technique', 'century'),
        im_path='/export/home/kschwarz/Documents/Data/Wikiart_artist49_images',
        label_file='_user_labels.pkl',
        chkpt=None,
        weight_file=None,
        use_gpu=True,
        device=0,
        epochs=100,
        batch_size=32,
        lr=1e-4,
        momentum=0.9,
        log_interval=10,
        log_dir='runs',
        exp_name=None,
        seed=123):
    argvars = locals().copy()
    torch.manual_seed(seed)

    # LOAD DATASET
    with open(stat_file, 'r') as f:
        data = pickle.load(f)
        mean, std = data['mean'], data['std']
        mean = [float(m) for m in mean]
        std = [float(s) for s in std]
    normalize = transforms.Normalize(mean=mean, std=std)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    if model.lower() == 'inception_v3':  # change input size to 299
        train_transform.transforms[0].size = (299, 299)
        val_transform.transforms[0].size = (299, 299)
    trainset = create_trainset(train_file, label_file, im_path,
                               train_transform, classes)
    valset = create_valset(test_file, im_path, val_transform,
                           trainset.labels_to_ints)
    num_labels = [len(trainset.labels_to_ints[c]) for c in classes]

    # PARAMETERS
    use_cuda = use_gpu and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(device)
        torch.cuda.manual_seed_all(seed)

    if model.lower() not in [
            'squeezenet', 'mobilenet_v1', 'mobilenet_v2', 'vgg16_bn',
            'inception_v3', 'alexnet'
    ]:
        assert False, 'Unknown model {}\n\t+ Choose from: ' \
                      '[sqeezenet, mobilenet_v1, mobilenet_v2, vgg16_bn, inception_v3, alexnet].'.format(model)
    elif model.lower() == 'mobilenet_v1':
        bodynet = mobilenet_v1(pretrained=weight_file is None)
    elif model.lower() == 'mobilenet_v2':
        bodynet = mobilenet_v2(pretrained=weight_file is None)
    elif model.lower() == 'vgg16_bn':
        bodynet = vgg16_bn(pretrained=weight_file is None)
    elif model.lower() == 'inception_v3':
        bodynet = inception_v3(pretrained=weight_file is None)
    elif model.lower() == 'alexnet':
        bodynet = alexnet(pretrained=weight_file is None)
    else:  # squeezenet
        bodynet = squeezenet(pretrained=weight_file is None)

    # Load weights for the body network
    if weight_file is not None:
        print("=> loading weights from '{}'".format(weight_file))
        pretrained_dict = torch.load(
            weight_file,
            map_location=lambda storage, loc: storage)['state_dict']
        state_dict = bodynet.state_dict()
        pretrained_dict = {
            k.replace('bodynet.', ''): v
            for k, v in
            pretrained_dict.items()  # in case of multilabel weight file
            if (k.replace('bodynet.', '') in state_dict.keys()
                and v.shape == state_dict[k.replace('bodynet.', '')].shape)
        }  # number of classes might have changed
        # check which weights will be transferred
        if not pretrained_dict == state_dict:  # some changes were made
            for k in set(state_dict.keys() + pretrained_dict.keys()):
                if k in state_dict.keys() and k not in pretrained_dict.keys():
                    print('\tWeights for "{}" were not found in weight file.'.
                          format(k))
                elif k in pretrained_dict.keys() and k not in state_dict.keys(
                ):
                    print(
                        '\tWeights for "{}" were are not part of the used model.'
                        .format(k))
                elif state_dict[k].shape != pretrained_dict[k].shape:
                    print(
                        '\tShapes of "{}" are different in model ({}) and weight file ({}).'
                        .format(k, state_dict[k].shape,
                                pretrained_dict[k].shape))
                else:  # everything is good
                    pass

        state_dict.update(pretrained_dict)
        bodynet.load_state_dict(state_dict)

    net = OctopusNet(bodynet, n_labels=num_labels)

    n_parameters = sum(
        [p.data.nelement() for p in net.parameters() if p.requires_grad])
    if use_cuda:
        net = net.cuda()
    print('Using {}\n\t+ Number of params: {}'.format(
        str(net).split('(', 1)[0], n_parameters))

    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)

    # tensorboard summary writer
    timestamp = time.strftime('%m-%d-%H-%M')
    expname = timestamp + '_' + str(net).split('(', 1)[0]
    if exp_name is not None:
        expname = expname + '_' + exp_name
    log = TBPlotter(os.path.join(log_dir, 'tensorboard', expname))
    log.print_logdir()

    # allow auto-tuner to find best algorithm for the hardware
    cudnn.benchmark = True

    with open(label_file, 'rb') as f:
        labels = pickle.load(f)['labels']
        n_labeled = '\t'.join(
            [str(Counter(l).items()) for l in labels.transpose()])

    write_config(argvars,
                 os.path.join(log_dir, expname),
                 extras={'n_labeled': n_labeled})

    # ININTIALIZE TRAINING
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     patience=10,
                                                     threshold=1e-1,
                                                     verbose=True)
    criterion = nn.CrossEntropyLoss()
    if use_cuda:
        criterion = criterion.cuda()

    kwargs = {'num_workers': 4} if use_cuda else {}
    trainloader = DataLoader(trainset,
                             batch_size=batch_size,
                             shuffle=True,
                             **kwargs)
    valloader = DataLoader(valset,
                           batch_size=batch_size,
                           shuffle=True,
                           **kwargs)

    # optionally resume from a checkpoint
    start_epoch = 1
    if chkpt is not None:
        if os.path.isfile(chkpt):
            print("=> loading checkpoint '{}'".format(chkpt))
            checkpoint = torch.load(chkpt,
                                    map_location=lambda storage, loc: storage)
            start_epoch = checkpoint['epoch']
            best_acc_score = checkpoint['best_acc_score']
            best_acc = checkpoint['acc']
            net.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                chkpt, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(chkpt))

    def train(epoch):
        losses = AverageMeter()
        accs = AverageMeter()
        class_acc = [AverageMeter() for i in range(len(classes))]

        # switch to train mode
        net.train()
        for batch_idx, (data, target) in enumerate(trainloader):
            if use_cuda:
                data, target = Variable(
                    data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)
            preds = [torch.max(outputs[i], 1)[1] for i in range(len(classes))]

            loss = Variable(torch.Tensor([0]),
                            requires_grad=True).type_as(data[0])
            for i, o, t, p in zip(range(len(classes)), outputs, target, preds):
                # filter unlabeled samples if there are any (have label -1)
                labeled = (t != -1).nonzero().view(-1)
                o, t, p = o[labeled], t[labeled], p[labeled]
                loss += criterion(o, t)
                # measure class accuracy and record loss
                class_acc[i].update(
                    (torch.sum(p == t).type(torch.FloatTensor) /
                     t.size(0)).data)
            accs.update(
                torch.mean(
                    torch.stack(
                        [class_acc[i].val for i in range(len(classes))])),
                target[0].size(0))
            losses.update(loss.data, target[0].size(0))

            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'Acc: {:.2f}% ({:.2f}%)'.format(epoch,
                                                      batch_idx * len(target),
                                                      len(trainloader.dataset),
                                                      float(losses.val),
                                                      float(losses.avg),
                                                      float(accs.val) * 100.,
                                                      float(accs.avg) * 100.))
                print('\t' + '\n\t'.join([
                    '{}: {:.2f}%'.format(classes[i],
                                         float(class_acc[i].val) * 100.)
                    for i in range(len(classes))
                ]))

        # log avg values to somewhere
        log.write('loss', float(losses.avg), epoch, test=False)
        log.write('acc', float(accs.avg), epoch, test=False)
        for i in range(len(classes)):
            log.write('class_acc', float(class_acc[i].avg), epoch, test=False)

    def test(epoch):
        losses = AverageMeter()
        accs = AverageMeter()
        class_acc = [AverageMeter() for i in range(len(classes))]

        # switch to evaluation mode
        net.eval()
        for batch_idx, (data, target) in enumerate(valloader):
            if use_cuda:
                data, target = Variable(
                    data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)
            preds = [torch.max(outputs[i], 1)[1] for i in range(len(classes))]

            loss = Variable(torch.Tensor([0]),
                            requires_grad=True).type_as(data[0])
            for i, o, t, p in zip(range(len(classes)), outputs, target, preds):
                labeled = (t != -1).nonzero().view(-1)
                loss += criterion(o[labeled], t[labeled])
                # measure class accuracy and record loss
                class_acc[i].update((torch.sum(p[labeled] == t[labeled]).type(
                    torch.FloatTensor) / t[labeled].size(0)).data)
            accs.update(
                torch.mean(
                    torch.stack(
                        [class_acc[i].val for i in range(len(classes))])),
                target[0].size(0))
            losses.update(loss.data, target[0].size(0))

        score = accs.avg - torch.std(
            torch.stack([class_acc[i].avg for i in range(len(classes))])
        ) / accs.avg  # compute mean - std/mean as measure for accuracy
        print(
            '\nVal set: Average loss: {:.4f} Average acc {:.2f}% Acc score {:.2f} LR: {:.6f}'
            .format(float(losses.avg),
                    float(accs.avg) * 100., float(score),
                    optimizer.param_groups[-1]['lr']))
        print('\t' + '\n\t'.join([
            '{}: {:.2f}%'.format(classes[i],
                                 float(class_acc[i].avg) * 100.)
            for i in range(len(classes))
        ]))
        log.write('loss', float(losses.avg), epoch, test=True)
        log.write('acc', float(accs.avg), epoch, test=True)
        for i in range(len(classes)):
            log.write('class_acc', float(class_acc[i].avg), epoch, test=True)
        return losses.avg.cpu().numpy(), float(score), float(
            accs.avg), [float(class_acc[i].avg) for i in range(len(classes))]

    if start_epoch == 1:  # compute baseline:
        _, best_acc_score, best_acc, _ = test(epoch=0)
    else:  # checkpoint was loaded
        best_acc_score = best_acc_score
        best_acc = best_acc

    for epoch in range(start_epoch, epochs + 1):
        # train for one epoch
        train(epoch)
        # evaluate on validation set
        val_loss, val_acc_score, val_acc, val_class_accs = test(epoch)
        scheduler.step(val_loss)

        # remember best acc and save checkpoint
        is_best = val_acc_score > best_acc_score
        best_acc_score = max(val_acc_score, best_acc_score)
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'best_acc_score': best_acc_score,
                'acc': val_acc,
                'class_acc': {c: a
                              for c, a in zip(classes, val_class_accs)}
            },
            is_best,
            expname,
            directory=log_dir)

        if val_acc > best_acc:
            shutil.copyfile(
                os.path.join(log_dir, expname + '_checkpoint.pth.tar'),
                os.path.join(log_dir,
                             expname + '_model_best_mean_acc.pth.tar'))
        best_acc = max(val_acc, best_acc)

        if optimizer.param_groups[-1]['lr'] < 1e-5:
            print('Learning rate reached minimum threshold. End training.')
            break

    # report best values
    best = torch.load(os.path.join(log_dir, expname + '_model_best.pth.tar'),
                      map_location=lambda storage, loc: storage)
    print(
        'Finished training after epoch {}:\n\tbest acc score: {}\n\tacc: {}\n\t class acc: {}'
        .format(best['epoch'], best['best_acc_score'], best['acc'],
                best['class_acc']))
    print('Best model mean accuracy: {}'.format(best_acc))