Exemplo n.º 1
0
def load_model(model_num):
    model, _ = caption.build_model(config)
    model_num = model_num
    checkpoint = torch.load(config.checkpoint + str(model_num),
                            map_location='cpu')
    model.load_state_dict(checkpoint['model'])
    model.to(config.device)
    return model
Exemplo n.º 2
0
def v1(pretrained=False):
    config = Config()
    model, _ = caption.build_model(config)
    
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url='https://github.com/saahiluppal/catr/releases/download/0.1/weights_9348032.pth',
            map_location='cpu'
        )
        model.load_state_dict(checkpoint['model'])
    
    return model
def main(config):
    device = torch.device(config.device)
    print(f'Initializing Device: {device}')

    seed = config.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    model, criterion = caption.build_model(config)
    model.load_state_dict(torch.load("pretrained_wts/my_model.pth"))
    model.to(device)
    print("Model Loaded")

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print(f"Number of params: {n_parameters}")

    param_dicts = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            config.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=config.lr,
                                  weight_decay=config.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.lr_drop)

    dataset_train = coco.build_dataset(config, mode='training')
    dataset_val = coco.build_dataset(config, mode='validation')
    print(f"Train: {len(dataset_train)}")
    print(f"Valid: {len(dataset_val)}")

    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        config.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   num_workers=config.num_workers)
    data_loader_val = DataLoader(dataset_val,
                                 config.batch_size,
                                 sampler=sampler_val,
                                 drop_last=False,
                                 num_workers=config.num_workers)

    if os.path.exists(config.checkpoint + "19"):
        print("Loading Checkpoint...19")
        checkpoint = torch.load(config.checkpoint + "19", map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        config.start_epoch = checkpoint['epoch'] + 1
        print("Loaded Checkpoint:", config.checkpoint + "19")

    print("Start Training..")

    # epoch starts from 0
    for epoch in range(config.start_epoch, config.epochs):
        print(f"Epoch: {epoch}")
        epoch_loss = train_one_epoch(model, criterion, data_loader_train,
                                     optimizer, device, epoch,
                                     config.clip_max_norm)
        lr_scheduler.step()
        print(f"Training Loss: {epoch_loss}")

        torch.save(
            {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,  # epoch = 0 in checkpoint means epoch num 1 
            },
            config.checkpoint + str(epoch + 1)
        )  # saved checkpoint checkpoint.pth1 means chkpt for epoch num 1

        validation_loss = evaluate(model, criterion, data_loader_val, device)
        print(f"Validation Loss: {validation_loss}")

        print()
Exemplo n.º 4
0
parser.add_argument('--v', type=str, help='version', default='v3')
args = parser.parse_args()
image_path = args.path
version = args.v
"""
if version == 'v1':
    model = torch.hub.load('saahiluppal/catr', 'v1', pretrained=True)
elif version == 'v2':
    model = torch.hub.load('saahiluppal/catr', 'v2', pretrained=True)
elif version == 'v3':
    model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True)
else:
    raise NotImplementedError('Version not implemented')
"""
config = Config()
model, criterion = caption.build_model(config)
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model'])
model.eval()

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token)
end_token = tokenizer.convert_tokens_to_ids(tokenizer._sep_token)

image = Image.open(image_path)
image = coco.val_transform(image)
image = image.unsqueeze(0)


def create_caption_and_mask(start_token, max_length):
def tokenize_eng(text):
    return [tok.text for tok in spacy_eng.tokenizer(text)]


from datasets import coco
from configuration import Config
from models import caption

parser = argparse.ArgumentParser(description='Image Captioning')
parser.add_argument('--path', type=str, help='path to image', required=True)
args = parser.parse_args()
image_path = args.path

config = Config()
model, _ = caption.build_model(config)

# To load model weights, use the code below
model.backbone.load_state_dict(
    torch.load("checkpoints/checkpoint-breakdown/backbone.pth",
               map_location='cpu'))
model.input_proj.load_state_dict(
    torch.load("checkpoints/checkpoint-breakdown/input_proj.pth",
               map_location='cpu'))
model.transformer.load_state_dict(
    torch.load("checkpoints/checkpoint-breakdown/transformer.pth",
               map_location='cpu'))
model.mlp.load_state_dict(
    torch.load("checkpoints/checkpoint-breakdown/mlp.pth", map_location='cpu'))
model.to("cuda" if torch.cuda.is_available() else "cpu")
Exemplo n.º 6
0
def finetune(config):
    device = torch.device(config.device)
    print(f'Initializing Device: {device}')

    seed = config.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    model, criterion = caption.build_model(config)
    checkpoint = torch.hub.load_state_dict_from_url(
        url=
        "https://github.com/saahiluppal/catr/releases/download/0.2/weight493084032.pth",
        map_location=device)
    model.to(device)
    model.load_state_dict(checkpoint['model'])

    config.lr = 1e-5
    config.epochs = 10
    config.lr_drop = 8

    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print(f"Number of params: {n_parameters}")

    param_dicts = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            config.lr_backbone,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts,
                                  lr=config.lr,
                                  weight_decay=config.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, config.lr_drop)

    dataset_train = coco.build_dataset(config, mode='training')
    dataset_val = coco.build_dataset(config, mode='validation')
    print(f"Train: {len(dataset_train)}")
    print(f"Valid: {len(dataset_val)}")

    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)

    batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                        config.batch_size,
                                                        drop_last=True)

    data_loader_train = DataLoader(dataset_train,
                                   batch_sampler=batch_sampler_train,
                                   num_workers=config.num_workers)
    data_loader_val = DataLoader(dataset_val,
                                 config.batch_size,
                                 sampler=sampler_val,
                                 drop_last=False,
                                 num_workers=config.num_workers)

    if os.path.exists(config.checkpoint):
        print("Loading Checkpoint...")
        checkpoint = torch.load(config.checkpoint, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        config.start_epoch = checkpoint['epoch'] + 1

    print("Start Training..")
    for epoch in range(config.start_epoch, config.epochs):
        print(f"Epoch: {epoch}")
        epoch_loss = train_one_epoch(model, criterion, data_loader_train,
                                     optimizer, device, epoch,
                                     config.clip_max_norm)
        lr_scheduler.step()
        print(f"Training Loss: {epoch_loss}")

        torch.save(
            {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
            }, config.checkpoint)

        validation_loss = evaluate(model, criterion, data_loader_val, device)
        print(f"Validation Loss: {validation_loss}")

        print()