示例#1
0
def plot_Cam(query_img, label, model_name):
    image = io.imread(query_img)
    tensor = data_transforms(image)
    prediction_var = Variable((tensor.unsqueeze(0)).cuda(), requires_grad=True)
    model, _ = get_pretrained_models(model_name)
    model.cuda()
    model.eval()
    final_layer = list(model.model.children())[-3]
    activated_features = SaveFeatures(final_layer)
    prediction = model(prediction_var)
    pred_probabilities = F.softmax(prediction).data.squeeze()
    pred_label = topk(pred_probabilities,
                      1).indices[0].item()  # get the top predicted class
    pred_label = label_dict[pred_label]  # to string
    label = "label: " + label
    pred_label = "predict: " + pred_label
    activated_features.remove()
    weight_softmax_params = list(model.model.fc.parameters())
    weight_softmax = np.squeeze(weight_softmax_params[0].cpu().data.numpy())
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(8, 8))
    class_idx = topk(pred_probabilities, 1)[1].int()
    overlay = getCAM(activated_features.features, weight_softmax, class_idx)
    axes[0].imshow(display_transform(image), cmap='gray')
    axes[0].set_title(label)
    axes[1].imshow(overlay[0], alpha=0.9, cmap='gray')
    axes[2].imshow(display_transform(image), cmap='gray')
    axes[2].imshow(skimage.transform.resize(overlay[0], tensor.shape[1:3]),
                   alpha=0.5,
                   cmap='jet')
    axes[2].set_title(pred_label)
    [axes[i].set_xticks([]) for i in range(3)]
    [axes[i].set_yticks([]) for i in range(3)]
    plt.suptitle("Grad-Cam from model " + model_name, fontsize=15)
    fig.tight_layout()
    fig.subplots_adjust(top=1.55)
    plt.show()
示例#2
0
        img_name = os.path.join(self.root_dir, self.label_raw[idx]['image_id'])
        img_name_raw = self.label_raw[idx]['image_id']
        image = Image.open(img_name)
        label = int(self.label_raw[idx]['label_id'])

        if self.transform:
            image = self.transform(image)

        return image, label, img_name_raw


transformed_dataset_test = SceneDataset(
    json_labels=label_raw_test,
    root_dir=
    '/home/member/fuwang/projects/scene/data/ai_challenger_scene_test_a_20170922/scene_test_a_images_20170922',
    transform=data_transforms('ten_crop', input_size, train_scale, test_scale))
transformed_dataset_val = SceneDataset(
    json_labels=label_raw_val,
    root_dir=
    '/home/member/fuwang/projects/scene/data/ai_challenger_scene_validation_20170908/scene_validation_images_20170908',
    transform=data_transforms('ten_crop', input_size, train_scale, test_scale))
