예제 #1
0
def get_instance_segmentation_model(bone='resnet50', attention=False):
    if bone == 'mobilenet_v2':
        if attention == False:
            backbone = models.mobilenet_v2(pretrained=True,
                                           att=attention).features
        if attention == True:
            backbone = models.mobilenet_v2(pretrained=False,
                                           att=attention).features
        backbone.out_channels = 1280
    if bone == 'googlenet':
        if attention == False:
            backbone = models.googlenet(pretrained=True)
        if attention == True:
            backbone = models.googlenet(pretrained=False)
        backbone.out_channels = 1024
    if bone == 'densenet121':
        if attention == False:
            backbone = models.densenet121(pretrained=True,
                                          att=attention).features
        if attention == True:
            backbone = models.densenet121(pretrained=False,
                                          att=attention).features
        backbone.out_channels = 1024
    if bone == 'resnet50':
        if attention == False:
            backbone = models.resnet50(pretrained=True, att=attention)
        if attention == True:
            backbone = models.resnet50(pretrained=False, att=attention)
        backbone.out_channels = 2048
    if bone == 'shufflenet_v2_x1_0':
        if attention == False:
            backbone = models.shufflenet_v2_x1_0(pretrained=True)
        if attention == True:
            backbone = models.shufflenet_v2_x1_0(pretrained=False)
        backbone.out_channels = 1024
    if bone == 'inception_v3':
        if attention == False:
            backbone = models.inception_v3(
            )  #'InceptionOutputs' object has no attribute 'values'
        if attention == True:
            backbone = models.inception_v3(
            )  #'InceptionOutputs' object has no attribute 'values'
        backbone.out_channels = 2048
    if bone == 'squeezenet1_0':
        if attention == False:
            backbone = models.squeezenet1_0(pretrained=True).features
        if attention == True:
            backbone = models.squeezenet1_0(pretrained=False).features
        backbone.out_channels = 512

    anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512), ),
                                       aspect_ratios=((0.5, 1.0, 2.0), ))
    roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                    output_size=7,
                                                    sampling_ratio=2)
    model = MaskRCNN(backbone,
                     num_classes=2,
                     rpn_anchor_generator=anchor_generator,
                     box_roi_pool=roi_pooler)
    return model
예제 #2
0
def get_imagenet_models(model_name):
    if model_name == 'model_vgg16bn':
        from models import vgg16_bn
        model = vgg16_bn(pretrained=True)
    elif model_name == 'model_resnet18_imgnet':
        from models import resnet18
        model = resnet18(pretrained=True)
    elif model_name == 'model_inception':
        from models import inception_v3
        model = inception_v3(pretrained=True)
    else:
        raise ValueError(f'Buggya no model named {model_name}')
    # print(f'Model: {model_name}')
    return model
예제 #3
0
def get_model_for_attack(model_name):
    if model_name == 'model1':
        model = ResNet34()
        load_w(model, "./models/weights/resnet34.pt")
    elif model_name == 'model2':
        model = ResNet18()
        load_w(model, "./models/weights/resnet18_AT.pt")
    elif model_name == 'model3':
        model = SmallResNet()
        load_w(model, "./models/weights/res_small.pth")
    elif model_name == 'model4':
        model = WideResNet34()
        pref = next(model.parameters())
        model.load_state_dict(
            filter_state_dict(
                torch.load("./models/weights/trades_wide_resnet.pt",
                           map_location=pref.device)))
    elif model_name == 'model5':
        model = WideResNet()
        load_w(model, "./models/weights/wideres34-10-pgdHE.pt")
    elif model_name == 'model6':
        model = WideResNet28()
        pref = next(model.parameters())
        model.load_state_dict(
            filter_state_dict(
                torch.load('models/weights/RST-AWP_cifar10_linf_wrn28-10.pt',
                           map_location=pref.device)))
    elif model_name == 'model_vgg16bn':
        model = vgg16_bn(pretrained=True)
    elif model_name == 'model_resnet18_imgnet':
        model = resnet18(pretrained=True)
    elif model_name == 'model_inception':
        model = inception_v3(pretrained=True)
    elif model_name == 'model_vitb':
        from mnist_vit import ViT, MegaSizer
        model = MegaSizer(
            ImageNetRenormalize(ViT('B_16_imagenet1k', pretrained=True)))
    elif model_name.startswith('model_hub:'):
        _, a, b = model_name.split(":")
        model = torch.hub.load(a, b, pretrained=True)
        model = Cifar10Renormalize(model)
    elif model_name.startswith('model_mnist:'):
        _, a = model_name.split(":")
        model = torch.load('mnist.pt')[a]
    elif model_name.startswith('model_ex:'):
        _, a = model_name.split(":")
        model = torch.load(a)
    return model
예제 #4
0
def load_model_quantized(model_name, device, dataset, num_labels):
    pretrained = (dataset == "imagenet")
    if model_name == "mobilenet":
        model = models.mobilenet_v2(pretrained=pretrained, progress=True, quantize=False)
    elif model_name == "resnet50":
        model = torchvision.models.quantization.resnet50(pretrained=pretrained, progress=True, quantize=False)
    elif model_name == "resnet50_ptcv":
        model = ptcv.qresnet50_ptcv(pretrained=pretrained)
    elif model_name == "inceptionv3":
        model = models.inception_v3(pretrained=pretrained, progress=True, quantize=False)
    elif model_name == "googlenet":
        model = models.googlenet(pretrained=pretrained, progress=True, quantize=False)
    elif model_name == "shufflenetv2":
        model = models.shufflenet_v2_x1_0(pretrained=pretrained, progress=True, quantize=False)
    elif model_name == 'dlrm':
        # These arguments are hardcoded to the defaults from DLRM (matching the pretrained model).
        model = DLRM_Net(16,
                         np.array([1460, 583, 10131227, 2202608, 305, 24, 12517, 633, 3, 93145, 5683,
                                   8351593, 3194, 27, 14992, 5461306, 10, 5652, 2173, 4, 7046547,
                                   18, 15, 286181, 105, 142572], dtype=np.int32),
                         np.array([13, 512, 256,  64,  16]),
                         np.array([367, 512, 256,   1]),
                         'dot', False, -1, 2, True, 0., 1, False, 'mult', 4, 200, False, 200)
        ld_model = torch.load('data/dlrm.pt')
        model.load_state_dict(ld_model["state_dict"])
    elif model_name == 'bert':
        config = AutoConfig.from_pretrained(
            'bert-base-cased',
            num_labels=num_labels,
            finetuning_task='mnli',
        )
        model = BertForSequenceClassification.from_pretrained('data/bert.bin', from_tf=False, config=config)
    else:
        raise ValueError("Unsupported model type")

    if dataset == "cifar10":
        ld_model = torch.load(f"data/{model_name}.pt")
        model.load_state_dict(ld_model)

    model = model.to(device)
    return model
