Exemple #1
0
def extract_imgs_feat():
    encoder = Encoder(opt.resnet101_file)
    encoder.to(opt.device)
    encoder.eval()

    imgs = os.listdir(opt.imgs_dir)
    imgs.sort()

    if not os.path.exists(opt.out_feats_dir):
        os.makedirs(opt.out_feats_dir)
    with h5py.File(os.path.join(opt.out_feats_dir, '%s_fc.h5' % opt.dataset_name)) as file_fc, \
            h5py.File(os.path.join(opt.out_feats_dir, '%s_att.h5' % opt.dataset_name)) as file_att:
        try:
            for img_nm in tqdm.tqdm(imgs, ncols=100):
                img = skimage.io.imread(os.path.join(opt.imgs_dir, img_nm))
                with torch.no_grad():
                    img = encoder.preprocess(img)
                    img = img.to(opt.device)
                    img_fc, img_att = encoder(img)
                file_fc.create_dataset(img_nm,
                                       data=img_fc.cpu().float().numpy())
                file_att.create_dataset(img_nm,
                                        data=img_att.cpu().float().numpy())
        except BaseException as e:
            file_fc.close()
            file_att.close()
            print(
                '--------------------------------------------------------------------'
            )
            raise e
def detect(path, encoder=None, decoder=None):
    torch.backends.cudnn.benchmark = True

    dataset = LoadImages(path, img_size=config.IMAGE_SIZE, used_layers=config.USED_LAYERS)

    if not encoder or not decoder:
        in_channels = num_channels(config.USED_LAYERS)
        encoder = Encoder(in_channels=in_channels)
        decoder = Decoder(num_classes=config.NUM_CLASSES+1)
        encoder = encoder.to(config.DEVICE)
        decoder = decoder.to(config.DEVICE)

        _, encoder, decoder = load_checkpoint(encoder, decoder, config.CHECKPOINT_FILE, config.DEVICE)

    encoder.eval()
    decoder.eval()

    for _, layers, path in dataset:
        with torch.no_grad():
            layers = torch.from_numpy(layers).to(config.DEVICE, non_blocking=True)
            if layers.ndimension() == 3:
                layers = layers.unsqueeze(0)

            features = encoder(layers)
            predictions = decoder(features)
            _, out = predictions, predictions.sigmoid()

            plot_volumes(to_volume(out, config.VOXEL_THRESH).cpu(), [path], config.NAMES)
def train():
    torch.backends.cudnn.benchmark = True

    _, dataloader = create_dataloader(config.IMG_DIR + "/train", config.MESH_DIR + "/train",
                                            batch_size=config.BATCH_SIZE, used_layers=config.USED_LAYERS,
                                            img_size=config.IMAGE_SIZE, map_size=config.MAP_SIZE,
                                            augment=config.AUGMENT, workers=config.NUM_WORKERS,
                                            pin_memory=config.PIN_MEMORY, shuffle=True)

    in_channels = num_channels(config.USED_LAYERS)
    encoder = Encoder(in_channels=in_channels)
    decoder = Decoder(num_classes=config.NUM_CLASSES+1)
    encoder.apply(init_weights)
    decoder.apply(init_weights)
    encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()),
                                      lr=config.ENCODER_LEARNING_RATE,
                                      betas=config.BETAS)
    decoder_solver = torch.optim.Adam(decoder.parameters(),
                                      lr=config.DECODER_LEARNING_RATE,
                                      betas=config.BETAS)
    encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(encoder_solver,
                                                                milestones=config.ENCODER_LR_MILESTONES,
                                                                gamma=config.GAMMA)
    decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(decoder_solver,
                                                                milestones=config.DECODER_LR_MILESTONES,
                                                                gamma=config.GAMMA)
    encoder = encoder.to(config.DEVICE)
    decoder = decoder.to(config.DEVICE)

    loss_fn = LossFunction()

    init_epoch = 0
    if config.CHECKPOINT_FILE and config.LOAD_MODEL:
        init_epoch, encoder, decoder = load_checkpoint(encoder, decoder, config.CHECKPOINT_FILE, config.DEVICE)

    output_dir = os.path.join(config.OUT_PATH, re.sub("[^0-9a-zA-Z]+", "-", dt.now().isoformat()))

    for epoch_idx in range(init_epoch, config.NUM_EPOCHS):
        encoder.train()
        decoder.train()
        train_one_epoch(encoder, decoder, dataloader, loss_fn, encoder_solver, decoder_solver, epoch_idx)
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()

        if config.TEST:
            test(encoder, decoder)
        if config.SAVE_MODEL:
            save_checkpoint(epoch_idx, encoder, decoder, output_dir)

    if not config.TEST:
        test(encoder, decoder)
    if not config.SAVE_MODEL:
        save_checkpoint(config.NUM_EPOCHS - 1, encoder, decoder, output_dir)