dataloader = {
    'test':
    DataLoader(transformed_dataset_test,
               batch_size=batch_size,
               shuffle=False,
               num_workers=INPUT_WORKERS),
    'val':
    DataLoader(transformed_dataset_val,
               batch_size=batch_size,
               shuffle=False,
示例#3
0
def run():
    model = load_model(arch,
                       pretrained,
                       use_gpu=use_gpu,
                       num_classes=num_classes,
                       AdaptiveAvgPool=AdaptiveAvgPool,
                       SPP=SPP,
                       num_levels=num_levels,
                       pool_type=pool_type,
                       bilinear=bilinear,
                       stage=stage,
                       SENet=SENet,
                       se_stage=se_stage,
                       se_layers=se_layers,
                       threshold_before_avg=threshold_before_avg)

    if use_gpu:
        if arch.lower().startswith('alexnet') or arch.lower().startswith(
                'vgg'):
            model.features = nn.DataParallel(model.features)
            model.cuda()
        else:
            model = nn.DataParallel(model).cuda()  #
            #model = nn.DataParallel(model, device_ids=[2]).cuda()

    best_prec3 = 0
    best_loss1 = 10000

    if try_resume:
        if os.path.isfile(latest_check):
            print("=> loading checkpoint '{}'".format(latest_check))
            checkpoint = torch.load(latest_check)
            global start_epoch
            start_epoch = checkpoint['epoch']
            best_prec3 = checkpoint['best_prec3']
            best_loss1 = checkpoint['loss1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'".format(latest_check))
        else:
            print("=> no checkpoint found at '{}'".format(latest_check))

    cudnn.benchmark = True

    if class_aware:
        train_set = jd.ImageDataset(TRAIN_ROOT,
                                    include_target=True,
                                    X_transform=data_transforms(
                                        train_transform, input_size,
                                        train_scale, test_scale))
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=BATCH_SIZE,
            shuffle=False,
            sampler=ClassAwareSampler.ClassAwareSampler(train_set),
            num_workers=INPUT_WORKERS,
            pin_memory=use_gpu)
    else:
        train_loader = torch.utils.data.DataLoader(jd.ImageDataset(
            TRAIN_ROOT,
            include_target=True,
            X_transform=data_transforms(train_transform, input_size,
                                        train_scale, test_scale)),
                                                   batch_size=BATCH_SIZE,
                                                   shuffle=True,
                                                   num_workers=INPUT_WORKERS,
                                                   pin_memory=use_gpu)

    if isinstance(VALIDATION_ROOT, pd.DataFrame):
        val_loader = torch.utils.data.DataLoader(jd.ImageDataset(
            VALIDATION_ROOT,
            include_target=True,
            X_transform=data_transforms(val_transform, input_size, train_scale,
                                        test_scale)),
                                                 batch_size=BATCH_SIZE,
                                                 shuffle=False,
                                                 num_workers=INPUT_WORKERS,
                                                 pin_memory=use_gpu)

    criterion = nn.CrossEntropyLoss().cuda(
    ) if use_gpu else nn.CrossEntropyLoss()

    if if_fc:
        if pretrained == 'imagenet' or arch == 'resnet50' or arch == 'resnet18':
            ignored_params = list(map(id, model.module.fc.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 model.module.parameters())
            lr_dicts = [{
                'params': base_params,
                'lr': lr1
            }, {
                'params': model.module.fc.parameters(),
                'lr': lr2
            }]

        elif pretrained == 'places':
            if arch == 'preact_resnet50':
                lr_dicts = list()
                lr_dicts.append({
                    'params':
                    model.module._modules['12']._modules['1'].parameters(),
                    'lr':
                    lr2
                })
                for _, index in enumerate(model.module._modules):
                    if index != '12':
                        lr_dicts.append({
                            'params':
                            model.module._modules[index].parameters(),
                            'lr':
                            lr1
                        })
                    else:
                        for index2, _ in enumerate(
                                model.module._modules[index]):
                            if index2 != 1:
                                lr_dicts.append({
                                    'params':
                                    model.module._modules[index]._modules[str(
                                        index2)].parameters(),
                                    'lr':
                                    lr1
                                })
            elif arch == 'resnet152':
                lr_dicts = list()
                lr_dicts.append({
                    'params':
                    model.module._modules['10']._modules['1'].parameters(),
                    'lr':
                    lr2
                })
                for _, index in enumerate(model.module._modules):
                    if index != '10':
                        lr_dicts.append({
                            'params':
                            model.module._modules[index].parameters(),
                            'lr':
                            lr1
                        })
                    else:
                        for index2, _ in enumerate(
                                model.module._modules[index]):
                            if index2 != 1:
                                lr_dicts.append({
                                    'params':
                                    model.module._modules[index]._modules[str(
                                        index2)].parameters(),
                                    'lr':
                                    lr1
                                })

        if optim_type == 'Adam':
            optimizer = optim.Adam(lr_dicts,
                                   betas=betas,
                                   eps=eps,
                                   weight_decay=weight_decay)
        elif optim_type == 'SGD':
            optimizer = optim.SGD(lr_dicts,
                                  momentum=momentum,
                                  weight_decay=weight_decay)
    else:
        if optim_type == 'Adam':
            if stage == 1 or se_stage == 1:
                #                optimizer = optim.Adam(model.module.fc.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
                optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                              model.parameters()),
                                       lr=lr,
                                       betas=betas,
                                       eps=eps,
                                       weight_decay=weight_decay)
            else:
                optimizer = optim.Adam(model.parameters(),
                                       lr=lr,
                                       betas=betas,
                                       eps=eps,
                                       weight_decay=weight_decay)
        elif optim_type == 'SGD':
            if stage == 1 or se_stage == 1:
                #                if pretrained == 'places' and arch == 'preact_resnet50':
                #                    optimizer = optim.SGD(model.module._modules['12']._modules['1'].parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
                #                elif pretrained =='places' and arch == 'resnet152':
                #                    optimizer = optim.SGD(model.module._modules['10']._modules['1'].parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
                #                else:
                #                    optimizer = optim.SGD(model.module.fc.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
                optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                             model.parameters()),
                                      lr=lr,
                                      momentum=momentum,
                                      weight_decay=weight_decay)
            else:
                optimizer = optim.SGD(model.parameters(),
                                      lr=lr,
                                      momentum=momentum,
                                      weight_decay=weight_decay)

    if evaluate:
        validate(val_loader, model, criterion)

    else:

        for epoch in range(start_epoch, epochs):

            # train for one epoch
            prec3, loss1 = train(train_loader, model, criterion, optimizer,
                                 epoch)

            # evaluate on validation set
            if isinstance(VALIDATION_ROOT, pd.DataFrame):
                prec3, loss1 = validate(val_loader, model, criterion, epoch)

            # remember best
            is_best = loss1 <= best_loss1
            best_loss1 = min(loss1, best_loss1)
            if epoch % save_freq == 0:
                save_checkpoint_epoch(
                    {
                        'epoch': epoch + 1,
                        'arch': arch,
                        'state_dict': model.state_dict(),
                        'best_prec3': best_prec3,
                        'loss1': loss1
                    }, epoch + 1)
            if is_best:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': arch,
                        'state_dict': model.state_dict(),
                        'best_prec3': best_prec3,
                        'loss1': loss1
                    }, is_best)
            if use_epoch_decay:
                if_adjust = adjust_learning_rate(optimizer, epoch, if_fc,
                                                 use_epoch_decay)
                if if_adjust:
                    my_check = torch.load(best_check)
                    model.load_state_dict(my_check['state_dict'])
                    best_prec3 = my_check['best_prec3']
                    best_loss1 = my_check['loss1']
            else:
                if not is_best:
                    if lr <= lr_min:  #lr特别小的时候别来回回滚checkpoint了
                        best_loss1 = loss1
                    else:
                        my_check = torch.load(best_check)
                        model.load_state_dict(my_check['state_dict'])
                        best_loss1 = my_check['loss1']
                        best_prec3 = my_check['best_prec3']
                        adjust_learning_rate(optimizer, epoch, if_fc,
                                             use_epoch_decay)
