Exemplo n.º 1
0
def main(args):
    # Load arguments from checkpoint
    torch.manual_seed(args.seed)
    state_dict = torch.load(args.checkpoint_path, map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args = argparse.Namespace(**{**vars(args), **vars(state_dict['args'])})
    utils.init_logging(args)

    # Load dictionary
    dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
    logging.info('Loaded a dictionary of {} words'.format(len(dictionary)))

    # Load dataset
    test_dataset = CaptionDataset(os.path.join(args.data, 'test-tokens.p'), os.path.join(args.data, 'test-features'), dictionary)

    logging.info('Created a test dataset of {} examples'.format(len(test_dataset)))
    test_loader = torch.utils.data.DataLoader(
        test_dataset, num_workers=args.num_workers, collate_fn=test_dataset.collater, pin_memory=True,
        batch_sampler=BatchSampler(test_dataset, args.max_tokens, args.batch_size, shuffle=False, seed=args.seed))

    # Build model
    model = models.build_model(args, dictionary).cuda()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {}'.format(args.checkpoint_path))

    generator = SequenceGenerator(
        model, dictionary, beam_size=args.beam_size, maxlen=args.max_len, stop_early=eval(args.stop_early),
        normalize_scores=eval(args.normalize_scores), len_penalty=args.len_penalty, unk_penalty=args.unk_penalty,
    )

    progress_bar = tqdm(test_loader, desc='| Generation', leave=False)
    for i, sample in enumerate(progress_bar):
        sample = utils.move_to_cuda(sample)
        with torch.no_grad():
            hypos = generator.generate(sample['image_features'])

        for i, (sample_id, hypos) in enumerate(zip(sample['id'].data, hypos)):
            if sample['caption_tokens'] is not None:
                target_tokens = sample['caption_tokens'].data[i, :]
                target_tokens = target_tokens[target_tokens.ne(dictionary.pad_idx)].int().cpu()
                target_str = dictionary.string(target_tokens)
                print('T-{:<6}\t{}'.format(sample_id, colored(target_str, 'green')))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.num_hypo)]):
                hypo_tokens = hypo['tokens'].int().cpu()
                hypo_str = dictionary.string(hypo_tokens)
                alignment = hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None

                print('H-{:<6}\t{}'.format(sample_id, colored(hypo_str, 'blue')))
                if hypo['positional_scores'] is not None:
                    print('P-{:<6}\t{}'.format(sample_id, ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))))
                if alignment is not None:
                    print('A-{:<6}\t{}'.format(sample_id, ' '.join(map(lambda x: str(x.item()), alignment))))
Exemplo n.º 2
0
def extract_features(args, model, image_dataset, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    data_loader = DataLoader(image_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False)
    progress_bar = tqdm(data_loader, desc='| Feature Extraction', leave=False)

    filenames = {}
    for caption_ids, image_paths, sample in progress_bar:
        image_features = model(utils.move_to_cuda(sample))
        image_features = image_features.view(*image_features.size()[:-2], -1)
        # B x C x (H x W) -> B x (H x W) x C
        image_features = image_features.transpose(1, 2)
        image_features = image_features.cpu().detach().numpy().astype(np.float32)

        for id, image_path, features in zip(caption_ids.cpu().numpy().astype(np.int32), image_paths, image_features):
            filename = os.path.join(output_dir, '{}.p'.format(str(id)))
            filenames[id] = (image_path, filename)
            with open(filename, 'wb') as file:
                pickle.dump(features, file, protocol=pickle.HIGHEST_PROTOCOL)

    with open(os.path.join(output_dir, 'metadata.p'), 'wb') as file:
        pickle.dump(filenames, file)
Exemplo n.º 3
0
def validate(args, model, criterion, valid_dataset, epoch):
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        num_workers=args.num_workers,
        collate_fn=valid_dataset.collater,
        pin_memory=True,
        batch_sampler=BatchSampler(valid_dataset,
                                   args.max_tokens,
                                   args.batch_size,
                                   shuffle=False,
                                   seed=args.seed))

    model.eval()
    stats = {'valid_loss': 0, 'num_tokens': 0, 'batch_size': 0}
    progress_bar = tqdm(valid_loader,
                        desc='| Epoch {:03d}'.format(epoch),
                        leave=False)

    for i, sample in enumerate(progress_bar):
        sample = utils.move_to_cuda(sample)
        output, _ = model(sample['image_features'], sample['caption_inputs'])
        with torch.no_grad():
            loss = criterion(output.view(-1, output.size(-1)),
                             sample['caption_tokens'].view(-1))

        stats['valid_loss'] += loss.item() / sample['num_tokens']
        stats['num_tokens'] += sample['num_tokens'] / len(
            sample['caption_inputs'])
        stats['batch_size'] += len(sample['caption_inputs'])
        progress_bar.set_postfix(
            {
                key: '{:.4g}'.format(value / (i + 1))
                for key, value in stats.items()
            },
            refresh=True)

    logging.info('Epoch {:03d}: {}'.format(
        epoch, ' | '.join(key + ' {:.4g}'.format(value / len(progress_bar))
                          for key, value in stats.items())))
    return stats['valid_loss'] / len(progress_bar)
