Beispiel #1
0
def switch_backbones(bone_name):
    from nets.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \
        resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2
    if bone_name == "resnet18":
        return resnet18()
    elif bone_name == "resnet34":
        return resnet34()
    elif bone_name == "resnet50":
        return resnet50()
    elif bone_name == "resnet101":
        return resnet101()
    elif bone_name == "resnet152":
        return resnet152()
    elif bone_name == "resnext50_32x4d":
        return resnext50_32x4d()
    elif bone_name == "resnext101_32x8d":
        return resnext101_32x8d()
    elif bone_name == "wide_resnet50_2":
        return wide_resnet50_2()
    elif bone_name == "wide_resnet101_2":
        return wide_resnet101_2()
    else:
        raise NotImplementedError(bone_name)
def _main(args):

    #### Preparing Train Dataset ####
    train_data_root = './datasets/standford_online_products/train'
    train_data_transform = torch_transforms.Resize((225, 225))
    train_num_retrieval_per_class = 10
    train_pca_n_components = 2
    train_pos_neighbor, train_neg_neighbor = (False, False)
    train_dataloader = sop.loader(  train_data_root, \
                                    data_transform=train_data_transform, \
                                    eval_mode=True, \
                                    eval_num_retrieval=train_num_retrieval_per_class, \
                                    neg_neighbor=train_neg_neighbor, \
                                    pos_neighbor=train_pos_neighbor
                                )

    #### Preparing Test Dataset ####
    test_data_root = './datasets/standford_online_products/test'
    test_data_transform = torch_transforms.Resize((225, 225))
    test_num_retrieval_per_class = 10
    test_pca_n_components = 2
    test_pos_neighbor, test_neg_neighbor = (False, False)
    test_dataloader = sop.loader(   test_data_root, \
                                    data_transform=test_data_transform, \
                                    eval_mode=True, \
                                    eval_num_retrieval=test_num_retrieval_per_class, \
                                    neg_neighbor=test_neg_neighbor, \
                                    pos_neighbor=test_pos_neighbor
                                )

    #### Preparing Validation Dataset ####
    val_data_root = './datasets/standford_online_products/val'
    val_num_retrieval_per_class = test_num_retrieval_per_class
    val_data_transform = torch_transforms.Resize((225, 225))
    val_pca_n_components = 2
    val_pos_neighbor, val_neg_neighbor = (False, False)
    val_dataloader = sop.loader(val_data_root, \
                                data_transform=val_data_transform, \
                                eval_mode=True, \
                                eval_num_retrieval=val_num_retrieval_per_class,\
                                neg_neighbor=val_neg_neighbor, \
                                pos_neighbor=val_pos_neighbor
                                )

    #### Preparing Pytorch ####
    device = args.device
    assert (device in [
        'cpu', 'multi'
    ]) or (len(device.split(':')) == 2 and device.split(':')[0] == 'cuda'
           and int(device.split(':')[1]) < torch.cuda.device_count()
           ), 'Uknown device: {}'.format(device)
    torch.manual_seed(0)
    if args.device != 'multi':
        device = torch.device(args.device)
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)

    #### Training Parameters ####
    start_epoch, num_epoch = (args.start_epoch, args.epochs)
    num_workers = args.num_workers
    check_counter = 10

    #### Reports Address ####
    reports_root = './reports'
    analysis_num = args.analysis
    reports_path = '{}/{}'.format(reports_root, analysis_num)
    loading_model_path = '{}/models'.format(reports_path)

    #### Constructing Model ####
    pretrained = args.pretrained
    num_classes = val_dataloader.num_classes()

    #### Constructing Model ####
    pretrained = args.pretrained
    num_classes = val_dataloader.num_classes()

    model = None
    if args.resnet_type == 'resnet18':
        model = resnet.resnet18(pretrained=pretrained, num_classes=num_classes)
    elif args.resnet_type == 'resnet34':
        model = resnet.resnet34(pretrained=pretrained, num_classes=num_classes)
    elif args.resnet_type == 'resnet50':
        model = resnet.resnet50(pretrained=pretrained, num_classes=num_classes)
    elif args.resnet_type == 'resnet101':
        model = resnet.resnet101(pretrained=pretrained,
                                 num_classes=num_classes)
    elif args.resnet_type == 'resnet152':
        model = resnet.resnet152(pretrained=pretrained,
                                 num_classes=num_classes)
    elif args.resnet_type == 'resnext50_32x4d':
        model = resnet.resnext50_32x4d(pretrained=pretrained,
                                       num_classes=num_classes)
    # elif args.resnet_type=='resnext101_32x8d':
    #     model = resnet.resnext101_32x8d(pretrained=pretrained, num_classes=num_classes)

    model, optimizer = resnet.load(loading_model_path,
                                   'resnet_epoch_{}'.format(start_epoch),
                                   model)

    if args.gpu and torch.cuda.is_available():
        if device == 'multi':
            model = nn.DataParallel(model)
        else:
            model = model.cuda(device=device)

    plot_representation(model, train_dataloader, device, args.gpu
                        and torch.cuda.is_available(), num_workers,
                        train_pca_n_components)