示例#4
0
        img_name = os.path.join(self.root_dir, self.label_raw[idx]['image_id'])
        img_name_raw = self.label_raw[idx]['image_id']
        image = Image.open(img_name)
        label = int(self.label_raw[idx]['label_id'])

        if self.transform:
            image = self.transform(image)

        return image, label, img_name_raw


transformed_dataset_test = SceneDataset(
    json_labels=label_raw_test,
    root_dir=
    '/home/member/fuwang/projects/scene/data/ai_challenger_scene_test_a_20170922/scene_test_a_images_20170922',
    transform=data_transforms('ten_crop'))
transformed_dataset_val = SceneDataset(
    json_labels=label_raw_val,
    root_dir=
    '/home/member/fuwang/projects/scene/data/ai_challenger_scene_validation_20170908/scene_validation_images_20170908',
    transform=data_transforms('ten_crop'))
dataloader = {
    'test':
    DataLoader(transformed_dataset_test,
               batch_size=batch_size,
               shuffle=False,
               num_workers=INPUT_WORKERS),
    'val':
    DataLoader(transformed_dataset_val,
               batch_size=batch_size,
               shuffle=False,
示例#5
0
def run():
    model = load_model(arch,
                       pretrained,
                       use_gpu=use_gpu,
                       AdaptiveAvgPool=AdaptiveAvgPool,
                       SPP=SPP,
                       num_levels=num_levels,
                       pool_type=pool_type,
                       stage=stage,
                       use_multi_path=True)

    model = nn.DataParallel(model).cuda()

    best_prec3 = 0
    best_loss1 = 10000

    if try_resume:
        if os.path.isfile(latest_check):
            print("=> loading checkpoint '{}'".format(latest_check))
            checkpoint = torch.load(latest_check)
            global start_epoch
            start_epoch = checkpoint['epoch']
            best_prec3 = checkpoint['best_prec3']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                latest_check, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(latest_check))

    cudnn.benchmark = True

    if class_aware:
        train_set = data.ChallengerSceneFolder(
            data.TRAIN_ROOT,
            data_transforms(train_transform, input_size, train_scale,
                            test_scale))
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=BATCH_SIZE,
            shuffle=False,
            sampler=ClassAwareSampler.ClassAwareSampler(train_set),
            num_workers=INPUT_WORKERS,
            pin_memory=use_gpu)
    else:
        train_loader = torch.utils.data.DataLoader(data.ChallengerSceneFolder(
            data.TRAIN_ROOT,
            data_transforms(train_transform, input_size, train_scale,
                            test_scale)),
                                                   batch_size=BATCH_SIZE,
                                                   shuffle=True,
                                                   num_workers=INPUT_WORKERS,
                                                   pin_memory=use_gpu)

    val_loader = torch.utils.data.DataLoader(data.ChallengerSceneFolder(
        data.VALIDATION_ROOT,
        data_transforms('validation', input_size, train_scale, test_scale)),
                                             batch_size=BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=INPUT_WORKERS,
                                             pin_memory=use_gpu)

    criterion = nn.CrossEntropyLoss().cuda(
    ) if use_gpu else nn.CrossEntropyLoss()

    if if_fc:
        if pretrained == 'imagenet' or arch == 'resnet50' or arch == 'resnet18':
            ignored_params = list(map(id, model.module.fc.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 model.module.parameters())
            lr_dicts = [{
                'params': base_params,
                'lr': lr1
            }, {
                'params': model.module.fc.parameters(),
                'lr': lr2
            }]

        elif pretrained == 'places':
            if arch == 'preact_resnet50':
                lr_dicts = list()
                lr_dicts.append({
                    'params':
                    model.module._modules['12']._modules['1'].parameters(),
                    'lr':
                    lr2
                })
                for _, index in enumerate(model.module._modules):
                    if index != '12':
                        lr_dicts.append({
                            'params':
                            model.module._modules[index].parameters(),
                            'lr':
                            lr1
                        })
                    else:
                        for index2, _ in enumerate(
                                model.module._modules[index]):
                            if index2 != 1:
                                lr_dicts.append({
                                    'params':
                                    model.module._modules[index]._modules[str(
                                        index2)].parameters(),
                                    'lr':
                                    lr1
                                })
            elif arch == 'resnet152':
                lr_dicts = list()
                lr_dicts.append({
                    'params':
                    model.module._modules['10']._modules['1'].parameters(),
                    'lr':
                    lr2
                })
                for _, index in enumerate(model.module._modules):
                    if index != '10':
                        lr_dicts.append({
                            'params':
                            model.module._modules[index].parameters(),
                            'lr':
                            lr1
                        })
                    else:
                        for index2, _ in enumerate(
                                model.module._modules[index]):
                            if index2 != 1:
                                lr_dicts.append({
                                    'params':
                                    model.module._modules[index]._modules[str(
                                        index2)].parameters(),
                                    'lr':
                                    lr1
                                })

        if optim_type == 'Adam':
            optimizer = optim.Adam(lr_dicts,
                                   betas=betas,
                                   eps=eps,
                                   weight_decay=weight_decay)
        elif optim_type == 'SGD':
            optimizer = optim.SGD(lr_dicts,
                                  momentum=momentum,
                                  weight_decay=weight_decay)
    else:
        if optim_type == 'Adam':
            if stage == 1:
                optimizer = optim.Adam(model.module.fc.parameters(),
                                       lr=lr,
                                       betas=betas,
                                       eps=eps,
                                       weight_decay=weight_decay)
            else:
                optimizer = optim.Adam(model.parameters(),
                                       lr=lr,
                                       betas=betas,
                                       eps=eps,
                                       weight_decay=weight_decay)
        elif optim_type == 'SGD':
            if stage == 1:
                if pretrained == 'places' and arch == 'preact_resnet50':
                    optimizer = optim.SGD(
                        model.module._modules['12']._modules['1'].parameters(),
                        lr=lr,
                        momentum=momentum,
                        weight_decay=weight_decay)
                elif pretrained == 'places' and arch == 'resnet152':
                    optimizer = optim.SGD(
                        model.module._modules['10']._modules['1'].parameters(),
                        lr=lr,
                        momentum=momentum,
                        weight_decay=weight_decay)
                else:
                    optimizer = optim.SGD(model.module.fc.parameters(),
                                          lr=lr,
                                          momentum=momentum,
                                          weight_decay=weight_decay)
            else:
                optimizer = optim.SGD(model.parameters(),
                                      lr=lr,
                                      momentum=momentum,
                                      weight_decay=weight_decay)

    if evaluate:
        validate(val_loader, model, criterion)

    else:

        for epoch in range(start_epoch, epochs):

            # train for one epoch
            train(train_loader, model, criterion, optimizer, epoch)

            # evaluate on validation set
            prec3, loss1 = validate(val_loader, model, criterion, epoch)

            # remember best prec@1 and save checkpoint
            is_best = prec3 >= best_prec3
            best_prec3 = max(prec3, best_prec3)
            if is_best:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': arch,
                        'state_dict': model.state_dict(),
                        'best_prec3': best_prec3,
                        'loss1': loss1
                    }, is_best)
                best_loss1 = loss1
            else:
                is_best_loss = (loss1 <= best_loss1)
                if is_best_loss or lr <= lr_min:  #lr特别小的时候别来回回滚checkpoint了
                    best_loss1 = loss1
                else:
                    my_check = torch.load(best_check)
                    model.load_state_dict(my_check['state_dict'])
                    best_loss1 = my_check['loss1']
                    #准确率没上升(超过最好),且loss相对上次没下降时调整lr
                    adjust_learning_rate(optimizer, epoch, if_fc)
