示例#1
0
    def __init__(self, num_classes=10):
        super(ExperimentModel, self).__init__()
        # self.features = nn.Sequential(
        #     nn.Conv2d(1, 20, kernel_size=5, stride=1),
        #     nn.ReLU(inplace=True),
        #     nn.MaxPool2d(kernel_size=2, stride=2),
        #     nn.Conv2d(20, 50, kernel_size=5),
        #     nn.ReLU(inplace=True),
        #     nn.MaxPool2d(kernel_size=2, stride=2),
        # )

        ####################################
        # self.features = InceptionV4()
        # state_dict = load_state_dict_from_url(
        #     'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
        #     model_dir='asserts',
        #     progress=True,
        # )
        # unused = self.features.load_state_dict(state_dict, strict=False)
        # logger.warn(unused)

        # self.f1 = nn.Linear(in_features=1536, out_features=num_classes, bias=True)

        ####################################
        self.features = ResNet101(pretrained=True)
        self.f1 = nn.Linear(in_features=512 * 4,
                            out_features=num_classes,
                            bias=True)

        self.theta_t1 = nn.Softmax(dim=1)
        self.l1 = nn.Linear(num_classes, num_classes, bias=False)
        self.theta_t2 = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(p=0.5)
示例#2
0
def main():

    # load the config file
    config_file = '../../log/' + args.load + '/train_config.json'
    with open(config_file) as fi:
        config = json.load(fi)
        print(" ".join("\033[96m{}\033[0m: {},".format(k, v)
                       for k, v in config.items()))

    # define data transformation
    test_transforms = transforms.Compose([
        transforms.Resize(size=448),
        transforms.CenterCrop(size=448),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    # define test dataset and loader
    test_data = CUB200(root='../../data/cub200',
                       train=False,
                       transform=test_transforms)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=config['batch_size'],
                                              shuffle=False,
                                              num_workers=config['workers'],
                                              pin_memory=False,
                                              drop_last=False)

    # load the model in eval mode
    if config['arch'] == 'resnet101':
        model = nn.DataParallel(
            ResNet101(num_classes, num_parts=config['nparts'])).cuda()
    elif config['arch'] == 'resnet50':
        model = nn.DataParallel(
            ResNet50(num_classes, num_parts=config['nparts'])).cuda()
    else:
        raise (RuntimeError(
            "Only support resnet50 or resnet101 for architecture!"))

    resume = '../../checkpoints/' + args.load + '_best.pth.tar'
    print("=> loading checkpoint '{}'".format(resume))
    checkpoint = torch.load(resume)
    model.load_state_dict(checkpoint['state_dict'], strict=True)
    model.eval()

    # test the model
    acc = test(test_loader, model)

    # print the overall best acc
    print('Testing finished...')
    print('Best accuracy on test set is: %.4f.' % acc)
示例#3
0
def get_model():
    if args.model == 'ResNet18':
        return ResNet18(p_dropout=args.dropout)
    elif args.model == 'ResNet34':
        return ResNet34(p_dropout=args.dropout)
    elif args.model == 'ResNet50':
        return ResNet50(p_dropout=args.dropout)
    elif args.model == 'ResNet101':
        return ResNet101(p_dropout=args.dropout)
    elif args.model == 'ResNet152':
        return ResNet152(p_dropout=args.dropout)
    elif args.model == 'VGG11':
        return VGG('VGG11', p_dropout=args.dropout)
    elif args.model == 'VGG13':
        return VGG('VGG13', p_dropout=args.dropout)
    elif args.model == 'VGG16':
        return VGG('VGG16', p_dropout=args.dropout)
    elif args.model == 'VGG19':
        return VGG('VGG19', p_dropout=args.dropout)
    else:
        raise 'Model Not found'
def main():

    global best_acc

    # create model by archetecture and load the pretrain weight
    print("=> creating model...")

    if args['arch'] == 'resnet101':
        model = ResNet101(args['num_classes'], args['nparts'])
        model.load_state_dict(models.resnet101(pretrained=True).state_dict(), strict=False)
    elif args['arch'] == 'resnet50':
        model = ResNet50(args['num_classes'], args['nparts'])
        model.load_state_dict(models.resnet50(pretrained=True).state_dict(), strict=False)
    else:
        raise(RuntimeError("Only support ResNet50 or ResNet101!"))

    model = torch.nn.DataParallel(model).cuda()

    # optionally resume from a checkpoint
    start_epoch = 0
    if args['resume'] != '':
        if os.path.isfile(args['resume']):
            print("=> loading checkpoint '{}'".format(args['resume']))
            checkpoint = torch.load(args['resume'])
            start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args['resume'], checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args['resume']))

    # data augmentation
    train_transforms = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
    val_transforms = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])

    # wrap to dataset
    if args['split'] == 'accuracy':
        train_data = CelebA(data_dir, split='train_full', align=True,
            percentage=None, transform=train_transforms, resize=(256, 256))
        val_data = CelebA(data_dir, split='val', align=True,
            percentage=None, transform=val_transforms, resize=(256, 256))
    elif args['split'] == 'interpretability':
        train_data = CelebA(data_dir, split='train', align=False,
            percentage=0.3, transform=train_transforms, resize=(256, 256))
        val_data = CelebA(data_dir, split='val', align=False,
            percentage=0.3, transform=val_transforms, resize=(256, 256))
    else:
        raise(RuntimeError("Please choose either \'accuracy\' or \'interpretability\' for data split."))

    # wrap to dataloader
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=args['batch_size'], shuffle=True,
        num_workers=args['workers'], pin_memory=False, drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_data, batch_size=args['batch_size'], shuffle=False,
        num_workers=args['workers'], pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.BCEWithLogitsLoss().cuda()

    # fix/finetune several layers
    fixed_layers = args['fixed']
    finetune_layers = args['finetune']
    finetune_parameters = []
    scratch_parameters = []
    for name, p in model.named_parameters():
        layer_name = name.split('.')[1]
        if layer_name not in fixed_layers:
            if layer_name in finetune_layers:
                finetune_parameters.append(p)
            else:
                scratch_parameters.append(p)
        else:
            p.requires_grad = False

    # define the optimizer according to different param groups
    optimizer = torch.optim.SGD([{'params': scratch_parameters,  'lr':20*args['lr']},
                                 {'params': finetune_parameters, 'lr':args['lr']},
        ], weight_decay=args['weight_decay'], momentum=0.9)

    # define the MultiStep learning rate scheduler
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5], gamma=0.1)

    # load the scheduler from the checkpoint if needed
    if args['resume'] != '':
        if os.path.isfile(args['resume']):
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler'])

    # training part
    for epoch in range(start_epoch, args['epochs']):

        # training
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on val set
        acc_per_attr, acc = validate(val_loader, model, criterion, epoch)

        # LR scheduler
        scheduler.step()

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        if is_best:
            best_acc = acc
            best_per_attr = acc_per_attr
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_acc': best_acc,
            'scheduler': scheduler.state_dict(),
        }, is_best, os.path.join(check_dir, args['save']))

        # print current best acc
        print('Current best average accuracy is: %.4f' % best_acc)

    # print the overall best acc and close the writer
    print('Training finished...')
    with open(os.path.join(log_dir, "acc_per_attr.txt"), 'w') as logfile:
        for k in range(args['num_classes']):
            logfile.write('%s: %.4f\n' % (celeba_attr[k], best_per_attr[k].avg))
    print('Per-attribute accuracy on val set has been written to acc_per_attr.txt under the log folder')
    print('Best average accuracy on val set is: %.4f.' % best_acc)

    writer.close()
