Exemple #1
0
def setup(args):
    # Prepare model
    config = CONFIGS[args.model_type]
    config.split = args.split
    config.slide_step = args.slide_step

    if args.dataset == "CUB_200_2011":
        num_classes = 4
    elif args.dataset == "car":
        num_classes = 196
    elif args.dataset == "nabirds":
        num_classes = 555
    elif args.dataset == "dog":
        num_classes = 120
    elif args.dataset == "INat2017":
        num_classes = 5089

    model = VisionTransformer(config,
                              args.img_size,
                              zero_head=True,
                              num_classes=num_classes,
                              smoothing_value=args.smoothing_value)
    model.load_from(np.load(args.pretrained_dir))
    if args.pretrained_model is not None:
        pretrained_model = torch.load(args.pretrained_model)['model']
        model.load_state_dict(pretrained_model)
    model.to(args.device)
    num_params = count_parameters(model)

    logger.info("{}".format(config))
    logger.info("Training parameters %s", args)
    logger.info("Total Parameter: \t%2.1fM" % num_params)
    return args, model
Exemple #2
0
def visualize_attn(args):
    """ 
    Visualization of learned attention map.
    """
    config = CONFIGS[args.model_type]
    num_classes = 10 if args.dataset == "cifar10" else 100

    model = VisionTransformer(config, args.img_size, 
                            norm_type=args.norm_type, 
                            zero_head=True, 
                            num_classes=num_classes,
                            vis=True)

    ckpt_file = os.path.join(args.output_dir, args.name + "_checkpoint.bin")
    ckpt = torch.load(ckpt_file)
    # use single card for visualize attn map
    model.load_state_dict(ckpt)
    model.to(args.device)
    model.eval()
    
    _, test_loader = get_loader(args)
    sample_idx = 0
    layer_ids = [0, 3, 6, 9]
    head_id = 0
    with torch.no_grad():
        for step, batch in enumerate(test_loader):
            batch = tuple(t.to(args.device) for t in batch)
            x, y = batch
            select_x = x[sample_idx].unsqueeze(0)
            output, attn_weights = model(select_x)
            # attn_weights is List[(1, number_of_head, len_h, len_h)]
            for layer_id in layer_ids:
                vis_attn(args, attn_weights[layer_id].squeeze(0)[head_id], layer_id=layer_id)
            break # visualize the first sample in the first batch
    print("done.")
    exit(0)
def main(args):
    savedir = "./saved-outputs/model" + str(args.base_idx) + "/"
    print('Preparing directory %s' % savedir)
    os.makedirs(savedir, exist_ok=True)
    with open(os.path.join(savedir, 'base_command.sh'), 'w') as f:
        f.write(' '.join(sys.argv))
        f.write('\n')
    
    trainloader, testloader = get_loader(args)
    
    config = CONFIGS['ViT-B_16']
    num_classes = 100
    model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
    modeldir = "./cifar100-100_500_seed_" + str(args.base_idx) + "/"  
    modelname = "cifar100-100_500_seed_" + str(args.base_idx) + "_checkpoint.bin"
    model.load_state_dict(torch.load(modeldir+modelname))
    
    simplex_model = BasicSimplex(model, num_vertices=1, fixed_points=[False]).cuda()
    del model

    ## add a new points and train ##
    for vv in range(1, args.n_verts+1):
        simplex_model.add_vert()
        simplex_model = simplex_model.cuda()
        optimizer = torch.optim.SGD(
            simplex_model.parameters(),
            lr=args.lr_init,
            momentum=0.9,
            weight_decay=args.wd
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                                               T_max=args.epochs)
        criterion = torch.nn.CrossEntropyLoss()
        columns = ['vert', 'ep', 'lr', 'tr_loss', 
                   'tr_acc', 'te_loss', 'te_acc', 'time']
        for epoch in range(args.epochs):
            time_ep = time.time()
            train_res = simp_utils.train_transformer_epoch(
                trainloader, 
                simplex_model, 
                criterion,
                optimizer,
                args.n_sample,
                vol_reg=1e-4,
                gradient_accumulation_steps=args.gradient_accumulation_steps,
            )

            start_ep = (epoch == 0)
            eval_ep = epoch % args.eval_freq == args.eval_freq - 1
            end_ep = epoch == args.epochs - 1
            # test_res = {'loss': None, 'accuracy': None}
            if eval_ep:
                test_res = simp_utils.eval(testloader, simplex_model, criterion)
            else:
                test_res = {'loss': None, 'accuracy': None}

            time_ep = time.time() - time_ep

            lr = optimizer.param_groups[0]['lr']
            scheduler.step()

            values = [vv, epoch + 1, lr, 
                      train_res['loss'], train_res['accuracy'], 
                      test_res['loss'], test_res['accuracy'], time_ep]

            table = tabulate.tabulate([values], columns, 
                                      tablefmt='simple', floatfmt='8.4f')
            if epoch % 40 == 0:
                table = table.split('\n')
                table = '\n'.join([table[1]] + table)
            else:
                table = table.split('\n')[2]
            print(table, flush=True)

        checkpoint = simplex_model.state_dict()
        fname = "lr_"+str(args.lr_init)+"simplex_vertex" + str(vv) + ".pt"
        torch.save(checkpoint, savedir + fname) 
Exemple #4
0
from models.modeling import VisionTransformer, CONFIGS

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

if __name__ == "__main__":
    config = CONFIGS["ViT-B_16"]
    model = VisionTransformer(config,
                              num_classes=5,
                              zero_head=False,
                              img_size=224,
                              vis=True)
    model.load_state_dict(torch.load("output/cassava_trial1_checkpoint.pth"))
    model.eval()

    labels = pd.read_csv(
        "/home/ligia/Documents/facultate/master-1.1/modelare_numerica/leaf_classification/dataset/train.csv"
    )

    accuracy = 0.0
    count = 0

    for path in glob(
            "/home/ligia/Documents/facultate/master-1.1/modelare_numerica/leaf_classification/dataset/train/*.jpg"
    ):
        img_name = path.split("/")[-1]

        im = Image.open(path)