예제 #5
0
def get_model_for_attack(model_name):
    if model_name == 'model_vgg16bn':
        model = vgg16_bn(pretrained=True)
    elif model_name == 'model_resnet18':
        model = resnet18(pretrained=True)
    elif model_name == 'model_inceptionv3':
        model = inception_v3(pretrained=True)
    elif model_name == 'model_vitb':
        from mnist_vit import ViT, MegaSizer
        model = MegaSizer(
            ImageNetRenormalize(ViT('B_16_imagenet1k', pretrained=True)))
    elif model_name.startswith('model_hub:'):
        _, a, b = model_name.split(":")
        model = torch.hub.load(a, b, pretrained=True)
        model = Cifar10Renormalize(model)
    elif model_name.startswith('model_mnist:'):
        _, a = model_name.split(":")
        model = torch.load('mnist.pt')[a]
    elif model_name.startswith('model_ex:'):
        _, a = model_name.split(":")
        model = torch.load(a)
    else:
        raise ValueError(f'Model f{model_name} does not exist.')
    return model
예제 #6
0
def main():
    parser = OptionParser()
    parser.add_option('-j', '--workers', dest='workers', default=16, type='int',
                      help='number of data loading workers (default: 16)')
    parser.add_option('-e', '--epochs', dest='epochs', default=80, type='int',
                      help='number of epochs (default: 80)')
    parser.add_option('-b', '--batch-size', dest='batch_size', default=16, type='int',
                      help='batch size (default: 16)')
    parser.add_option('-c', '--ckpt', dest='ckpt', default=False,
                      help='load checkpoint model (default: False)')
    parser.add_option('-v', '--verbose', dest='verbose', default=100, type='int',
                      help='show information for each <verbose> iterations (default: 100)')
    parser.add_option('--lr', '--learning-rate', dest='lr', default=1e-3, type='float',
                      help='learning rate (default: 1e-3)')
    parser.add_option('--sf', '--save-freq', dest='save_freq', default=1, type='int',
                      help='saving frequency of .ckpt models (default: 1)')
    parser.add_option('--sd', '--save-dir', dest='save_dir', default='./models/wsdan/',
                      help='saving directory of .ckpt models (default: ./models/wsdan)')
    parser.add_option('--ln', '--log-name', dest='log_name', default='train.log',
                      help='log name  (default: train.log)')
    parser.add_option('--mn', '--model-name', dest='model_name', default='model.ckpt',
                      help='model name  (default:model.ckpt)')
    parser.add_option('--init', '--initial-training', dest='initial_training', default=1, type='int',
                      help='train from 1-beginning or 0-resume training (default: 1)')
 

    (options, args) = parser.parse_args()

    ##################################
    # Initialize saving directory
    ##################################
    if not os.path.exists(options.save_dir):
        os.makedirs(options.save_dir)

    ##################################
    # Logging setting
    ##################################
    logging.basicConfig(
        filename=os.path.join( options.save_dir, options.log_name),
        filemode='w',
        format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
        level=logging.INFO)
    warnings.filterwarnings("ignore")

    ##################################
    # Load dataset
    ##################################
    image_size = (256,256)
    num_classes = 4
    transform = transforms.Compose([transforms.Resize(size=image_size),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                    std=[0.229, 0.224, 0.225])])
    train_dataset = CustomDataset(data_root='/mnt/HDD/RFW/train/data/',csv_file='data/RFW_Train40k_Images_Metada.csv',transform=transform)
    val_dataset = CustomDataset(data_root='/mnt/HDD/RFW/train/data/',csv_file='data/RFW_Val4k_Images_Metadata.csv',transform=transform)
    test_dataset = CustomDataset(data_root='/mnt/HDD/RFW/test/data/',csv_file='data/RFW_Test_Images_Metadata.csv',transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=options.batch_size, shuffle=True,num_workers=options.workers, pin_memory=True)
    validate_loader = DataLoader(val_dataset, batch_size=options.batch_size * 4, shuffle=False,num_workers=options.workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=options.batch_size * 4, shuffle=False,num_workers=options.workers, pin_memory=True)
    
    ##################################
    # Initialize model
    ##################################
    logs = {}
    start_epoch = 0
    num_attentions = 32
    feature_net = inception_v3(pretrained=True)
    net = WSDAN(num_classes=num_classes, M=num_attentions, net='inception_mixed_6e', pretrained=True)

    # feature_center: size of (#classes, #attention_maps * #channel_features)
    feature_center = torch.zeros(num_classes, num_attentions * net.num_features).to(device)
   
    if options.ckpt:
        # Load ckpt and get state_dict
        checkpoint = torch.load(options.ckpt)

        # Get epoch and some logs
        logs = checkpoint['logs']
        start_epoch = int(logs['epoch'])

        # Load weights
        state_dict = checkpoint['state_dict']
        net.load_state_dict(state_dict)
        logging.info('Network loaded from {}'.format(options.ckpt))

        # load feature center
        if 'feature_center' in checkpoint:
            feature_center = checkpoint['feature_center'].to(device)
            logging.info('feature_center loaded from {}'.format(options.ckpt))

    logging.info('Network weights save to {}'.format(options.save_dir))
    feature_net = inception_v3(pretrained=True)
 
    if options.ckpt:
        ckpt = options.ckpt

        if options.initial_training == 0:
            # Get Name (epoch)
            epoch_name = (ckpt.split('/')[-1]).split('.')[0]
            start_epoch = int(epoch_name)

        # Load ckpt and get state_dict
        checkpoint = torch.load(ckpt)
        state_dict = checkpoint['state_dict']

        # Load weights
        net.load_state_dict(state_dict)
        logging.info('Network loaded from {}'.format(options.ckpt))

        # load feature center
        if 'feature_center' in checkpoint:
            feature_center = checkpoint['feature_center'].to(torch.device("cuda"))
            logging.info('feature_center loaded from {}'.format(options.ckpt))

      ##################################
    # Use cuda
    ##################################
    net.to(device)
    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    ##################################
    # Optimizer, LR Scheduler
    ##################################
    learning_rate = logs['lr'] if 'lr' in logs else options.lr
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)

    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=2)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

    ##################################
    # ModelCheckpoint
    ##################################
    callback_monitor = 'val_{}'.format(raw_metric.name)
    callback = ModelCheckpoint(savepath=os.path.join(options.save_dir, options.model_name),
                               monitor=callback_monitor,
                               mode='max')
    if callback_monitor in logs:
        callback.set_best_score(logs[callback_monitor])
    else:
        callback.reset()


    ##################################
    # TRAINING
    ##################################
    logging.info('')
    logging.info('Start training: Total epochs: {}, Batch size: {}, Training size: {}, Validation size: {}'.
                 format(options.epochs, options.batch_size, len(train_dataset), len(val_dataset)))

    for epoch in range(start_epoch, options.epochs):
        callback.on_epoch_begin()

        logs['epoch'] = epoch + 1
        logs['lr'] = optimizer.param_groups[0]['lr']

        logging.info('Epoch {:03d}, Learning Rate {:g}'.format(epoch + 1, optimizer.param_groups[0]['lr']))

        pbar = tqdm(total=len(train_loader), unit=' batches')
        pbar.set_description('Epoch {}/{}'.format(epoch + 1, options.epochs))

        train(logs=logs,
              data_loader=train_loader,
              net=net,
              feature_center=feature_center,
              optimizer=optimizer,
              pbar=pbar)
        validate(logs=logs,
                 data_loader=validate_loader,
                 net=net,
                 pbar=pbar)

        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(logs['val_loss'])
        else:
            scheduler.step()

        callback.on_epoch_end(logs, net, feature_center=feature_center)
        pbar.close()
