コード例 #1
0
def setup(args):
    # Prepare model
    config = CONFIGS[args.model_type]
    config.split = args.split
    config.slide_step = args.slide_step

    if args.dataset == "CUB_200_2011":
        num_classes = 4
    elif args.dataset == "car":
        num_classes = 196
    elif args.dataset == "nabirds":
        num_classes = 555
    elif args.dataset == "dog":
        num_classes = 120
    elif args.dataset == "INat2017":
        num_classes = 5089

    model = VisionTransformer(config,
                              args.img_size,
                              zero_head=True,
                              num_classes=num_classes,
                              smoothing_value=args.smoothing_value)
    model.load_from(np.load(args.pretrained_dir))
    if args.pretrained_model is not None:
        pretrained_model = torch.load(args.pretrained_model)['model']
        model.load_state_dict(pretrained_model)
    model.to(args.device)
    num_params = count_parameters(model)

    logger.info("{}".format(config))
    logger.info("Training parameters %s", args)
    logger.info("Total Parameter: \t%2.1fM" % num_params)
    return args, model
コード例 #2
0
def setup(args):
    # Prepare model
    config = CONFIGS[args.model_type]
    config.mixup = args.mixup
    config.mixup_layer = args.mixup_layer
    config.mixup_alpha = args.mixup_alpha

    num_classes = 10 if args.dataset == "cifar10" else 100

    model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
    try:
        model.load_from(np.load(args.pretrained_dir))
    except:
        load_model(args.pretrained_dir, model)
    model.to(args.device)
    num_params = count_parameters(model)

    print("{}".format(config))
    
    for key, value in vars(args).items():
        logging.info(f"{key}: {value}")
        #print("Training parameters %s", arg)
    print("#Trainable Parameter: \t%2.2fM" % num_params)
    print(num_params)
    return args, model
コード例 #3
0
ファイル: train.py プロジェクト: gregrolwes/ViT-pytorch
def setup(args, device):
    # Prepare model
    config = CONFIGS[args.model_type]

    output_size = 256

    model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=output_size)
    model.load_from(np.load(args.pretrained_dir))
    model.to(device)
    num_params = count_parameters(model)

    logger.info("{}".format(config))
    logger.info("Training parameters %s", args)
    logger.info("Total Parameter: \t%2.1fM" % num_params)
    print(num_params)
    return args, model
コード例 #4
0
ファイル: train.py プロジェクト: LongJohnCoder/ViT-pytorch
def setup(args):
    # Prepare model
    config = CONFIGS[args.model_type]

    num_classes = 10 if args.dataset == "cifar10" else 100

    model = VisionTransformer(config,
                              args.img_size,
                              zero_head=True,
                              num_classes=num_classes)
    model.load_from(np.load(args.pretrained_dir))
    model.to(args.device)
    num_params = count_parameters(model)

    logger.info("{}".format(config))
    logger.info("Training parameters %s", args)
    logger.info("Total Parameter: \t%2.1fM" % num_params)
    return args, model
コード例 #5
0
def visualize_attn(args):
    """ 
    Visualization of learned attention map.
    """
    config = CONFIGS[args.model_type]
    num_classes = 10 if args.dataset == "cifar10" else 100

    model = VisionTransformer(config, args.img_size, 
                            norm_type=args.norm_type, 
                            zero_head=True, 
                            num_classes=num_classes,
                            vis=True)

    ckpt_file = os.path.join(args.output_dir, args.name + "_checkpoint.bin")
    ckpt = torch.load(ckpt_file)
    # use single card for visualize attn map
    model.load_state_dict(ckpt)
    model.to(args.device)
    model.eval()
    
    _, test_loader = get_loader(args)
    sample_idx = 0
    layer_ids = [0, 3, 6, 9]
    head_id = 0
    with torch.no_grad():
        for step, batch in enumerate(test_loader):
            batch = tuple(t.to(args.device) for t in batch)
            x, y = batch
            select_x = x[sample_idx].unsqueeze(0)
            output, attn_weights = model(select_x)
            # attn_weights is List[(1, number_of_head, len_h, len_h)]
            for layer_id in layer_ids:
                vis_attn(args, attn_weights[layer_id].squeeze(0)[head_id], layer_id=layer_id)
            break # visualize the first sample in the first batch
    print("done.")
    exit(0)