示例#6
0
文件: train.py 项目: filick/scene
def run():
    model = load_model(arch,
                       pretrained,
                       use_gpu=use_gpu,
                       AdaptiveAvgPool=AdaptiveAvgPool)

    if use_gpu:
        if arch.lower().startswith('alexnet') or arch.lower().startswith(
                'vgg'):
            model.features = nn.DataParallel(model.features)
            model.cuda()
        else:
            model = nn.DataParallel(model).cuda()

        best_prec1 = 0

    if try_resume:
        if os.path.isfile(latest_check):
            print("=> loading checkpoint '{}'".format(latest_check))
            checkpoint = torch.load(latest_check)
            global start_epoch
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                latest_check, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(latest_check))

    cudnn.benchmark = True

    if class_aware:
        train_set = data.ChallengerSceneFolder(
            data.TRAIN_ROOT,
            data_transforms(train_transform, input_size, train_scale,
                            test_scale))
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=BATCH_SIZE,
            shuffle=False,
            sampler=ClassAwareSampler.ClassAwareSampler(train_set),
            num_workers=INPUT_WORKERS,
            pin_memory=use_gpu)
    else:
        train_loader = torch.utils.data.DataLoader(data.ChallengerSceneFolder(
            data.TRAIN_ROOT,
            data_transforms(train_transform, input_size, train_scale,
                            test_scale)),
                                                   batch_size=BATCH_SIZE,
                                                   shuffle=True,
                                                   num_workers=INPUT_WORKERS,
                                                   pin_memory=use_gpu)

    val_loader = torch.utils.data.DataLoader(data.ChallengerSceneFolder(
        data.VALIDATION_ROOT,
        data_transforms('validation', input_size, train_scale, test_scale)),
                                             batch_size=BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=INPUT_WORKERS,
                                             pin_memory=use_gpu)

    criterion = nn.CrossEntropyLoss().cuda(
    ) if use_gpu else nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),
                           lr=lr,
                           betas=betas,
                           eps=eps,
                           weight_decay=weight_decay)
    #optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

    if evaluate:
        validate(val_loader, model, criterion)

    else:

        for epoch in range(start_epoch, epochs):

            # train for one epoch
            train(train_loader, model, criterion, optimizer, epoch)

            # evaluate on validation set
            prec1 = validate(val_loader, model, criterion, epoch)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            if is_best:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': arch,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                    }, is_best)
            else:
                my_check = torch.load(best_check)
                model.load_state_dict(my_check['state_dict'])
                adjust_learning_rate(optimizer, epoch)