예제 #7
0
"""Main script for Youtube-8M feature extractor."""

import misc.config as cfg
from misc.utils import concat_feat, get_dataloader, make_cuda, make_variable
from models import PCAWrapper, inception_v3

if __name__ == '__main__':
    # init models and data loader
    model = make_cuda(
        inception_v3(pretrained=True, transform_input=True, extract_feat=True))
    pca = PCAWrapper(n_components=cfg.n_components)
    model.eval()

    # data loader for frames in ingle video
    # data_loader = get_dataloader(dataset="VideoFrame",
    #                              path=cfg.video_file,
    #                              num_frames=cfg.num_frames,
    #                              batch_size=cfg.batch_size)
    # data loader for frames decoded from several videos
    data_loader = get_dataloader(dataset="FrameImage",
                                 path=cfg.frame_root,
                                 batch_size=cfg.batch_size)

    # extract features by inception_v3
    feats = None
    for step, frames in enumerate(data_loader):
        print("extracting feature [{}/{}]".format(step + 1, len(data_loader)))
        feat = model(make_variable(frames))
        feats = concat_feat(feats, feat.data.cpu())

    # recude dimensions by PCA
예제 #8
0
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, datasets
from torchsummary import summary

from dataloader import load_cifar
from models import inception_v3, Inception3, InceptionA, InceptionB, InceptionC, InceptionD, InceptionAux, BasicConv2d
from eval import plot_epoch

train_loader, val_loader, test_loader = load_cifar()

model = inception_v3()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
summary(model, (3, 32, 32))

LEARNING_RATE = 0.001
MOMENTUM = 0.9

cast = torch.nn.CrossEntropyLoss().to(device)
# Optimization
optimizer = torch.optim.SGD(model.parameters(),
                            lr=LEARNING_RATE,
                            momentum=MOMENTUM)

BATCHSIZE = args.batchsize

LR_SCHEDULE = {
    0: 0.0001,
    10: 0.00001,
    20: 0.000001
}

"""
Set up all theano functions
"""
X = T.tensor4('X')
Y = T.ivector('y')

# set up theano functions to generate output by feeding data through network, any test outputs should be deterministic
net = inception_v3(X)
# load network weights
d = pickle.load(open('data/pre_trained_weights/inception_v3.pkl'))
helper.set_all_param_values(net['softmax'], d['param values'])

# stack our own softmax onto the final layer
output_layer = DenseLayer(net['pool3'], num_units=10, W=lasagne.init.HeNormal(), nonlinearity=softmax)

# standard output functions
output_train = lasagne.layers.get_output(output_layer)
output_test = lasagne.layers.get_output(output_layer, deterministic=True)

# set up the loss that we aim to minimize, when using cat cross entropy our Y should be ints not one-hot
loss = lasagne.objectives.categorical_crossentropy(output_train, Y)
loss = loss.mean()
예제 #10
0
        QUANTIZE_MAX_VAL)
    # - convert to 8-bit in range [0.0, 255.0]
    quantized_embeddings = (
        (clipped_embeddings - QUANTIZE_MIN_VAL) *
        (255.0 /
         (QUANTIZE_MAX_VAL - QUANTIZE_MIN_VAL)))
    # - cast 8-bit float to uint8
    quantized_embeddings = quantized_embeddings.astype(np.uint8)

    return quantized_embeddings