def main():

    # load the config file
    config_file = '../log/'+ args.load +'/train_config.json'
    with open(config_file) as fi:
        config = json.load(fi)
        print(" ".join("\033[96m{}\033[0m: {},".format(k, v) for k, v in config.items()))

    # define data transformation (no crop)
    test_transforms = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                                std=(0.229, 0.224, 0.225))
        ])

    # define test dataset and loader
    if config['split'] == 'accuracy':
        dataset = CelebA('../data/celeba', split='test', align=True,
            percentage=None, transform=test_transforms, resize=(256, 256))
    elif config['split'] == 'interpretability':
        dataset = CelebA('../data/celeba', split='test', align=False,
            percentage=0.3, transform=test_transforms, resize=(256, 256))
    else:
        raise(RuntimeError("Please choose either \'accuracy\' or \'interpretability\' for data split."))
    test_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1, shuffle=True,
        num_workers=1, pin_memory=False)

    # create a dataloader iter instance
    test_loader_iter = iter(test_loader)

    # define the figure layout
    fig_rows = 5
    fig_cols = 5
    f_assign, axarr_assign = plt.subplots(fig_rows, fig_cols, figsize=(fig_cols*2,fig_rows*2))
    f_assign.subplots_adjust(wspace=0, hspace=0)

    # load the model in eval mode
    # with batch size = 1, we only support single GPU visaulization
    if config['arch'] == 'resnet101':
        model = ResNet101(num_classes, num_parts=config['nparts']).cuda()
    elif config['arch'] == 'resnet50':
        model = ResNet50(num_classes, num_parts=config['nparts']).cuda()
    else:
        raise(RuntimeError("Only support resnet50 or resnet101 for architecture!"))

    # load model
    resume = '../checkpoints/'+args.load+'_best.pth.tar'
    print("=> loading checkpoint '{}'".format(resume))
    checkpoint = torch.load(resume)
    # remove the module prefix
    new_state_dict = OrderedDict()
    for k, v in checkpoint['state_dict'].items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict, strict=True)
    model.eval()

    with torch.no_grad():
        # the visualization code
        current_id = 0
        for col_id in range(fig_cols):
            for j in range(fig_rows):

                # inference the model
                input, target, _ = next(test_loader_iter)
                input = input.cuda()
                target = target.cuda()
                current_id += 1
                with torch.no_grad():
                    print("Visualizing %dth image..." % current_id)
                    output_list, att_list, assign = model(input)

                # define root for saving results and make directories correspondingly
                root = os.path.join('../visualization', args.load, str(current_id))
                os.makedirs(root, exist_ok=True)
                os.makedirs(os.path.join(root, 'attentions'), exist_ok=True)
                os.makedirs(os.path.join(root, 'assignments'), exist_ok=True)

                # denormalize the image and save the input
                save_input = transforms.Normalize(mean=(0, 0, 0),std=(1/0.229, 1/0.224, 1/0.225))(input.data[0].cpu())
                save_input = transforms.Normalize(mean=(-0.485, -0.456, -0.406),std=(1, 1, 1))(save_input)
                save_input = torch.nn.functional.interpolate(save_input.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
                img = torchvision.transforms.ToPILImage()(save_input)
                img.save(os.path.join(root, 'input.png'))

                # save the labels and pred as list
                label = list(target.data[0].cpu().numpy())
                prediction = []
                assert (len(label) == num_classes)
                for k in range(num_classes):
                    current_score = torch.sigmoid(output_list[k]).squeeze().data.item()
                    current_pred = int(current_score > 0.5)
                    prediction.append(current_pred)

                # write the labels and pred
                with open(os.path.join(root, 'prediction.txt'), 'w') as pred_log:
                    for k in range(num_classes):
                        pred_log.write('%s pred: %d, label: %d\n' % (celeba_attr[k], prediction[k], label[k]))

                # upsample the assignment and transform the attention correspondingly
                assign_reshaped = torch.nn.functional.interpolate(assign.data.cpu(), size=(256, 256), mode='bilinear', align_corners=False)

                # visualize the attention
                for k in range(num_classes):

                    # attention vector for kth attribute
                    att = att_list[k].view(
                        1, config['nparts'], 1, 1).data.cpu()

                    # multiply the assignment with the attention vector
                    assign_att = assign_reshaped * att

                    # sum along the part dimension to calculate the spatial attention map
                    attmap_hw = torch.sum(assign_att, dim=1).squeeze(0).numpy()

                    # normalize the attention map and merge it onto the input
                    img = cv2.imread(os.path.join(root, 'input.png'))
                    mask = attmap_hw / attmap_hw.max()
                    img_float = img.astype(float) / 255.
                    show_att_on_image(img_float, mask, os.path.join(root, 'attentions', celeba_attr[k]+'.png'))

                # generate the one-channel hard assignment via argmax
                _, assign = torch.max(assign_reshaped, 1)

                # colorize and save the assignment
                plot_assignment(root, assign.squeeze(0).numpy(), config['nparts'])

                # collect the assignment for the final image array
                color_assignment_name = os.path.join(root, 'assignment.png')
                color_assignment = mpimg.imread(color_assignment_name)
                axarr_assign[j, col_id].imshow(color_assignment)
                axarr_assign[j, col_id].axis('off')

                # plot the assignment for each dictionary vector
                for i in range(config['nparts']):
                    img = torch.nn.functional.interpolate(assign_reshaped.data[:, i].cpu().unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
                    img = torchvision.transforms.ToPILImage()(img.squeeze(0))
                    img.save(os.path.join(root, 'assignments', 'part_'+str(i)+'.png'))

        # save the array version
        os.makedirs('../visualization/collected', exist_ok=True)
        f_assign.savefig(os.path.join('../visualization/collected', args.load+'.png'))

        print('Visualization finished!')
def main():

    # load the config file
    config_file = '../log/' + args.load + '/train_config.json'
    with open(config_file) as fi:
        config = json.load(fi)
        print(" ".join("\033[96m{}\033[0m: {},".format(k, v)
                       for k, v in config.items()))

    # test transform
    test_transforms = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    # define test dataset and loader
    if config['split'] == 'accuracy':
        test_data = CelebA('../data/celeba',
                           split='test',
                           align=True,
                           percentage=None,
                           transform=test_transforms,
                           resize=(256, 256))
    elif config['split'] == 'interpretability':
        test_data = CelebA('../data/celeba',
                           split='test',
                           align=False,
                           percentage=0.3,
                           transform=test_transforms,
                           resize=(256, 256))
    else:
        raise (RuntimeError(
            "Please choose either \'accuracy\' or \'interpretability\' for data split."
        ))
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=config['batch_size'],
                                              shuffle=False,
                                              num_workers=6,
                                              pin_memory=False,
                                              drop_last=False)

    # load the model in eval mode
    if config['arch'] == 'resnet101':
        model = nn.DataParallel(
            ResNet101(num_classes, num_parts=config['nparts'])).cuda()
    elif config['arch'] == 'resnet50':
        model = nn.DataParallel(
            ResNet50(num_classes, num_parts=config['nparts'])).cuda()
    else:
        raise (RuntimeError(
            "Only support resnet50 or resnet101 for architecture!"))

    resume = '../checkpoints/' + args.load + '_best.pth.tar'
    print("=> loading checkpoint '{}'".format(resume))
    checkpoint = torch.load(resume)
    model.load_state_dict(checkpoint['state_dict'], strict=True)
    model.eval()

    # test the model
    acc_per_attr, acc = test(test_loader, model)

    # print the overall best acc
    print('Testing finished...')
    print('Per-attribute accuracy:')
    print(
        '==========================================================================='
    )
    for k in range(num_classes):
        print('\033[96m%s\033[0m: %.4f' %
              (celeba_attr[k], acc_per_attr[k].avg))
    print(
        '==========================================================================='
    )
    print('Best average accuracy on test set is: %.4f.' % acc)
def main():

    # load the config file
    config_file = '../log/'+ args.load +'/train_config.json'
    with open(config_file) as fi:
        config = json.load(fi)
        print(" ".join("\033[96m{}\033[0m: {},".format(k, v) for k, v in config.items()))

    # test transform
    data_transforms = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                            std=(0.229, 0.224, 0.225))
        ])

    # define dataset and loader
    assert config['split'] == 'interpretability'
    fit_data = CelebA('../data/celeba',
        split='fit', align=False, percentage=0.3, transform=data_transforms, resize=(256, 256))
    eval_data = CelebA('../data/celeba',
        split='eval', align=False, percentage=0.3, transform=data_transforms, resize=(256, 256))
    fit_loader = torch.utils.data.DataLoader(
        fit_data, batch_size=config['batch_size'], shuffle=False,
        num_workers=6, pin_memory=False, drop_last=False)
    eval_loader = torch.utils.data.DataLoader(
        eval_data, batch_size=config['batch_size'], shuffle=False,
        num_workers=6, pin_memory=False, drop_last=False)

    # load the model in eval mode
    if config['arch'] == 'resnet101':
        model = nn.DataParallel(ResNet101(num_classes, num_parts=config['nparts'])).cuda()
    elif config['arch'] == 'resnet50':
        model = nn.DataParallel(ResNet50(num_classes, num_parts=config['nparts'])).cuda()
    else:
        raise(RuntimeError("Only support resnet50 or resnet101 for architecture!"))

    resume = '../checkpoints/'+args.load+'_best.pth.tar'
    print("=> loading checkpoint '{}'".format(resume))
    checkpoint = torch.load(resume)
    model.load_state_dict(checkpoint['state_dict'], strict=True)
    model.eval()

    # convert the assignment to centers for both splits
    print('Evaluating the model for the whole data split...')
    fit_centers, fit_annos, fit_eyedists = create_centers(
        fit_loader, model, config['nparts'])
    eval_centers, eval_annos, eval_eyedists = create_centers(
        eval_loader, model, config['nparts'])
    eval_data_size = eval_centers.shape[0]

    # normalize the centers to make sure every face image has unit eye distance
    fit_centers, fit_annos = fit_centers / fit_eyedists, fit_annos / fit_eyedists
    eval_centers, eval_annos = eval_centers / eval_eyedists, eval_annos / eval_eyedists

    # fit the linear regressor with sklearn
    # normalized assignment center coordinates -> normalized landmark coordinate annotations
    print('=> fitting and evaluating the regressor')

    # convert tensors to numpy (64 bit double)
    fit_centers_np = fit_centers.cpu().numpy().astype(np.float64)
    fit_annos_np = fit_annos.cpu().numpy().astype(np.float64)
    eval_centers_np = eval_centers.cpu().numpy().astype(np.float64)
    eval_annos_np = eval_annos.cpu().numpy().astype(np.float64)

    # data standardization
    scaler_centers = StandardScaler()
    scaler_landmarks = StandardScaler()

    # fit the StandardScaler with the fitting split
    scaler_centers.fit(fit_centers_np)
    scaler_landmarks.fit(fit_annos_np)

    # stardardize the fitting split
    fit_centers_std = scaler_centers.transform(fit_centers_np)
    fit_annos_std = scaler_landmarks.transform(fit_annos_np)

    # define regressor without intercept and fit it
    regressor = LinearRegression(fit_intercept=False)
    regressor.fit(fit_centers_std, fit_annos_std)

    # standardize the centers on the evaluation split
    eval_centers_std = scaler_centers.transform(eval_centers_np)

    # regress the landmarks on the evaluation split
    eval_pred_std = regressor.predict(eval_centers_std)

    # unstandardize the prediction with StandardScaler for landmarks
    eval_pred = scaler_landmarks.inverse_transform(eval_pred_std)

    # calculate the error
    eval_pred = eval_pred.reshape((eval_data_size, num_landmarks, 2))
    eval_annos = eval_annos_np.reshape((eval_data_size, num_landmarks, 2))
    error = L2_distance(eval_pred, eval_annos) * 100

    print('Mean L2 Distance on the test set is %.2f%%.' % error)
    print('Evaluation finished for model \''+args.load+'\'.')