示例#7
0
        file = 'result/' + phases[0] + '_1_' + epoch_i.split('.')[0].split(
            '_')[-1] + '.csv'
    else:
        file = 'result/' + phases[0] + '_1.csv'
    with open(file, 'w', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile, dialect='excel')
        writer.writerow(["id", "is_iceberg"])
        for item in aug_softmax.keys():
            writer.writerow(
                [item, max(min(aug_softmax[item][1], 0.9999), 0.0001)])


transformed_dataset_test = jd.ImageDataset(test_root,
                                           include_target=True,
                                           X_transform=data_transforms(
                                               val_transform, input_size,
                                               train_scale, test_scale))

dataloader = {
    phases[0]:
    DataLoader(transformed_dataset_test,
               batch_size=batch_size,
               shuffle=False,
               num_workers=INPUT_WORKERS)
}
dataset_sizes = {phases[0]: len(test_root)}


class AverageMeter(object):
    def __init__(self):
        self.reset()
示例#8
0
best_checkpoint = torch.load(best_check)
# hook the feature extractor
features_blobs = []


def hook_feature(module, input, output):
    features_blobs.append(np.squeeze(output.data.cpu().numpy()))


features_names = [
    'layer5'
]  # layer4 is for original resnet,layer5 is for masked resnet
for name in features_names:
    model_conv._modules.get(name).register_forward_hook(hook_feature)