if __name__ == '__main__':
    # init Inception v3 model
    model = make_cuda(inception_v3(pretrained=True,
                                   model_path=cfg.inception_v3_model,
                                   transform_input=True,
                                   extract_feat=True))
    model.eval()

    # init PCA model
    pca = PCAWrapper(n_components=cfg.n_components,
                     batch_size=cfg.pca_batch_size)
    pca.load_params(filepath=cfg.pca_model)

    subfolders = list_folders(cfg.dataset_path)
    for subfolder in subfolders:
        print("current folder: {}".format(subfolder))
        # data loader for frames in single video
        data_loader = get_dataloader(dataset="FrameImage",
                                     path=os.path.join(subfolder, 'frames'),
                                     frame_num =cfg.frame_num,
예제 #11
0
def main():
    args = parser.parse_args()

    data_dir = args.data_dir
    val_file = args.list_files
    ext_batch_sz = int(args.ext_batch_sz)
    int_batch_sz = int(args.int_batch_sz)
    start_instance = int(args.start_instance)
    end_instance = int(args.end_instance)
    checkpoint = args.checkpoint_path

    model_start_time = time.time()
    if args.architecture == "inception_v3":
        new_size = 299
        num_categories = 3528, 3468, 2048
        spatial_net = models.inception_v3(pretrained=(checkpoint == ""),
                                          num_outputs=len(num_categories))
    else:  #resnet
        new_size = 224
        num_categories = 8192, 4096, 2048
        spatial_net = models.resnet152(pretrained=(checkpoint == ""),
                                       num_outputs=len(num_categories))

    if os.path.isfile(checkpoint):
        print('loading checkpoint {} ...'.format(checkpoint))
        params = torch.load(checkpoint)
        model_dict = spatial_net.state_dict()

        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in params['state_dict'].items() if k in model_dict
        }

        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        spatial_net.load_state_dict(model_dict)
        print('loaded')
    else:
        print(checkpoint)
        print('ERROR: No checkpoint found')

    spatial_net.cuda()
    spatial_net.eval()
    model_end_time = time.time()
    model_time = model_end_time - model_start_time
    print("Action recognition model is loaded in %4.4f seconds." %
          (model_time))

    f_val = open(val_file, "r")
    val_list = f_val.readlines()[start_instance:end_instance]
    print("we got %d test videos" % len(val_list))

    line_id = 1
    match_count = 0

    for line_id, line in enumerate(val_list):
        print("sample %d/%d" % (line_id + 1, len(val_list)))
        line_info = line.split(" ")
        clip_path = os.path.join(data_dir, line_info[0])
        num_frames = int(line_info[1])
        input_video_label = int(line_info[2])

        spatial_prediction = VideoSpatialPrediction(clip_path, spatial_net,
                                                    num_categories, num_frames,
                                                    ext_batch_sz, int_batch_sz,
                                                    new_size)

        for ii in range(len(spatial_prediction)):
            for vr_ind, vr in enumerate(spatial_prediction[ii]):
                folder_name = args.architecture + "_" + args.dataset + "_VR" + str(
                    ii)
                if not os.path.isdir(folder_name + '/' + line_info[0]):
                    print("creating folder: " + folder_name + "/" +
                          line_info[0])
                    os.makedirs(folder_name + "/" + line_info[0])
                vr_name = folder_name + '/' + line_info[
                    0] + '/vr_{0:02d}.png'.format(vr_ind)
                vr_gray = normalize_maxmin(vr.transpose()).transpose() * 255.
                cv2.imwrite(vr_name, vr_gray)
예제 #12
0
"""Extract inception_v3_feats from images for Youtube-8M feature extractor."""

import os

import torch

import init_path
import misc.config as cfg
from misc.utils import (concat_feat_var, get_dataloader, make_cuda,
                        make_variable)
from models import PCAWrapper, inception_v3

if __name__ == '__main__':
    # init models and data loader
    model = make_cuda(inception_v3(pretrained=True,
                                   transform_input=True,
                                   extract_feat=True))
    model.eval()

    # get vid list
    video_list = os.listdir(cfg.video_root)
    video_list = [os.path.splitext(v)[0] for v in video_list
                  if os.path.splitext(v)[1] in cfg.video_ext]

    # extract features by inception_v3
    for idx, vid in enumerate(video_list):
        if os.path.exists(cfg.inception_v3_feats_path.format(vid)):
            print("skip {}".format(vid))
        else:
            print("extract feature from {} [{}/{}]".format(vid, idx + 1,
                                                           len(video_list)))