def test(encoder=None, decoder=None):
    torch.backends.cudnn.benchmark = True

    _, dataloader = create_dataloader(config.IMG_DIR + "/test", config.MESH_DIR + "/test",
                                            batch_size=config.BATCH_SIZE, used_layers=config.USED_LAYERS,
                                            img_size=config.IMAGE_SIZE, map_size=config.MAP_SIZE,
                                            augment=config.AUGMENT, workers=config.NUM_WORKERS,
                                            pin_memory=config.PIN_MEMORY, shuffle=False)
    if not encoder or not decoder:
        in_channels = num_channels(config.USED_LAYERS)
        encoder = Encoder(in_channels=in_channels)
        decoder = Decoder(num_classes=config.NUM_CLASSES+1)
        encoder = encoder.to(config.DEVICE)
        decoder = decoder.to(config.DEVICE)

        _, encoder, decoder = load_checkpoint(encoder, decoder, config.CHECKPOINT_FILE, config.DEVICE)

    loss_fn = LossFunction()

    loop = tqdm(dataloader, leave=True)
    losses = []
    ious = []

    encoder.eval()
    decoder.eval()

    for i, (_, layers, volumes, img_files) in enumerate(loop):
        with torch.no_grad():
            layers = layers.to(config.DEVICE, non_blocking=True)
            volumes = volumes.to(config.DEVICE, non_blocking=True)

            features = encoder(layers)
            predictions = decoder(features)

            loss = loss_fn(predictions, volumes)
            losses.append(loss.item())

            iou = predictions_iou(to_volume(predictions, config.VOXEL_THRESH), volumes)
            ious.append(iou)

            mean_iou = sum(ious) / len(ious)
            mean_loss = sum(losses) / len(losses)
            loop.set_postfix(loss=mean_loss, mean_iou=mean_iou)

            if i == 0 and config.PLOT:
                plot_volumes(to_volume(predictions, config.VOXEL_THRESH).cpu(), img_files, config.NAMES)
                plot_volumes(volumes.cpu(), img_files, config.NAMES)
Exemple #5
0
# coding:utf8
import torch
import skimage.io

from opts import parse_opt
from models.decoder import Decoder
from models.encoder import Encoder

opt = parse_opt()
assert opt.test_model, 'please input test_model'
assert opt.image_file, 'please input image_file'

encoder = Encoder(opt.resnet101_file)
encoder.to(opt.device)
encoder.eval()

img = skimage.io.imread(opt.image_file)
with torch.no_grad():
    img = encoder.preprocess(img)
    img = img.to(opt.device)
    fc_feat, att_feat = encoder(img)

print("====> loading checkpoint '{}'".format(opt.test_model))
chkpoint = torch.load(opt.test_model, map_location=lambda s, l: s)
decoder = Decoder(chkpoint['idx2word'], chkpoint['settings'])
decoder.load_state_dict(chkpoint['model'])
print("====> loaded checkpoint '{}', epoch: {}, train_mode: {}".format(
    opt.test_model, chkpoint['epoch'], chkpoint['train_mode']))