tf = data_transforms('validation', input_size, train_scale, test_scale)


def returnCAM(feature_conv, weight_softmax, class_idx):
    # generate the class activation maps upsample to 256x256
    size_upsample = (256, 256)
    nc, h, w = feature_conv.shape
    output_cam = []
    for idx in class_idx:
        cam = weight_softmax[class_idx].dot(feature_conv.reshape((nc, h * w)))
        cam = cam.reshape(h, w)
        cam = cam - np.min(cam)
        cam_img = cam / np.max(cam)
        cam_img = np.uint8(255 * cam_img)
        output_cam.append(imresize(cam_img, size_upsample))
    return output_cam
示例#9
0
def run():
    model = load_model(arch, pretrained, use_gpu=use_gpu, num_classes=num_classes,  AdaptiveAvgPool=AdaptiveAvgPool,
                       SPP=SPP, num_levels=num_levels, pool_type=pool_type, bilinear=bilinear, stage=stage, 
                       SENet=SENet,se_stage=se_stage,se_layers=se_layers, 
                       threshold_before_avg = threshold_before_avg, triplet = triplet)
                                
    if use_gpu:
        if arch.lower().startswith('alexnet') or arch.lower().startswith('vgg'):
            model.features = nn.DataParallel(model.features)
            model.cuda()
        else:
            model = nn.DataParallel(model).cuda()#
            #model = nn.DataParallel(model, device_ids=[2]).cuda()

    best_prec3 = 0
    best_loss1 = 10000

    if try_resume:
        if os.path.isfile(latest_check):
            print("=> loading checkpoint '{}'".format(latest_check))
            checkpoint = torch.load(latest_check)
            global start_epoch 
            start_epoch = checkpoint['epoch']
            print(start_epoch)
            best_prec3 = checkpoint['best_prec3']
            best_loss1 = checkpoint['loss1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'"
                  .format(latest_check))
        else:
            print("=> no checkpoint found at '{}'".format(latest_check))

    cudnn.benchmark = True

    with open(TRAIN_ROOT+'/pig_test_annotations.json', 'r') as f: #label文件
        label_raw_train = json.load(f)
    with open(VALIDATION_ROOT+'/pig_test_annotations.json', 'r') as f: #label文件
        label_raw_val = json.load(f)
    if class_aware:
        raise ValueError('Have not tested if class aware works for dataset, it works for imagefolder')
        train_set = triplet_image_dataset.TripletImageDataset(json_labels=label_raw_train,
                                        root_dir=TRAIN_ROOT, transform = data_transforms(train_transform,input_size, train_scale, test_scale),
                                        distance = train_distance, frames = train_frames)
        train_loader = torch.utils.data.DataLoader(
                train_set,
                batch_size=BATCH_SIZE, shuffle=False,
                sampler=ClassAwareSampler.ClassAwareSampler(train_set),
                num_workers=INPUT_WORKERS, pin_memory=use_gpu)
    else:
        train_loader = torch.utils.data.DataLoader(
                triplet_image_dataset.TripletImageDataset(json_labels=label_raw_train,
                                        root_dir=TRAIN_ROOT, transform = data_transforms(train_transform,input_size, train_scale, test_scale),
                                        distance = train_distance, frames = train_frames),
                batch_size=BATCH_SIZE, shuffle=True,
                num_workers=INPUT_WORKERS, pin_memory=use_gpu)
        
    val_loader = torch.utils.data.DataLoader(
            triplet_image_dataset.TripletImageDataset(json_labels=label_raw_val,
                                        root_dir=VALIDATION_ROOT, transform = data_transforms('validation',input_size, train_scale, test_scale),
                                        distance = val_distance, frames = val_frames),
            batch_size=BATCH_SIZE, shuffle=False,
            num_workers=INPUT_WORKERS, pin_memory=use_gpu)

    if if_fc:
        if pretrained == 'imagenet' or arch == 'resnet50' or arch == 'resnet18':
            ignored_params = list(map(id, model.module.fc.parameters()))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 model.module.parameters())
            lr_dicts = [{'params': base_params, 'lr':lr1}, 
                         {'params': model.module.fc.parameters(), 'lr':lr2}]
            
        elif pretrained =='places':
            if arch == 'preact_resnet50':
                lr_dicts = list()
                lr_dicts.append({'params': model.module._modules['12']._modules['1'].parameters(), 'lr':lr2})
                for _, index in enumerate(model.module._modules):
                    if index != '12':
                        lr_dicts.append({'params': model.module._modules[index].parameters(), 'lr':lr1})
                    else:
                        for index2,_ in enumerate(model.module._modules[index]):
                            if index2 !=1:
                                lr_dicts.append({'params': model.module._modules[index]._modules[str(index2)].parameters(), 'lr':lr1})
            elif arch == 'resnet152':
                lr_dicts = list()
                lr_dicts.append({'params': model.module._modules['10']._modules['1'].parameters(), 'lr':lr2})
                for _, index in enumerate(model.module._modules):
                    if index != '10':
                        lr_dicts.append({'params': model.module._modules[index].parameters(), 'lr':lr1})
                    else:
                        for index2,_ in enumerate(model.module._modules[index]):
                            if index2 !=1:
                                lr_dicts.append({'params': model.module._modules[index]._modules[str(index2)].parameters(), 'lr':lr1})

        if optim_type == 'Adam':
                optimizer = optim.Adam(lr_dicts,
                                         betas=betas, eps=eps, weight_decay=weight_decay) 
        elif optim_type == 'SGD':
                optimizer = optim.SGD(lr_dicts, 
                                         momentum=momentum, weight_decay=weight_decay)
    else:
        if optim_type == 'Adam':
            if stage == 1 or se_stage == 1:
#                optimizer = optim.Adam(model.module.fc.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 
                optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
            else:
                optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 
        elif optim_type == 'SGD':
            if stage == 1 or se_stage == 1:
#                if pretrained == 'places' and arch == 'preact_resnet50':
#                    optimizer = optim.SGD(model.module._modules['12']._modules['1'].parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
#                elif pretrained =='places' and arch == 'resnet152':
#                    optimizer = optim.SGD(model.module._modules['10']._modules['1'].parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
#                else:
#                    optimizer = optim.SGD(model.module.fc.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
                optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=momentum, weight_decay=weight_decay)
            else:
                optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    
    if evaluate:
        validate(val_loader, model, criterion)

    else:

        for epoch in range(start_epoch, epochs):

            # train for one epoch
            prec3, loss1 = train(train_loader, model, criterion, optimizer, epoch)

            # evaluate on validation set
            if VALIDATION_ROOT != None:
                prec3, loss1 = validate(val_loader, model, criterion, epoch)

            # remember best 
            is_best = loss1 <= best_loss1
            best_loss1 = min(loss1, best_loss1)
            if epoch % save_freq == 0:
                save_checkpoint_epoch({
                        'epoch': epoch + 1,
                        'arch': arch,
                        'state_dict': model.state_dict(),
                        'best_prec3': best_prec3,
                        'loss1': loss1
                        },  epoch+1)
            save_checkpoint({
                        'epoch': epoch + 1,
                        'arch': arch,
                        'state_dict': model.state_dict(),
                        'best_prec3': best_prec3,
                        'loss1': loss1
                        }, is_best)