예제 #13
0
def train_multiclass(train_file, test_file, stat_file,
                     model='mobilenet_v2',
                     classes=('artist_name', 'genre', 'style', 'technique', 'century'),
                     label_file='_user_labels.pkl',
                     im_path='/export/home/kschwarz/Documents/Data/Wikiart_artist49_images',
                     chkpt=None, weight_file=None,
                     triplet_selector='semihard', margin=0.2,
                     labels_per_class=4, samples_per_label=4,
                     use_gpu=True, device=0,
                     epochs=100, batch_size=32, lr=1e-4, momentum=0.9,
                     log_interval=10, log_dir='runs',
                     exp_name=None, seed=123):
    argvars = locals().copy()
    torch.manual_seed(seed)

    # LOAD DATASET
    with open(stat_file, 'r') as f:
        data = pickle.load(f)
        mean, std = data['mean'], data['std']
        mean = [float(m) for m in mean]
        std = [float(s) for s in std]
    normalize = transforms.Normalize(mean=mean, std=std)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.ToTensor(),
                normalize,
    ])

    if model.lower() == 'inception_v3':            # change input size to 299
        train_transform.transforms[0].size = (299, 299)
        val_transform.transforms[0].size = (299, 299)
    trainset = create_trainset(train_file, label_file, im_path, train_transform, classes)
    for c in classes:
        if len(trainset.labels_to_ints[c]) < labels_per_class:
            print('less labels in class {} than labels_per_class, use all available labels ({})'
                  .format(c, len(trainset.labels_to_ints[c])))
    valset = create_valset(test_file, im_path, val_transform, trainset.labels_to_ints)
    # PARAMETERS
    use_cuda = use_gpu and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(device)
        torch.cuda.manual_seed_all(seed)

    if model.lower() not in ['squeezenet', 'mobilenet_v1', 'mobilenet_v2', 'vgg16_bn', 'inception_v3', 'alexnet']:
        assert False, 'Unknown model {}\n\t+ Choose from: ' \
                      '[sqeezenet, mobilenet_v1, mobilenet_v2, vgg16_bn, inception_v3, alexnet].'.format(model)
    elif model.lower() == 'mobilenet_v1':
        bodynet = mobilenet_v1(pretrained=weight_file is None)
    elif model.lower() == 'mobilenet_v2':
        bodynet = mobilenet_v2(pretrained=weight_file is None)
    elif model.lower() == 'vgg16_bn':
        bodynet = vgg16_bn(pretrained=weight_file is None)
    elif model.lower() == 'inception_v3':
        bodynet = inception_v3(pretrained=weight_file is None)
    elif model.lower() == 'alexnet':
        bodynet = alexnet(pretrained=weight_file is None)
    else:       # squeezenet
        bodynet = squeezenet(pretrained=weight_file is None)

    # Load weights for the body network
    if weight_file is not None:
        print("=> loading weights from '{}'".format(weight_file))
        pretrained_dict = torch.load(weight_file, map_location=lambda storage, loc: storage)['state_dict']
        state_dict = bodynet.state_dict()
        pretrained_dict = {k.replace('bodynet.', ''): v for k, v in pretrained_dict.items()         # in case of multilabel weight file
                           if (k.replace('bodynet.', '') in state_dict.keys() and v.shape == state_dict[k.replace('bodynet.', '')].shape)}  # number of classes might have changed
        # check which weights will be transferred
        if not pretrained_dict == state_dict:  # some changes were made
            for k in set(state_dict.keys() + pretrained_dict.keys()):
                if k in state_dict.keys() and k not in pretrained_dict.keys():
                    print('\tWeights for "{}" were not found in weight file.'.format(k))
                elif k in pretrained_dict.keys() and k not in state_dict.keys():
                    print('\tWeights for "{}" were are not part of the used model.'.format(k))
                elif state_dict[k].shape != pretrained_dict[k].shape:
                    print('\tShapes of "{}" are different in model ({}) and weight file ({}).'.
                          format(k, state_dict[k].shape, pretrained_dict[k].shape))
                else:  # everything is good
                    pass

        state_dict.update(pretrained_dict)
        bodynet.load_state_dict(state_dict)

    net = MetricNet(bodynet, len(classes))

    n_parameters = sum([p.data.nelement() for p in net.parameters() if p.requires_grad])
    if use_cuda:
        net = net.cuda()
    print('Using {}\n\t+ Number of params: {}'.format(str(net).split('(', 1)[0], n_parameters))

    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)

    # tensorboard summary writer
    timestamp = time.strftime('%m-%d-%H-%M')
    expname = timestamp + '_' + str(net).split('(', 1)[0]
    if exp_name is not None:
        expname = expname + '_' + exp_name
    log = TBPlotter(os.path.join(log_dir, 'tensorboard', expname))
    log.print_logdir()

    # allow auto-tuner to find best algorithm for the hardware
    cudnn.benchmark = True

    with open(label_file, 'rb') as f:
        labels = pickle.load(f)['labels']
        n_labeled = '\t'.join([str(Counter(l).items()) for l in labels.transpose()])

    write_config(argvars, os.path.join(log_dir, expname), extras={'n_labeled': n_labeled})


    # ININTIALIZE TRAINING
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, threshold=1e-1, verbose=True)

    if triplet_selector.lower() not in ['random', 'semihard', 'hardest', 'mixed', 'khardest']:
        assert False, 'Unknown option {} for triplet selector. Choose from "random", "semihard", "hardest" or "mixed"' \
                      '.'.format(triplet_selector)
    elif triplet_selector.lower() == 'random':
        criterion = TripletLoss(margin=margin,
                                triplet_selector=RandomNegativeTripletSelector(margin, cpu=not use_cuda))
    elif triplet_selector.lower() == 'semihard' or triplet_selector.lower() == 'mixed':
        criterion = TripletLoss(margin=margin,
                                triplet_selector=SemihardNegativeTripletSelector(margin, cpu=not use_cuda))
    elif triplet_selector.lower() == 'khardest':
        criterion = TripletLoss(margin=margin,
                                triplet_selector=KHardestNegativeTripletSelector(margin, k=3, cpu=not use_cuda))
    else:
        criterion = TripletLoss(margin=margin,
                                triplet_selector=HardestNegativeTripletSelector(margin, cpu=not use_cuda))
    if use_cuda:
        criterion = criterion.cuda()

    kwargs = {'num_workers': 4} if use_cuda else {}
    multilabel_train = np.stack([trainset.df[c].values for c in classes]).transpose()
    train_batch_sampler = BalancedBatchSamplerMulticlass(multilabel_train, n_label=labels_per_class,
                                                         n_per_label=samples_per_label, ignore_label=None)
    trainloader = DataLoader(trainset, batch_sampler=train_batch_sampler, **kwargs)
    multilabel_val = np.stack([valset.df[c].values for c in classes]).transpose()
    val_batch_sampler = BalancedBatchSamplerMulticlass(multilabel_val, n_label=labels_per_class,
                                                       n_per_label=samples_per_label, ignore_label=None)
    valloader = DataLoader(valset, batch_sampler=val_batch_sampler, **kwargs)

    # optionally resume from a checkpoint
    start_epoch = 1
    if chkpt is not None:
        if os.path.isfile(chkpt):
            print("=> loading checkpoint '{}'".format(chkpt))
            checkpoint = torch.load(chkpt, map_location=lambda storage, loc: storage)
            start_epoch = checkpoint['epoch']
            best_acc_score = checkpoint['best_acc_score']
            best_acc = checkpoint['acc']
            net.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(chkpt, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(chkpt))

    def train(epoch):
        losses = AverageMeter()
        gtes = AverageMeter()
        non_zero_triplets = AverageMeter()
        distances_ap = AverageMeter()
        distances_an = AverageMeter()

        # switch to train mode
        net.train()
        for batch_idx, (data, target) in enumerate(trainloader):
            target = torch.stack(target)
            if use_cuda:
                data, target = Variable(data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)

            # normalize features
            for i in range(len(classes)):
                outputs[i] = torch.nn.functional.normalize(outputs[i], p=2, dim=1)

            loss = Variable(torch.Tensor([0]), requires_grad=True).type_as(data[0])
            n_triplets = 0
            for op, tgt in zip(outputs, target):
                # filter unlabeled samples if there are any (have label -1)
                labeled = (tgt != -1).nonzero().view(-1)
                op, tgt = op[labeled], tgt[labeled]

                l, nt = criterion(op, tgt)
                loss += l
                n_triplets += nt

            non_zero_triplets.update(n_triplets, target[0].size(0))
            # measure GTE and record loss
            gte, dist_ap, dist_an = GTEMulticlass(outputs, target)           # do not compute ap pairs for concealed classes
            gtes.update(gte.data, target[0].size(0))
            distances_ap.update(dist_ap.data, target[0].size(0))
            distances_an.update(dist_an.data, target[0].size(0))
            losses.update(loss.data[0], target[0].size(0))

            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'GTE: {:.2f}% ({:.2f}%)\t'
                      'Non-zero Triplets: {:d} ({:d})'.format(
                    epoch, batch_idx * len(target[0]), len(trainloader) * len(target[0]),
                    float(losses.val), float(losses.avg),
                    float(gtes.val) * 100., float(gtes.avg) * 100.,
                    int(non_zero_triplets.val), int(non_zero_triplets.avg)))

        # log avg values to somewhere
        log.write('loss', float(losses.avg), epoch, test=False)
        log.write('gte', float(gtes.avg), epoch, test=False)
        log.write('non-zero trplts', int(non_zero_triplets.avg), epoch, test=False)
        log.write('dist_ap', float(distances_ap.avg), epoch, test=False)
        log.write('dist_an', float(distances_an.avg), epoch, test=False)

    def test(epoch):
        losses = AverageMeter()
        gtes = AverageMeter()
        non_zero_triplets = AverageMeter()
        distances_ap = AverageMeter()
        distances_an = AverageMeter()

        # switch to evaluation mode
        net.eval()
        for batch_idx, (data, target) in enumerate(valloader):
            target = torch.stack(target)
            if use_cuda:
                data, target = Variable(data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]
            # compute output
            outputs = net(data)

            # normalize features
            for i in range(len(classes)):
                outputs[i] = torch.nn.functional.normalize(outputs[i], p=2, dim=1)

            loss = Variable(torch.Tensor([0]), requires_grad=True).type_as(data[0])
            n_triplets = 0
            for op, tgt in zip(outputs, target):
                # filter unlabeled samples if there are any (have label -1)
                labeled = (tgt != -1).nonzero().view(-1)
                op, tgt = op[labeled], tgt[labeled]

                l, nt = criterion(op, tgt)
                loss += l
                n_triplets += nt

            non_zero_triplets.update(n_triplets, target[0].size(0))
            # measure GTE and record loss
            gte, dist_ap, dist_an = GTEMulticlass(outputs, target)
            gtes.update(gte.data.cpu(), target[0].size(0))
            distances_ap.update(dist_ap.data.cpu(), target[0].size(0))
            distances_an.update(dist_an.data.cpu(), target[0].size(0))
            losses.update(loss.data[0].cpu(), target[0].size(0))

        print('\nVal set: Average loss: {:.4f} Average GTE {:.2f}%, '
              'Average non-zero triplets: {:d} LR: {:.6f}'.format(float(losses.avg), float(gtes.avg) * 100.,
                                                       int(non_zero_triplets.avg),
                                                                  optimizer.param_groups[-1]['lr']))
        log.write('loss', float(losses.avg), epoch, test=True)
        log.write('gte', float(gtes.avg), epoch, test=True)
        log.write('non-zero trplts', int(non_zero_triplets.avg), epoch, test=True)
        log.write('dist_ap', float(distances_ap.avg), epoch, test=True)
        log.write('dist_an', float(distances_an.avg), epoch, test=True)
        return losses.avg, 1 - gtes.avg

    if start_epoch == 1:         # compute baseline:
        _, best_acc = test(epoch=0)
    else:       # checkpoint was loaded
        best_acc = best_acc

    for epoch in range(start_epoch, epochs + 1):
        if triplet_selector.lower() == 'mixed' and epoch == 26:
            criterion.triplet_selector = HardestNegativeTripletSelector(margin, cpu=not use_cuda)
            print('Changed negative selection from semihard to hardest.')
        # train for one epoch
        train(epoch)
        # evaluate on validation set
        val_loss, val_acc = test(epoch)
        scheduler.step(val_loss)

        # remember best acc and save checkpoint
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint({
            'epoch': epoch,
            'state_dict': net.state_dict(),
            'best_acc': best_acc,
        }, is_best, expname, directory=log_dir)

        if optimizer.param_groups[-1]['lr'] < 1e-5:
            print('Learning rate reached minimum threshold. End training.')
            break

    # report best values
    best = torch.load(os.path.join(log_dir, expname + '_model_best.pth.tar'), map_location=lambda storage, loc: storage)
    print('Finished training after epoch {}:\n\tbest acc score: {}'
          .format(best['epoch'], best['acc']))
    print('Best model mean accuracy: {}'.format(best_acc))