decoder.to(opt.device)
decoder.eval()
Exemple #6
0
                       num_fre=args.NUM_FRE)
    VOCAB_SIZE = vocab.num_words
    SEQ_LEN = vocab.max_sentence_len

    encoder = Encoder(args.ENCODER_OUTPUT_SIZE)
    decoder = Decoder(embed_size=args.EMBED_SIZE,
                      hidden_size=args.HIDDEN_SIZE,
                      attention_size=args.ATTENTION_SIZE,
                      vocab_size=VOCAB_SIZE,
                      encoder_size=2048,
                      device=device,
                      seq_len=SEQ_LEN + 2)
    encoder.load_state_dict(torch.load(args.ENCODER_MODEL_LOAD_PATH))
    decoder.load_state_dict(torch.load(args.DECODER_MODEL_LOAD_PATH))

    encoder.to(device)
    decoder.to(device)
    encoder.eval()
    decoder.eval()
    result_json = {"images": []}
    for path in IMG_PATH:
        img_name = path.split("/")[-1]
        img = Image.open(path)
        img = transform(img).unsqueeze(0).to(
            device)  # [BATCH_SIZE(1) * CHANNEL * INPUT_SIZE * INPUT_SIZE]

        num_sentence = args.NUM_TOP_PROB
        top_prev_prob = torch.zeros((num_sentence, 1)).to(device)
        words = torch.Tensor([vocab.SOS_token]).long().expand(
            num_sentence, -1).to(
                device
def main():

    #parsing the arguments
    args, _ = parse_arguments()

    #setup logging
    #output_dir = Path('/content/drive/My Drive/image-captioning/output')
    output_dir = Path(args.output_directory)
    output_dir.mkdir(parents=True, exist_ok=True)
    logfile_path = Path(output_dir / "output.log")
    setup_logging(logfile=logfile_path)

    #setup and read config.ini
    #config_file = Path('/content/drive/My Drive/image-captioning/config.ini')
    config_file = Path('../config.ini')
    reading_config(config_file)

    #tensorboard
    tensorboard_logfile = Path(output_dir / 'tensorboard')
    tensorboard_writer = SummaryWriter(tensorboard_logfile)

    #load dataset
    #dataset_dir = Path('/content/drive/My Drive/Flickr8k_Dataset')
    dataset_dir = Path(args.dataset)
    images_path = Path(dataset_dir / Config.get("images_dir"))
    captions_path = Path(dataset_dir / Config.get("captions_dir"))
    training_loader, validation_loader, testing_loader = data_loaders(
        images_path, captions_path)

    #load the model (encoder, decoder, optimizer)
    embed_size = Config.get("encoder_embed_size")
    hidden_size = Config.get("decoder_hidden_size")
    batch_size = Config.get("training_batch_size")
    epochs = Config.get("epochs")
    feature_extraction = Config.get("feature_extraction")
    raw_captions = read_captions(captions_path)
    id_to_word, word_to_id = dictionary(raw_captions, threshold=5)
    vocab_size = len(id_to_word)
    encoder = Encoder(embed_size, feature_extraction)
    decoder = Decoder(embed_size, hidden_size, vocab_size, batch_size)

    #load pretrained embeddings
    #pretrained_emb_dir = Path('/content/drive/My Drive/word2vec')
    pretrained_emb_dir = Path(args.pretrained_embeddings)
    pretrained_emb_file = Path(pretrained_emb_dir /
                               Config.get("pretrained_emb_path"))
    pretrained_embeddings = load_pretrained_embeddings(pretrained_emb_file,
                                                       id_to_word)

    #load the optimizer
    learning_rate = Config.get("learning_rate")
    optimizer = adam_optimizer(encoder, decoder, learning_rate)

    #loss funciton
    criterion = cross_entropy

    #load checkpoint
    checkpoint_file = Path(output_dir / Config.get("checkpoint_file"))
    checkpoint_captioning = load_checkpoint(checkpoint_file)

    #using available device(gpu/cpu)
    encoder = encoder.to(Config.get("device"))
    decoder = decoder.to(Config.get("device"))
    pretrained_embeddings = pretrained_embeddings.to(Config.get("device"))

    start_epoch = 1
    if checkpoint_captioning is not None:
        start_epoch = checkpoint_captioning['epoch'] + 1
        encoder.load_state_dict(checkpoint_captioning['encoder'])
        decoder.load_state_dict(checkpoint_captioning['decoder'])
        optimizer.load_state_dict(checkpoint_captioning['optimizer'])
        logger.info(
            'Initialized encoder, decoder and optimizer from loaded checkpoint'
        )

    del checkpoint_captioning

    #image captioning model
    model = ImageCaptioning(encoder, decoder, optimizer, criterion,
                            training_loader, validation_loader, testing_loader,
                            pretrained_embeddings, output_dir,
                            tensorboard_writer)

    #training and testing the model
    if args.training:
        validate_every = Config.get("validate_every")
        model.train(epochs, validate_every, start_epoch)
    elif args.testing:
        images_path = Path(images_path / Config.get("images_dir"))
        model.testing(id_to_word, images_path)