Exemplo n.º 4
0
def main(args):
    if not torch.cuda.is_available():
        raise NotImplementedError('Training on CPU is not supported.')
    torch.manual_seed(args.seed)
    utils.init_logging(args)

    # Load dictionary
    dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
    logging.info('Loaded a dictionary of {} words'.format(len(dictionary)))

    # Load datasets
    train_dataset = CaptionDataset(os.path.join(args.data, 'train-tokens.p'), os.path.join(args.data, 'train-features'), dictionary)
    logging.info('Created a train dataset of {} examples'.format(len(train_dataset)))
    valid_dataset = CaptionDataset(os.path.join(args.data, 'valid-tokens.p'), os.path.join(args.data, 'valid-features'), dictionary)
    logging.info('Created a validation dataset of {} examples'.format(len(valid_dataset)))

    # Build model and criterion
    model = models.build_model(args, dictionary).cuda()
    logging.info('Built a model with {} parameters'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    criterion = nn.CrossEntropyLoss(ignore_index=dictionary.pad_idx, reduction='sum').cuda()

    # Build an optimizer and a learning rate schedule
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adam(trainable_params, args.lr, weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, min_lr=args.min_lr, factor=args.lr_shrink)

    # Load last checkpoint if one exists
    state_dict = utils.load_checkpoint(args, model, optimizer, lr_scheduler)
    last_epoch = state_dict['last_epoch'] if state_dict is not None else -1
    optimizer.param_groups[0]['lr'] = args.lr

    dic= {} # 0->loss, 1->lr, 4->grad_norm, 6->valid_loss
    
    for epoch in range(last_epoch + 1, args.max_epoch):
        train_loader = torch.utils.data.DataLoader(
            train_dataset, num_workers=args.num_workers, collate_fn=train_dataset.collater, pin_memory=True,
            batch_sampler=BatchSampler(train_dataset, args.max_tokens, args.batch_size, shuffle=True, seed=args.seed))

        model.train()
        stats = {'loss': 0., 'lr': 0., 'num_tokens': 0., 'batch_size': 0., 'grad_norm': 0., 'clip': 0.}
        progress_bar = tqdm(train_loader, desc='| Epoch {:03d}'.format(epoch), leave=False)

        for i, sample in enumerate(progress_bar):
            # Forward and backward pass
            sample = utils.move_to_cuda(sample)
            output, _ = model(sample['image_features'], sample['caption_inputs'])

            loss = criterion(output.view(-1, output.size(-1)), sample['caption_tokens'].view(-1))
            optimizer.zero_grad()
            loss.backward()

            # Normalize gradients by number of tokens and perform clipping
            for name, param in model.named_parameters():
                if param.grad is not None: #parche
                    param.grad.data.div_(sample['num_tokens'])
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_norm)
            optimizer.step()

            # Update statistics for progress bar
            stats['loss'] += loss.item() / sample['num_tokens']
            stats['lr'] += optimizer.param_groups[0]['lr']
            stats['num_tokens'] += sample['num_tokens'] / len(sample['caption_inputs'])
            stats['batch_size'] += len(sample['caption_inputs'])
            stats['grad_norm'] += grad_norm
            stats['clip'] += 1 if grad_norm > args.clip_norm else 0
            progress_bar.set_postfix({key: '{:.4g}'.format(value / (i + 1)) for key, value in stats.items()}, refresh=True)

        logging.info('Epoch {:03d}: {}'.format(epoch, ' | '.join(key + ' {:.4g}'.format(
            value / len(progress_bar)) for key, value in stats.items())))

        dic[epoch] = list(map(lambda x: x/len(progress_bar), stats.values()))

        # Adjust learning rate based on validation loss
        valid_loss = validate(args, model, criterion, valid_dataset, epoch)
        lr_scheduler.step(valid_loss)

        dic[epoch].append(valid_loss)

        # Save checkpoints
        if epoch % args.save_interval == 0:
            utils.save_checkpoint(args, model, optimizer, lr_scheduler, epoch, valid_loss)
        if optimizer.param_groups[0]['lr'] <= args.min_lr:
            logging.info('Done training!')
            break
    with open('logs_dict.p', 'wb') as ff: #########################
        pickle.dump(dic, ff)