예제 #14
0
def train_multiclass(
        train_file,
        test_file,
        stat_file,
        model='mobilenet_v2',
        classes=('artist_name', 'genre', 'style', 'technique', 'century'),
        im_path='/export/home/kschwarz/Documents/Data/Wikiart_artist49_images',
        label_file='_user_labels.pkl',
        chkpt=None,
        weight_file=None,
        use_gpu=True,
        device=0,
        epochs=100,
        batch_size=32,
        lr=1e-4,
        momentum=0.9,
        log_interval=10,
        log_dir='runs',
        exp_name=None,
        seed=123):
    argvars = locals().copy()
    torch.manual_seed(seed)

    # LOAD DATASET
    with open(stat_file, 'r') as f:
        data = pickle.load(f)
        mean, std = data['mean'], data['std']
        mean = [float(m) for m in mean]
        std = [float(s) for s in std]
    normalize = transforms.Normalize(mean=mean, std=std)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(90),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    if model.lower() == 'inception_v3':  # change input size to 299
        train_transform.transforms[0].size = (299, 299)
        val_transform.transforms[0].size = (299, 299)
    trainset = create_trainset(train_file, label_file, im_path,
                               train_transform, classes)
    valset = create_valset(test_file, im_path, val_transform,
                           trainset.labels_to_ints)
    num_labels = [len(trainset.labels_to_ints[c]) for c in classes]

    # PARAMETERS
    use_cuda = use_gpu and torch.cuda.is_available()
    if use_cuda:
        torch.cuda.set_device(device)
        torch.cuda.manual_seed_all(seed)

    if model.lower() not in [
            'squeezenet', 'mobilenet_v1', 'mobilenet_v2', 'vgg16_bn',
            'inception_v3', 'alexnet'
    ]:
        assert False, 'Unknown model {}\n\t+ Choose from: ' \
                      '[sqeezenet, mobilenet_v1, mobilenet_v2, vgg16_bn, inception_v3, alexnet].'.format(model)
    elif model.lower() == 'mobilenet_v1':
        bodynet = mobilenet_v1(pretrained=weight_file is None)
    elif model.lower() == 'mobilenet_v2':
        bodynet = mobilenet_v2(pretrained=weight_file is None)
    elif model.lower() == 'vgg16_bn':
        bodynet = vgg16_bn(pretrained=weight_file is None)
    elif model.lower() == 'inception_v3':
        bodynet = inception_v3(pretrained=weight_file is None)
    elif model.lower() == 'alexnet':
        bodynet = alexnet(pretrained=weight_file is None)
    else:  # squeezenet
        bodynet = squeezenet(pretrained=weight_file is None)

    # Load weights for the body network
    if weight_file is not None:
        print("=> loading weights from '{}'".format(weight_file))
        pretrained_dict = torch.load(
            weight_file,
            map_location=lambda storage, loc: storage)['state_dict']
        state_dict = bodynet.state_dict()
        pretrained_dict = {
            k.replace('bodynet.', ''): v
            for k, v in
            pretrained_dict.items()  # in case of multilabel weight file
            if (k.replace('bodynet.', '') in state_dict.keys()
                and v.shape == state_dict[k.replace('bodynet.', '')].shape)
        }  # number of classes might have changed
        # check which weights will be transferred
        if not pretrained_dict == state_dict:  # some changes were made
            for k in set(state_dict.keys() + pretrained_dict.keys()):
                if k in state_dict.keys() and k not in pretrained_dict.keys():
                    print('\tWeights for "{}" were not found in weight file.'.
                          format(k))
                elif k in pretrained_dict.keys() and k not in state_dict.keys(
                ):
                    print(
                        '\tWeights for "{}" were are not part of the used model.'
                        .format(k))
                elif state_dict[k].shape != pretrained_dict[k].shape:
                    print(
                        '\tShapes of "{}" are different in model ({}) and weight file ({}).'
                        .format(k, state_dict[k].shape,
                                pretrained_dict[k].shape))
                else:  # everything is good
                    pass

        state_dict.update(pretrained_dict)
        bodynet.load_state_dict(state_dict)

    net = OctopusNet(bodynet, n_labels=num_labels)

    n_parameters = sum(
        [p.data.nelement() for p in net.parameters() if p.requires_grad])
    if use_cuda:
        net = net.cuda()
    print('Using {}\n\t+ Number of params: {}'.format(
        str(net).split('(', 1)[0], n_parameters))

    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)

    # tensorboard summary writer
    timestamp = time.strftime('%m-%d-%H-%M')
    expname = timestamp + '_' + str(net).split('(', 1)[0]
    if exp_name is not None:
        expname = expname + '_' + exp_name
    log = TBPlotter(os.path.join(log_dir, 'tensorboard', expname))
    log.print_logdir()

    # allow auto-tuner to find best algorithm for the hardware
    cudnn.benchmark = True

    with open(label_file, 'rb') as f:
        labels = pickle.load(f)['labels']
        n_labeled = '\t'.join(
            [str(Counter(l).items()) for l in labels.transpose()])

    write_config(argvars,
                 os.path.join(log_dir, expname),
                 extras={'n_labeled': n_labeled})

    # ININTIALIZE TRAINING
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     patience=10,
                                                     threshold=1e-1,
                                                     verbose=True)
    criterion = nn.CrossEntropyLoss()
    if use_cuda:
        criterion = criterion.cuda()

    kwargs = {'num_workers': 4} if use_cuda else {}
    trainloader = DataLoader(trainset,
                             batch_size=batch_size,
                             shuffle=True,
                             **kwargs)
    valloader = DataLoader(valset,
                           batch_size=batch_size,
                           shuffle=True,
                           **kwargs)

    # optionally resume from a checkpoint
    start_epoch = 1
    if chkpt is not None:
        if os.path.isfile(chkpt):
            print("=> loading checkpoint '{}'".format(chkpt))
            checkpoint = torch.load(chkpt,
                                    map_location=lambda storage, loc: storage)
            start_epoch = checkpoint['epoch']
            best_acc_score = checkpoint['best_acc_score']
            best_acc = checkpoint['acc']
            net.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                chkpt, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(chkpt))

    def train(epoch):
        losses = AverageMeter()
        accs = AverageMeter()
        class_acc = [AverageMeter() for i in range(len(classes))]

        # switch to train mode
        net.train()
        for batch_idx, (data, target) in enumerate(trainloader):
            if use_cuda:
                data, target = Variable(
                    data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)
            preds = [torch.max(outputs[i], 1)[1] for i in range(len(classes))]

            loss = Variable(torch.Tensor([0]),
                            requires_grad=True).type_as(data[0])
            for i, o, t, p in zip(range(len(classes)), outputs, target, preds):
                # filter unlabeled samples if there are any (have label -1)
                labeled = (t != -1).nonzero().view(-1)
                o, t, p = o[labeled], t[labeled], p[labeled]
                loss += criterion(o, t)
                # measure class accuracy and record loss
                class_acc[i].update(
                    (torch.sum(p == t).type(torch.FloatTensor) /
                     t.size(0)).data)
            accs.update(
                torch.mean(
                    torch.stack(
                        [class_acc[i].val for i in range(len(classes))])),
                target[0].size(0))
            losses.update(loss.data, target[0].size(0))

            # compute gradient and do optimizer step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{}]\t'
                      'Loss: {:.4f} ({:.4f})\t'
                      'Acc: {:.2f}% ({:.2f}%)'.format(epoch,
                                                      batch_idx * len(target),
                                                      len(trainloader.dataset),
                                                      float(losses.val),
                                                      float(losses.avg),
                                                      float(accs.val) * 100.,
                                                      float(accs.avg) * 100.))
                print('\t' + '\n\t'.join([
                    '{}: {:.2f}%'.format(classes[i],
                                         float(class_acc[i].val) * 100.)
                    for i in range(len(classes))
                ]))

        # log avg values to somewhere
        log.write('loss', float(losses.avg), epoch, test=False)
        log.write('acc', float(accs.avg), epoch, test=False)
        for i in range(len(classes)):
            log.write('class_acc', float(class_acc[i].avg), epoch, test=False)

    def test(epoch):
        losses = AverageMeter()
        accs = AverageMeter()
        class_acc = [AverageMeter() for i in range(len(classes))]

        # switch to evaluation mode
        net.eval()
        for batch_idx, (data, target) in enumerate(valloader):
            if use_cuda:
                data, target = Variable(
                    data.cuda()), [Variable(t.cuda()) for t in target]
            else:
                data, target = Variable(data), [Variable(t) for t in target]

            # compute output
            outputs = net(data)
            preds = [torch.max(outputs[i], 1)[1] for i in range(len(classes))]

            loss = Variable(torch.Tensor([0]),
                            requires_grad=True).type_as(data[0])
            for i, o, t, p in zip(range(len(classes)), outputs, target, preds):
                labeled = (t != -1).nonzero().view(-1)
                loss += criterion(o[labeled], t[labeled])
                # measure class accuracy and record loss
                class_acc[i].update((torch.sum(p[labeled] == t[labeled]).type(
                    torch.FloatTensor) / t[labeled].size(0)).data)
            accs.update(
                torch.mean(
                    torch.stack(
                        [class_acc[i].val for i in range(len(classes))])),
                target[0].size(0))
            losses.update(loss.data, target[0].size(0))

        score = accs.avg - torch.std(
            torch.stack([class_acc[i].avg for i in range(len(classes))])
        ) / accs.avg  # compute mean - std/mean as measure for accuracy
        print(
            '\nVal set: Average loss: {:.4f} Average acc {:.2f}% Acc score {:.2f} LR: {:.6f}'
            .format(float(losses.avg),
                    float(accs.avg) * 100., float(score),
                    optimizer.param_groups[-1]['lr']))
        print('\t' + '\n\t'.join([
            '{}: {:.2f}%'.format(classes[i],
                                 float(class_acc[i].avg) * 100.)
            for i in range(len(classes))
        ]))
        log.write('loss', float(losses.avg), epoch, test=True)
        log.write('acc', float(accs.avg), epoch, test=True)
        for i in range(len(classes)):
            log.write('class_acc', float(class_acc[i].avg), epoch, test=True)
        return losses.avg.cpu().numpy(), float(score), float(
            accs.avg), [float(class_acc[i].avg) for i in range(len(classes))]

    if start_epoch == 1:  # compute baseline:
        _, best_acc_score, best_acc, _ = test(epoch=0)
    else:  # checkpoint was loaded
        best_acc_score = best_acc_score
        best_acc = best_acc

    for epoch in range(start_epoch, epochs + 1):
        # train for one epoch
        train(epoch)
        # evaluate on validation set
        val_loss, val_acc_score, val_acc, val_class_accs = test(epoch)
        scheduler.step(val_loss)

        # remember best acc and save checkpoint
        is_best = val_acc_score > best_acc_score
        best_acc_score = max(val_acc_score, best_acc_score)
        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'best_acc_score': best_acc_score,
                'acc': val_acc,
                'class_acc': {c: a
                              for c, a in zip(classes, val_class_accs)}
            },
            is_best,
            expname,
            directory=log_dir)

        if val_acc > best_acc:
            shutil.copyfile(
                os.path.join(log_dir, expname + '_checkpoint.pth.tar'),
                os.path.join(log_dir,
                             expname + '_model_best_mean_acc.pth.tar'))
        best_acc = max(val_acc, best_acc)

        if optimizer.param_groups[-1]['lr'] < 1e-5:
            print('Learning rate reached minimum threshold. End training.')
            break

    # report best values
    best = torch.load(os.path.join(log_dir, expname + '_model_best.pth.tar'),
                      map_location=lambda storage, loc: storage)
    print(
        'Finished training after epoch {}:\n\tbest acc score: {}\n\tacc: {}\n\t class acc: {}'
        .format(best['epoch'], best['best_acc_score'], best['acc'],
                best['class_acc']))
    print('Best model mean accuracy: {}'.format(best_acc))