示例#8
0
                                               pin_memory=True)
else:
    MapRoot = args.salmap_root
    test_loader = torch.utils.data.DataLoader(MyTestData(test_dataRoot,
                                                         transform=True),
                                              batch_size=1,
                                              shuffle=True,
                                              num_workers=4,
                                              pin_memory=True)

print('data already')
"""""" """"" ~~~nets~~~ """ """"""
start_epoch = 0
start_iteration = 0
model_rgb = ResNet101(BatchNorm=nn.BatchNorm2d,
                      pretrained=bool(1 - args.param),
                      output_stride=16)
model_intergration = Integration()
model_att = TriAtt()

if args.param is True:
    model_rgb.load_state_dict(
        torch.load(
            os.path.join(args.snapshot_root,
                         'snapshot_iter_' + parameters["snap_num"])))
    model_intergration.load_state_dict(
        torch.load(
            os.path.join(args.snapshot_root,
                         'integrate_snapshot_iter_' + parameters["snap_num"])))
    model_att.load_state_dict(
        torch.load(
示例#9
0
def train(args):
    DEVICE = "cuda:" + args.cuda if torch.cuda.is_available() else "cpu"

    print("Use device: ", DEVICE)

    device = torch.device(DEVICE)

    data_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataset = Image_Dateset(args.train_label_info_file_name,
                                  transform=data_transform)

    data_loader = DataLoader(dataset=train_dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=4)

    model = ResNet101(
        feature_extract=args.feature_extract,
        num_classes=args.num_classes,
        use_pretrained=args.use_pretrained).get_model().to(device)

    best_acc = 0.0
    best_model_params = copy.deepcopy(model.state_dict())

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    num_epochs = args.epochs
    train_loss_history = []
    train_acc_history = []

    model.train()
    for epoch in range(num_epochs):
        print(f'Epoch: {epoch + 1}/{num_epochs}')
        print('-' * len(f'Epoch: {epoch + 1}/{num_epochs}'))

        training_loss = 0.0
        training_corrects = 0
        count = 0

        for i, (inputs, label) in tqdm(enumerate(data_loader),
                                       ascii=True,
                                       total=len(data_loader)):
            count += 1
            inputs = inputs.to(device)
            label = label.to(device)

            optimizer.zero_grad()

            output = model(inputs)

            _, predict = torch.max(output.detach(), 1)

            loss = criterion(output, label)

            loss.backward()
            optimizer.step()

            training_loss += loss.item()

            training_corrects += torch.sum(predict == label.detach()).item()

        training_loss = training_loss / count

        train_loss_history.append(training_loss)

        train_acc = training_corrects / len(train_dataset)

        train_acc_history.append(train_acc)

        print(
            f'Training loss: {training_loss:.4f}\taccuracy: {train_acc:.4f}\n')

        if train_acc > best_acc:
            best_acc = train_acc
            best_model_params = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_params)
    torch.save(model, args.model_path)

    with open(args.info_path, 'wb') as f:
        pickle.dump(
            {
                'train_loss_history': train_loss_history,
                'train_acc_history': train_acc_history,
            }, f)
def main():

    # load the config file
    config_file = '../../log/' + args.load + '/train_config.json'
    with open(config_file) as fi:
        config = json.load(fi)
        print(" ".join("\033[96m{}\033[0m: {},".format(k, v)
                       for k, v in config.items()))

    # define data transformation (no crop)
    data_transforms = transforms.Compose([
        transforms.Resize(size=(448)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    # define dataset and loader
    fit_data = CUB200(root='../../data/cub200',
                      train=True,
                      transform=data_transforms,
                      resize=448)
    eval_data = CUB200(root='../../data/cub200',
                       train=False,
                       transform=data_transforms,
                       resize=448)
    fit_loader = torch.utils.data.DataLoader(fit_data,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=1,
                                             pin_memory=False,
                                             drop_last=False)
    eval_loader = torch.utils.data.DataLoader(eval_data,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=1,
                                              pin_memory=False,
                                              drop_last=False)

    # load the model in eval mode
    if config['arch'] == 'resnet101':
        model = nn.DataParallel(
            ResNet101(num_classes, num_parts=config['nparts'])).cuda()
    elif config['arch'] == 'resnet50':
        model = nn.DataParallel(
            ResNet50(num_classes, num_parts=config['nparts'])).cuda()
    else:
        raise (RuntimeError(
            "Only support resnet50 or resnet101 for architecture!"))

    resume = '../../checkpoints/' + args.load + '_best.pth.tar'
    print("=> loading checkpoint '{}'".format(resume))
    checkpoint = torch.load(resume)
    model.load_state_dict(checkpoint['state_dict'], strict=True)
    model.eval()

    # convert the assignment to centers for both splits
    print('Evaluating the model for the whole data split...')
    fit_centers, fit_annos, fit_masks = create_centers(fit_loader, model,
                                                       config['nparts'])
    eval_centers, eval_annos, eval_masks = create_centers(
        eval_loader, model, config['nparts'])

    # fit the linear regressor with sklearn
    # normalized assignment center coordinates -> normalized landmark coordinate annotations
    print('=> fitting and evaluating the regressor')
    error = 0
    n_valid_samples = 0

    # different landmarks have different masks
    for i in range(num_landmarks):

        # get the valid indices for the current landmark
        fit_masks_np = fit_masks.cpu().numpy().astype(np.float64)
        eval_masks_np = eval_masks.cpu().numpy().astype(np.float64)
        fit_selection = (abs(fit_masks_np[:, i * 2]) > 1e-5)
        eval_selection = (abs(eval_masks_np[:, i * 2]) > 1e-5)

        # convert tensors to numpy (64 bit double)
        fit_centers_np = fit_centers.cpu().numpy().astype(np.float64)
        fit_annos_np = fit_annos.cpu().numpy().astype(np.float64)
        eval_centers_np = eval_centers.cpu().numpy().astype(np.float64)
        eval_annos_np = eval_annos.cpu().numpy().astype(np.float64)

        # select the current landmarks for both fit and eval set
        fit_annos_np = fit_annos_np[:, i * 2:i * 2 + 2]
        eval_annos_np = eval_annos_np[:, i * 2:i * 2 + 2]

        # remove invalid indices
        fit_centers_np = fit_centers_np[fit_selection]
        fit_annos_np = fit_annos_np[fit_selection]
        eval_centers_np = eval_centers_np[eval_selection]
        eval_annos_np = eval_annos_np[eval_selection]
        eval_data_size = eval_centers_np.shape[0]

        # data standardization
        scaler_centers = StandardScaler()
        scaler_landmarks = StandardScaler()

        # fit the StandardScaler with the fitting split
        scaler_centers.fit(fit_centers_np)
        scaler_landmarks.fit(fit_annos_np)

        # stardardize the fitting split
        fit_centers_std = scaler_centers.transform(fit_centers_np)
        fit_annos_std = scaler_landmarks.transform(fit_annos_np)

        # define regressor without intercept and fit it
        regressor = LinearRegression(fit_intercept=False)
        regressor.fit(fit_centers_std, fit_annos_std)

        # standardize the centers on the evaluation split
        eval_centers_std = scaler_centers.transform(eval_centers_np)

        # regress the landmarks on the evaluation split
        eval_pred_std = regressor.predict(eval_centers_std)

        # unstandardize the prediction with StandardScaler for landmarks
        eval_pred = scaler_landmarks.inverse_transform(eval_pred_std)

        # calculate the error
        eval_pred = eval_pred.reshape((eval_data_size, 1, 2))
        eval_annos_np = eval_annos_np.reshape((eval_data_size, 1, 2))
        error += L2_distance(eval_pred, eval_annos_np) * eval_data_size
        n_valid_samples += eval_data_size

    error = error * 100 / n_valid_samples
    print('Mean L2 Distance on the test set is %.2f%%.' % error)
    print('Evaluation finished for model \'' + args.load + '\'.')
示例#11
0
def main():

    # load the config file
    config_file = '../../log/' + args.load + '/train_config.json'
    with open(config_file) as fi:
        config = json.load(fi)
        print(" ".join("\033[96m{}\033[0m: {},".format(k, v)
                       for k, v in config.items()))

    # define data transformation (no crop)
    test_transforms = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    # wrap to dataset
    #test_data = UBIPr_Identification("pairs_to_explain_identification", split='train', transform=test_transforms)
    test_data = UBIPr_Verification("pairs_to_explain_verification",
                                   split='train',
                                   transform=test_transforms)

    # wrap to dataloader
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=1,
                                              pin_memory=False,
                                              drop_last=False)

    test_loader_iter = iter(test_loader)

    # define the figure layout
    fig_rows = 5
    fig_cols = 5
    f_assign, axarr_assign = plt.subplots(fig_rows,
                                          fig_cols,
                                          figsize=(fig_cols * 2, fig_rows * 2))
    f_assign.subplots_adjust(wspace=0, hspace=0)

    # load the model in eval mode
    # with batch size = 1, we only support single GPU visaulization
    if config['arch'] == 'resnet101':
        model = ResNet101(num_classes, num_parts=config['nparts']).cuda()
    elif config['arch'] == 'resnet50':
        model = ResNet50(num_classes, num_parts=config['nparts']).cuda()
    else:
        raise (RuntimeError(
            "Only support resnet50 or resnet101 for architecture!"))

    # load model
    resume = '../../checkpoints/' + args.load + '_best.pth.tar'
    print("=> loading checkpoint '{}'".format(resume))
    checkpoint = torch.load(resume)
    # remove the module prefix
    new_state_dict = OrderedDict()
    for k, v in checkpoint['state_dict'].items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict, strict=True)
    model.eval()

    with torch.no_grad():
        # the visualization code
        current_id = 0
        for i in range(100):

            t0 = time.time()

            # inference the model
            img_batch, ground_truth, _ = next(test_loader_iter)

            input = img_batch.cuda()
            target = ground_truth.cuda()

            #image_A = img_batch[0][0].cuda()
            #image_B = img_batch[1][0].cuda()
            '''image_A_labels = img_labels[0][0].cuda()
            image_B_labels = img_labels[1][0].cuda()'''

            current_id += 1
            with torch.no_grad():
                print("Visualizing %dth image..." % current_id)
                #output_list_A, att_list_A, assign_A = model(torch.reshape(image_A, [1, 3, 256, 256]))
                #output_list_B, att_list_B, assign_B = model(torch.reshape(image_B, [1, 3, 256, 256]))
                output_list, att_list, assign = model(input)

            # define root for saving results and make directories correspondingly
            root = os.path.join('../../visualization', args.load,
                                str(current_id))
            os.makedirs(root, exist_ok=True)
            '''os.makedirs(os.path.join(root, 'attentions_A'), exist_ok=True)
            os.makedirs(os.path.join(root, 'attentions_B'), exist_ok=True)'''

            os.makedirs(os.path.join(root, 'attentions'), exist_ok=True)

            if (not JUST_CARE_ABOUT_THE_SCORES):
                '''os.makedirs(os.path.join(root, 'assignments_A'), exist_ok=True)
                os.makedirs(os.path.join(root, 'assignments_B'), exist_ok=True)'''
                os.makedirs(os.path.join(root, 'assignments'), exist_ok=True)

            # denormalize the image and save the input
            '''save_input = transforms.Normalize(mean=(0, 0, 0),std=(1/0.229, 1/0.224, 1/0.225))(torch.reshape(image_A, [1, 3, 256, 256]).data[0].cpu())
            save_input = transforms.Normalize(mean=(-0.485, -0.456, -0.406),std=(1, 1, 1))(save_input)
            save_input = torch.nn.functional.interpolate(save_input.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
            img = torchvision.transforms.ToPILImage()(save_input)
            
            img.save(os.path.join(root, 'input_A.png'))

            save_input = transforms.Normalize(mean=(0, 0, 0),std=(1/0.229, 1/0.224, 1/0.225))(torch.reshape(image_B, [1, 3, 256, 256]).data[0].cpu())
            save_input = transforms.Normalize(mean=(-0.485, -0.456, -0.406),std=(1, 1, 1))(save_input)
            save_input = torch.nn.functional.interpolate(save_input.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
            img = torchvision.transforms.ToPILImage()(save_input)
            
            img.save(os.path.join(root, 'input_B.png'))'''

            # denormalize the image and save the input
            save_input = transforms.Normalize(mean=(0, 0, 0),
                                              std=(1 / 0.229, 1 / 0.224, 1 /
                                                   0.225))(input.data[0].cpu())
            save_input = transforms.Normalize(mean=(-0.485, -0.456, -0.406),
                                              std=(1, 1, 1))(save_input)
            save_input = torch.nn.functional.interpolate(
                save_input.unsqueeze(0),
                size=(256, 256),
                mode='bilinear',
                align_corners=False).squeeze(0)
            img = torchvision.transforms.ToPILImage()(save_input)
            img.save(os.path.join(root, 'input.png'))

            # save the labels and pred as list
            '''label_A = list(torch.reshape(image_A_labels, [1, image_A_labels.shape[0]]).data[0].cpu().numpy())
            assert (len(label_A) == num_classes)
            prediction_A = []
            highest_predicted_class_A = (0.0, 0, 0)
            for k in range(num_classes):
                current_pred = torch.sigmoid(output_list_A[k]).squeeze().data.item()
                if(current_pred > highest_predicted_class_A[0]): highest_predicted_class_A = (current_pred, UBIPR_CLASSES[k], k)
                
                prediction_A.append(current_pred)
            
            label_B = list(torch.reshape(image_B_labels, [1, image_B_labels.shape[0]]).data[0].cpu().numpy())
            prediction_B = []
            highest_predicted_class_B = (0.0, 0, 0)
            for k in range(num_classes):
                current_pred = torch.sigmoid(output_list_B[k]).squeeze().data.item()
                if(current_pred > highest_predicted_class_B[0]): highest_predicted_class_B = (current_pred, UBIPR_CLASSES[k], k)
                
                prediction_B.append(current_pred)'''

            # save the labels and pred as list
            label = list(target.data[0].cpu().numpy())
            prediction = []
            assert (len(label) == num_classes)
            highest_predicted_class = (0.0, 0, 0)
            for k in range(num_classes):
                current_pred = torch.sigmoid(
                    output_list[k]).squeeze().data.item()
                #current_pred = int(current_score > 0.5)
                if (current_pred > highest_predicted_class[0]):
                    highest_predicted_class = (current_pred, UBIPR_CLASSES[k],
                                               k)

                prediction.append(current_pred)

            # write the labels and pred
            '''if(not JUST_CARE_ABOUT_THE_SCORES):
                with open(os.path.join(root, 'prediction_A.txt'), 'w') as pred_log:
                    for k in range(num_classes):
                        pred_log.write('%s pred: %f, label: %d\n' % (UBIPR_CLASSES[k], prediction_A[k], label_A[k]))

                with open(os.path.join(root, 'prediction_B.txt'), 'w') as pred_log:
                    for k in range(num_classes):
                        pred_log.write('%s pred: %f, label: %d\n' % (UBIPR_CLASSES[k], prediction_B[k], label_B[k]))

            # upsample the assignment and transform the attention correspondingly
            assign_A_reshaped = torch.nn.functional.interpolate(assign_A.data.cpu(), size=(256, 256), mode='bilinear', align_corners=False)
            assign_B_reshaped = torch.nn.functional.interpolate(assign_B.data.cpu(), size=(256, 256), mode='bilinear', align_corners=False)'''

            # write the labels and pred
            with open(os.path.join(root, 'prediction.txt'), 'w') as pred_log:
                for k in range(num_classes):
                    pred_log.write('%s pred: %f, label: %d\n' %
                                   (UBIPR_CLASSES[k], prediction[k], label[k]))

            # upsample the assignment and transform the attention correspondingly
            assign_reshaped = torch.nn.functional.interpolate(
                assign.data.cpu(),
                size=(256, 256),
                mode='bilinear',
                align_corners=False)

            # visualize the attention
            '''for k in range(num_classes):

                #if(k != highest_predicted_class[2]): continue

                # attention vector for kth attribute
                att = att_list_A[k].view(1, config['nparts'], 1, 1).data.cpu()

                # multiply the assignment with the attention vector
                assign_att = assign_A_reshaped * att

                # sum along the part dimension to calculate the spatial attention map
                attmap_hw = torch.sum(assign_att, dim=1).squeeze(0).numpy()

                # normalize the attention map and merge it onto the input
                img = cv2.imread(os.path.join(root, 'input_A.png'))
                mask_A = attmap_hw / attmap_hw.max()

                # save the attention map for image A
                np.save(os.path.join(root, 'attention_map_A.npy'), mask_A)

                img_float = img.astype(float) / 255.

                show_att_on_image(img_float, mask_A, os.path.join(root, 'attentions_A', UBIPR_CLASSES[k]+'.png'))

            # generate the one-channel hard assignment via argmax
            _, assign = torch.max(assign_A_reshaped, 1)

            # colorize and save the assignment
            if(not JUST_CARE_ABOUT_THE_SCORES):
                plot_assignment(root, assign.squeeze(0).numpy(), config['nparts'], "A")

                # collect the assignment for the final image array
                color_assignment_name = os.path.join(root, 'assignment_A.png')
                color_assignment = mpimg.imread(color_assignment_name)
                #axarr_assign[j, col_id].imshow(color_assignment)
                #axarr_assign[j, col_id].axis('off')

            # plot the assignment for each dictionary vector
            if(not JUST_CARE_ABOUT_THE_SCORES):
                for i in range(config['nparts']):
                    img = torch.nn.functional.interpolate(assign_A_reshaped.data[:, i].cpu().unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
                    img = torchvision.transforms.ToPILImage()(img.squeeze(0))
                    img.save(os.path.join(root, 'assignments_A', 'part_'+str(i)+'.png'))
            
            # visualize the attention
            for k in range(num_classes):

                #if(k != highest_predicted_class[2]): continue

                # attention vector for kth attribute
                att = att_list_B[k].view(1, config['nparts'], 1, 1).data.cpu()

                # multiply the assignment with the attention vector
                assign_att = assign_B_reshaped * att

                # sum along the part dimension to calculate the spatial attention map
                attmap_hw = torch.sum(assign_att, dim=1).squeeze(0).numpy()

                # normalize the attention map and merge it onto the input
                img = cv2.imread(os.path.join(root, 'input_B.png'))
                mask_B = attmap_hw / attmap_hw.max()

                # save the attention map for image B
                np.save(os.path.join(root, 'attention_map_B.npy'), mask_B)
                
                img_float = img.astype(float) / 255.

                show_att_on_image(img_float, mask_B, os.path.join(root, 'attentions_B', UBIPR_CLASSES[k]+'.png'))

            # generate the one-channel hard assignment via argmax
            _, assign = torch.max(assign_B_reshaped, 1)

            # colorize and save the assignment
            if(not JUST_CARE_ABOUT_THE_SCORES):
                plot_assignment(root, assign.squeeze(0).numpy(), config['nparts'], "B")

                # collect the assignment for the final image array
                color_assignment_name = os.path.join(root, 'assignment_B.png')
                color_assignment = mpimg.imread(color_assignment_name)
                #axarr_assign[j, col_id].imshow(color_assignment)
                #axarr_assign[j, col_id].axis('off')

            # plot the assignment for each dictionary vector
            if(not JUST_CARE_ABOUT_THE_SCORES):
                for i in range(config['nparts']):
                    img = torch.nn.functional.interpolate(assign_B_reshaped.data[:, i].cpu().unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False)
                    img = torchvision.transforms.ToPILImage()(img.squeeze(0))
                    img.save(os.path.join(root, 'assignments_B', 'part_'+str(i)+'.png'))
            '''
            # visualize the attention
            for k in range(num_classes):

                if (k != 0): continue

                # attention vector for kth attribute
                att = att_list[k].view(1, config['nparts'], 1, 1).data.cpu()

                # multiply the assignment with the attention vector
                assign_att = assign_reshaped * att

                # sum along the part dimension to calculate the spatial attention map
                attmap_hw = torch.sum(assign_att, dim=1).squeeze(0).numpy()

                # normalize the attention map and merge it onto the input
                img = cv2.imread(os.path.join(root, 'input.png'))
                mask = attmap_hw / attmap_hw.max()

                # save the attention map
                np.save(os.path.join(root, 'attention_map.npy'), mask)

                img_float = img.astype(float) / 255.
                show_att_on_image(
                    img_float, mask,
                    os.path.join(root, 'attentions',
                                 UBIPR_CLASSES[k] + '.png'))

            # generate the one-channel hard assignment via argmax
            _, assign = torch.max(assign_reshaped, 1)

            # colorize and save the assignment
            if (not JUST_CARE_ABOUT_THE_SCORES):
                plot_assignment(root,
                                assign.squeeze(0).numpy(), config['nparts'],
                                None)

                # collect the assignment for the final image array
                color_assignment_name = os.path.join(root, 'assignment.png')
                color_assignment = mpimg.imread(color_assignment_name)
                #axarr_assign[j, col_id].imshow(color_assignment)
                #axarr_assign[j, col_id].axis('off')

            # plot the assignment for each dictionary vector
            if (not JUST_CARE_ABOUT_THE_SCORES):
                for i in range(config['nparts']):
                    img = torch.nn.functional.interpolate(
                        assign_reshaped.data[:, i].cpu().unsqueeze(0),
                        size=(256, 256),
                        mode='bilinear',
                        align_corners=False)
                    img = torchvision.transforms.ToPILImage()(img.squeeze(0))
                    img.save(
                        os.path.join(root, 'assignments',
                                     'part_' + str(i) + '.png'))

            # --------------------------------------------------------------------------------------------------------------------------------
            # build the final explanation
            # --------------------------------------------------------------------------------------------------------------------------------
            '''difference_mask_1 = np.asarray(Image.open(os.path.join(root, 'attentions_A') + "/" + ground_truth[0][0] + ".png").convert("RGBA"))
            difference_mask_2 = np.asarray(Image.open(os.path.join(root, 'attentions_B') + "/" + ground_truth[1][0] + ".png").convert("RGBA"))

            image_A = np.asarray(Image.open(os.path.join(root, 'input_A.png')).convert("RGBA").resize((127, 127), Image.LANCZOS))
            image_B = np.asarray(Image.open(os.path.join(root, 'input_B.png')).convert("RGBA").resize((127, 127), Image.LANCZOS))

            assemble_explanation(image_A, image_B, difference_mask_2, difference_mask_1, 0.0, "I", os.path.join(root, 'explanation.png'))

            if(JUST_CARE_ABOUT_THE_SCORES):
                rmtree(os.path.join(root, 'attentions_A'))
                rmtree(os.path.join(root, 'attentions_B'))

            elapsed_time = time.time() - t0
            print("[INFO] ELAPSED TIME: %.2fs\n" % (elapsed_time))

            with open("times_by_parts.txt", "a") as file:
                file.write(str(elapsed_time) + "\n")'''

            difference_mask_aux = np.asarray(
                Image.open(os.path.join(root, 'attentions') +
                           "/I.png").convert("RGBA"))
            difference_mask_1 = difference_mask_aux[64:64 + 128, :128, :]
            difference_mask_2 = difference_mask_aux[64:64 + 128, 128:, :]

            input_aux = np.asarray(
                Image.open(os.path.join(root, 'input.png')).convert("RGBA"))
            image_A = np.asarray(
                Image.fromarray(input_aux[64:64 + 128, :128, :].astype(
                    np.uint8)).resize((127, 127),
                                      Image.LANCZOS).convert("RGBA"))
            image_B = np.asarray(
                Image.fromarray(input_aux[64:64 + 128,
                                          128:, :].astype(np.uint8)).resize(
                                              (127, 127),
                                              Image.LANCZOS).convert("RGBA"))

            assemble_explanation(image_A, image_B, difference_mask_2,
                                 difference_mask_1, 0.0, "I",
                                 os.path.join(root, 'explanation.png'))

            if (JUST_CARE_ABOUT_THE_SCORES):
                rmtree(os.path.join(root, 'attentions'))

            elapsed_time = time.time() - t0
            print("[INFO] ELAPSED TIME: %.2fs\n" % (elapsed_time))

            with open("times_by_parts.txt", "a") as file:
                file.write(str(elapsed_time) + "\n")

        # save the array version
        os.makedirs('../../visualization/collected', exist_ok=True)
        f_assign.savefig(
            os.path.join('../../visualization/collected', args.load + '.png'))

        print('Visualization finished!')
def main():

    global best_acc

    # create model by archetecture and load the pretrain weight
    print("=> creating model...")

    if args['arch'] == 'resnet101':
        model = ResNet101(args['num_classes'], args['nparts'])
        model.load_state_dict(models.resnet101(pretrained=True).state_dict(),
                              strict=False)
    elif args['arch'] == 'resnet50':
        model = ResNet50(args['num_classes'], args['nparts'])
        model.load_state_dict(models.resnet50(pretrained=True).state_dict(),
                              strict=False)
    else:
        raise (RuntimeError("Only support ResNet50 or ResNet101!"))

    model = torch.nn.DataParallel(model).cuda()

    # optionally resume from a checkpoint
    start_epoch = 0
    if args['resume'] != '':
        if os.path.isfile(args['resume']):
            print("=> loading checkpoint '{}'".format(args['resume']))
            checkpoint = torch.load(args['resume'])
            start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args['resume'], checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args['resume']))

    # data augmentation
    train_transforms = transforms.Compose([
        transforms.Resize(size=448),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.1),
        transforms.RandomCrop(size=448),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])
    test_transforms = transforms.Compose([
        transforms.Resize(size=448),
        transforms.CenterCrop(size=448),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])

    # wrap to dataset
    train_data = CUB200(root=data_dir, train=True, transform=train_transforms)
    test_data = CUB200(root=data_dir, train=False, transform=test_transforms)

    # wrap to dataloader
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args['batch_size'],
                                               shuffle=True,
                                               num_workers=args['workers'],
                                               pin_memory=False,
                                               drop_last=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args['batch_size'],
                                              shuffle=False,
                                              num_workers=args['workers'],
                                              pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()

    # fix/finetune several layers
    fixed_layers = args['fixed']
    finetune_layers = args['finetune']
    finetune_parameters = []
    scratch_parameters = []
    for name, p in model.named_parameters():
        layer_name = name.split('.')[1]
        if layer_name not in fixed_layers:
            if layer_name in finetune_layers:
                finetune_parameters.append(p)
            else:
                scratch_parameters.append(p)
        else:
            p.requires_grad = False

    # define the optimizer according to different param groups
    optimizer = torch.optim.SGD([
        {
            'params': scratch_parameters,
            'lr': 20 * args['lr']
        },
        {
            'params': finetune_parameters,
            'lr': args['lr']
        },
    ],
                                weight_decay=args['weight_decay'],
                                momentum=0.9)

    # define the MultiStep learning rate scheduler
    num_iters = len(train_loader)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, num_iters * args['epochs'])

    # load the scheduler from the checkpoint if needed
    if args['resume'] != '':
        if os.path.isfile(args['resume']):
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler'])

    # training part
    for epoch in range(start_epoch, args['epochs']):

        # training
        train(train_loader, model, criterion, optimizer, epoch, scheduler)

        # evaluate on test set
        acc = test(test_loader, model, criterion, epoch)

        # remember best acc and save checkpoint
        is_best = acc > best_acc
        best_acc = max(acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_acc,
                'scheduler': scheduler.state_dict(),
            }, is_best, os.path.join(check_dir, args['save']))

        # print current best acc
        print('Current best average accuracy is: %.4f' % best_acc)

    # print the overall best acc and close the writer
    print('Training finished...')
    print('Best accuracy on test set is: %.4f.' % best_acc)
    writer.close()