def main(args):
    random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Load arguments from checkpoint (no need to load pretrained embeddings or write to log file)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args = argparse.Namespace(
        **{
            **vars(state_dict['args']),
            **vars(args), 'embed_path': None,
            'log_file': None
        })
    utils.init_logging(args)

    # Load dictionary
    dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
    logging.info('Loaded a dictionary of {} words'.format(len(dictionary)))

    # Load dataset
    test = pd.read_csv(os.path.join(args.dataset_path, args.test_caption))
    if args.caption_ids is None:
        args.caption_ids = random.sample(test['uId'].tolist(), 10)
    image_ids = [
        test.loc[test['uId'] == id]['imgId'].tolist()[0]
        for id in args.caption_ids
    ]
    reference_captions = [
        test.loc[test['uId'] == id]['report'].tolist()[0]
        for id in args.caption_ids
    ]
    image_names = [
        os.path.join(args.dataset_path, args.test_image, id + '.png')
        for id in image_ids
    ]

    # Transform image
    transform = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
    ])
    images = [transform(Image.open(filename)) for filename in image_names]
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    sample = torch.stack([transform(image.convert('RGB')) for image in images],
                         dim=0)

    # Extract image features
    vgg = vgg19(pretrained=True).eval().cuda()
    model = nn.Sequential(*list(vgg.features.children())[:-2])
    image_features = model(utils.move_to_cuda(sample))
    image_features = image_features.view(*image_features.size()[:-2], -1)
    # B x C x (H x W) -> B x (H x W) x C
    image_features = image_features.transpose(1, 2)

    # Load model and build generator
    model = models.build_model(args, dictionary).cuda()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {}'.format(
        args.checkpoint_path))
    generator = SequenceGenerator(
        model,
        dictionary,
        beam_size=args.beam_size,
        maxlen=args.max_len,
        stop_early=eval(args.stop_early),
        normalize_scores=eval(args.normalize_scores),
        len_penalty=args.len_penalty,
        unk_penalty=args.unk_penalty,
    )

    # Generate captions
    with torch.no_grad():
        hypos = generator.generate(image_features)
    for i, (id, image, reference_caption) in enumerate(
            zip(args.caption_ids, images, reference_captions)):
        output_image = os.path.join('images', '{}.png'.format(id))
        attention = hypos[i][0]['attention'].view(14, 14, -1).cpu().numpy()
        system_tokens = [
            dictionary.words[tok] for tok in hypos[i][0]['tokens']
            if tok != dictionary.eos_idx
        ]
        ''' dumpeo dict para entender mejor las predicciones y referencias'''
        prediction = "".join([
            " " +
            i if not i.startswith("'") and i not in string.punctuation else i
            for i in system_tokens
        ]).strip()
        dic = {'prediction': prediction, 'reference': reference_caption}
        with open('images/{}.p'.format(id), 'wb') as f:
            pickle.dump(dic, f)
        utils.plot_image_caption(image, output_image, system_tokens,
                                 reference_caption, attention)
Exemplo n.º 6
0
def main(args):
    random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Load arguments from checkpoint (no need to load pretrained embeddings or write to log file)
    state_dict = torch.load(
        args.checkpoint_path,
        map_location=lambda s, l: default_restore_location(s, 'cpu'))
    args = argparse.Namespace(
        **{
            **vars(state_dict['args']),
            **vars(args), 'embed_path': None,
            'log_file': None
        })
    utils.init_logging(args)

    # Load dictionary
    dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
    logging.info('Loaded a dictionary of {} words'.format(len(dictionary)))

    # Load dataset
    coco = COCO(os.path.join(args.coco_path, args.test_caption))
    if args.caption_ids is None:
        args.caption_ids = random.sample(list(coco.anns.keys()), 50)
    image_ids = [coco.anns[id]['image_id'] for id in args.caption_ids]
    reference_captions = [coco.anns[id]['caption'] for id in args.caption_ids]
    image_names = [
        os.path.join(args.coco_path, args.test_image,
                     coco.loadImgs(id)[0]['file_name']) for id in image_ids
    ]

    # Transform image
    transform = transforms.Compose([
        transforms.Resize((args.image_size, args.image_size)),
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
    ])
    images = [transform(Image.open(filename)) for filename in image_names]
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    sample = torch.stack([transform(image.convert('RGB')) for image in images],
                         dim=0)

    # Extract image features
    vgg = vgg19(pretrained=True).eval().cuda()
    model = nn.Sequential(*list(vgg.features.children())[:-2])
    image_features = model(utils.move_to_cuda(sample))
    image_features = image_features.view(*image_features.size()[:-2], -1)
    # B x C x (H x W) -> B x (H x W) x C
    image_features = image_features.transpose(1, 2)

    # Load model and build generator
    model = models.build_model(args, dictionary).cuda()
    model.load_state_dict(state_dict['model'])
    logging.info('Loaded a model from checkpoint {}'.format(
        args.checkpoint_path))
    generator = SequenceGenerator(
        model,
        dictionary,
        beam_size=args.beam_size,
        maxlen=args.max_len,
        stop_early=eval(args.stop_early),
        normalize_scores=eval(args.normalize_scores),
        len_penalty=args.len_penalty,
        unk_penalty=args.unk_penalty,
    )

    # Generate captions
    with torch.no_grad():
        hypos = generator.generate(image_features)
    for i, (id, image, reference_caption) in enumerate(
            zip(args.caption_ids, images, reference_captions)):
        output_image = os.path.join('images', '{}.jpg'.format(id))
        attention = hypos[i][0]['attention'].view(14, 14, -1).cpu().numpy()
        system_tokens = [
            dictionary.words[tok] for tok in hypos[i][0]['tokens']
            if tok != dictionary.eos_idx
        ]
        utils.plot_image_caption(image, output_image, system_tokens,
                                 reference_caption, attention)