Beispiel #1
0
def main():
    args = get_arguments()
    SETTING = Dict(yaml.safe_load(open(os.path.join('arguments',args.arg+'.yaml'), encoding='utf8')))
    print(args)
    args.device = list (map(str,args.device))
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(args.device)

    # image transformer
    transform = transforms.Compose([
        transforms.Resize((SETTING.imsize, SETTING.imsize)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
        ])

    if args.dataset == 'coco':
        val_dset = CocoDset(root=SETTING.root_path, img_dir='val2017', ann_dir='annotations/captions_val2017.json', transform=transform)
    val_loader = DataLoader(val_dset, batch_size=SETTING.batch_size, shuffle=False, num_workers=SETTING.n_cpu, collate_fn=collater)

    vocab = Vocabulary(max_len=SETTING.max_len)
    vocab.load_vocab(args.vocab_path)

    imenc = ImageEncoder(SETTING.out_size, SETTING.cnn_type)
    capenc = CaptionEncoder(len(vocab), SETTING.emb_size, SETTING.out_size, SETTING.rnn_type)

    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    imenc = imenc.to(device)
    capenc = capenc.to(device)

    assert args.checkpoint is not None
    print("loading model and optimizer checkpoint from {} ...".format(args.checkpoint), flush=True)
    ckpt = torch.load(args.checkpoint, map_location=device)
    imenc.load_state_dict(ckpt["encoder_state"])
    capenc.load_state_dict(ckpt["decoder_state"])

    begin = time.time()
    dset = EmbedDset(val_loader, imenc, capenc, vocab, args)
    print("database created | {} ".format(sec2str(time.time()-begin)), flush=True)

    savedir = os.path.join("out", args.config_name)
    if not os.path.exists(savedir):
        os.makedirs(savedir, 0o777)

    image = dset.embedded["image"]
    caption = dset.embedded["caption"]
    n_i = image.shape[0]
    n_c = caption.shape[0]
    all = np.concatenate([image, caption], axis=0)

    emb_file = os.path.join(savedir, "embedding_{}.npy".format(n_i))
    save_file = os.path.join(savedir, "{}.npy".format(SETTING.method))
    vis_file = os.path.join(savedir, "{}.png".format(SETTING.method))
    np.save(emb_file, all)
    print("saved embeddings to {}".format(emb_file), flush=True)
    dimension_reduction(emb_file, save_file, method=SETTING.method)
    plot_embeddings(save_file, n_i, vis_file, method=SETTING.method)
def main():

    args = get_arguments()
    SETTING = Dict(
        yaml.safe_load(
            open(os.path.join('arguments', args.arg + '.yaml'),
                 encoding='utf8')))
    print(args)
    args.device = list(map(str, args.device))
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(args.device)

    transform = transforms.Compose([
        transforms.Resize((SETTING.imsize, SETTING.imsize)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    if args.dataset == 'coco':
        val_dset = CocoDset(root=SETTING.root_path,
                            img_dir='val2017',
                            ann_dir='annotations/captions_val2017.json',
                            transform=transform)
    val_loader = DataLoader(val_dset,
                            batch_size=SETTING.batch_size,
                            shuffle=False,
                            num_workers=SETTING.n_cpu,
                            collate_fn=collater)

    vocab = Vocabulary(max_len=SETTING.max_len)
    vocab.load_vocab(args.vocab_path)

    imenc = ImageEncoder(SETTING.out_size, SETTING.cnn_type)
    capenc = CaptionEncoder(len(vocab), SETTING.emb_size, SETTING.out_size,
                            SETTING.rnn_type)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    imenc = imenc.to(device)
    capenc = capenc.to(device)

    assert SETTING.checkpoint is not None
    print("loading model and optimizer checkpoint from {} ...".format(
        SETTING.checkpoint),
          flush=True)
    ckpt = torch.load(SETTING.checkpoint)
    imenc.load_state_dict(ckpt["encoder_state"])
    capenc.load_state_dict(ckpt["decoder_state"])

    begin = time.time()
    dset = EmbedDset(val_loader, imenc, capenc, vocab, args)
    print("database created | {} ".format(sec2str(time.time() - begin)),
          flush=True)

    retrieve_i2c(dset, val_dset, args.image_path, imenc, transform)
    retrieve_c2i(dset, val_dset, args.output_dir, args.caption, capenc, vocab)
def main():
    args = parse_args()

    print('BATCH_SIZE: {}'.format(args.batch_size))
    print('EMBEDDING_DIM: {}'.format(args.embedding_dim))
    print('DEC_HIDDEN_DIM: {}'.format(args.dec_hidden_dim))
    print('LR: {}'.format(args.lr))
    print('ENCODER DROPOUT: {}'.format(args.enc_dropout))
    print('DECODER DROPOUT: {}'.format(args.dec_dropout))
    print('EPOCHS: {}'.format(args.epochs))
    print('LOG_INTERVAL: {}'.format(args.log_interval))
    print('USE PRETRAINED: {}'.format(args.use_pretrained))
    print('USE CURRICULUM LEARNING: {}'.format(args.use_curriculum_learning))

    # Prepare data & split
    dataset = ImageCaptionDataset(args.image_folder, args.caption_path)
    train_set, test_set = dataset.random_split(train_portion=0.8)
    train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
    test_dataloader = DataLoader(test_set, batch_size=args.batch_size)
    print('Training set size: {}'.format(len(train_set)))
    print('Test set size: {}'.format(len(test_set)))
    print('Vocab size: {}'.format(len(dataset.vocab)))
    print('----------------------------')

    # Create model & optimizer
    encoder = ImageEncoder(device, pretrained=args.use_pretrained).to(device)
    decoder = CaptionDecoder(device, len(dataset.vocab), embedding_dim=args.embedding_dim,
                             enc_hidden_dim=encoder.hidden_dim, dec_hidden_dim=args.dec_hidden_dim, dropout=args.dec_dropout,
                             use_pretrained_emb=args.use_pretrained, word_to_int=dataset.word_to_int).to(device)
    enc_optimizer = torch.optim.Adam(encoder.parameters(), lr=args.lr)
    dec_optimizer = torch.optim.Adam(decoder.parameters(), lr=args.lr)

    # Train
    train(encoder, decoder, enc_optimizer, dec_optimizer, train_dataloader, dataset, args)

    # Save model
    torch.save(encoder.cpu().state_dict(), args.output_encoder)
    torch.save(decoder.cpu().state_dict(), args.output_decoder)
    encoder.to(device)
    decoder.to(device)

    # Test
    test(encoder, decoder, test_dataloader, dataset, args)
Beispiel #4
0
def main():
    args = parse_args()

    transform = transforms.Compose([
        transforms.Resize((args.imsize, args.imsize)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    if args.dataset == 'coco':
        val_dset = CocoDataset(root=args.root_path,
                               imgdir='val2017',
                               jsonfile='annotations/captions_val2017.json',
                               transform=transform,
                               mode='all')
    val_loader = DataLoader(val_dset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.n_cpu,
                            collate_fn=collater_eval)

    vocab = Vocabulary(max_len=args.max_len)
    vocab.load_vocab(args.vocab_path)

    imenc = ImageEncoder(args.out_size, args.cnn_type)
    capenc = CaptionEncoder(len(vocab), args.emb_size, args.out_size,
                            args.rnn_type)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    imenc = imenc.to(device)
    capenc = capenc.to(device)

    assert args.checkpoint is not None
    print("loading model and optimizer checkpoint from {} ...".format(
        args.checkpoint),
          flush=True)
    ckpt = torch.load(args.checkpoint)
    imenc.load_state_dict(ckpt["encoder_state"])
    capenc.load_state_dict(ckpt["decoder_state"])

    begin = time.time()
    dset = EmbedDataset(val_loader, imenc, capenc, vocab, args)
    print("database created | {} ".format(sec2str(time.time() - begin)),
          flush=True)

    retrieve_i2c(dset, val_dset, imenc, vocab, args)
    retrieve_c2i(dset, val_dset, capenc, vocab, args)
Beispiel #5
0
class Similarity(nn.Module):
    def __init__(self, config: SimilarityConfig):
        super(Similarity, self).__init__()

        self.text_encoder = TextEncoder(config.product_text_encoder_config)
        self.text_encoder = self.text_encoder.to(GlobalConfig.device)

        self.image_encoder = ImageEncoder(config.product_image_encoder_config)
        self.image_encoder = self.image_encoder.to(GlobalConfig.device)

        self.linear = nn.Linear(config.mm_size, config.context_vector_size)
        self.linear = self.linear.to(GlobalConfig.device)

    def forward(self, context, text, text_length, image):
        """Forward.
        Args:
            context: Context (batch_size, ContextEncoderConfig.output_size).
            text: Product text (batch_size, product_text_max_len).
            text_length: Product text length (batch_size, ).
            image: Product image (batch_size, 3, image_size, image_size).
        Returns:
        """

        batch_size = context.size(0)
        sos = SOS_ID * torch.ones(batch_size, dtype=torch.long).view(-1, 1).to(
            GlobalConfig.device)
        # (batch_size)

        # Concat SOS.
        text = torch.cat((sos, text), 1).to(GlobalConfig.device)
        # (batch_size, product_text_max_len)
        text_length += 1
        # (batch_size, )

        encoded_text, _ = self.text_encoder(text, text_length)
        # (batch_size, text_feat_size)
        encoded_image = self.image_encoder(image, encoded_text)
        # (batch_size, image_feat_size)

        mm = torch.cat((encoded_text, encoded_image), 1)
        mm = mm.to(GlobalConfig.device)
        mm = self.linear(mm)
        return cosine_similarity(context, mm)
Beispiel #6
0
def train(task: int, model_file_name: str):
    """Train model.
    Args:
        task (int): Task.
        model_file_name (str): Model file name (saved or to be saved).
    """

    # Check if data exists.
    if not isfile(DatasetConfig.common_raw_data_file):
        raise ValueError('No common raw data.')

    # Load extracted common data.
    common_data: CommonData = load_pkl(DatasetConfig.common_raw_data_file)

    # Dialog data files.
    train_dialog_data_file = DatasetConfig.get_dialog_filename(
        task, TRAIN_MODE)
    valid_dialog_data_file = DatasetConfig.get_dialog_filename(
        task, VALID_MODE)
    test_dialog_data_file = DatasetConfig.get_dialog_filename(task, TEST_MODE)
    if not isfile(train_dialog_data_file):
        raise ValueError('No train dialog data file.')
    if not isfile(valid_dialog_data_file):
        raise ValueError('No valid dialog data file.')

    # Load extracted dialogs.
    train_dialogs: List[TidyDialog] = load_pkl(train_dialog_data_file)
    valid_dialogs: List[TidyDialog] = load_pkl(valid_dialog_data_file)
    test_dialogs: List[TidyDialog] = load_pkl(test_dialog_data_file)

    if task in {KNOWLEDGE_TASK}:
        knowledge_data = KnowledgeData()

    # Dataset wrap.
    train_dataset = Dataset(
        task,
        common_data.dialog_vocab,
        None,  #common_data.obj_id,
        train_dialogs,
        knowledge_data if task == KNOWLEDGE_TASK else None)
    valid_dataset = Dataset(
        task,
        common_data.dialog_vocab,
        None,  #common_data.obj_id,
        valid_dialogs,
        knowledge_data if task == KNOWLEDGE_TASK else None)
    test_dataset = Dataset(
        task,
        common_data.dialog_vocab,
        None,  #common_data.obj_id,
        test_dialogs,
        knowledge_data if task == KNOWLEDGE_TASK else None)

    print('Train dataset size:', len(train_dataset))
    print('Valid dataset size:', len(valid_dataset))
    print('Test dataset size:', len(test_dataset))

    # Get initial embedding.
    vocab_size = len(common_data.dialog_vocab)
    embed_init = get_embed_init(common_data.glove,
                                vocab_size).to(GlobalConfig.device)

    # Context model configurations.
    context_text_encoder_config = ContextTextEncoderConfig(
        vocab_size, embed_init)
    context_image_encoder_config = ContextImageEncoderConfig()
    context_encoder_config = ContextEncoderConfig()

    # Context models.
    context_text_encoder = TextEncoder(context_text_encoder_config)
    context_text_encoder = context_text_encoder.to(GlobalConfig.device)
    context_image_encoder = ImageEncoder(context_image_encoder_config)
    context_image_encoder = context_image_encoder.to(GlobalConfig.device)
    context_encoder = ContextEncoder(context_encoder_config)
    context_encoder = context_encoder.to(GlobalConfig.device)

    # Load model file.
    model_file = join(DatasetConfig.dump_dir, model_file_name)
    if isfile(model_file):
        state = torch.load(model_file)
        # if task != state['task']:
        #     raise ValueError("Task doesn't match.")
        context_text_encoder.load_state_dict(state['context_text_encoder'])
        context_image_encoder.load_state_dict(state['context_image_encoder'])
        context_encoder.load_state_dict(state['context_encoder'])

    # Task-specific parts.
    if task == INTENTION_TASK:
        intention_train(context_text_encoder, context_image_encoder,
                        context_encoder, train_dataset, valid_dataset,
                        test_dataset, model_file)
    elif task == TEXT_TASK:
        text_train(context_text_encoder, context_image_encoder,
                   context_encoder, train_dataset, valid_dataset, test_dataset,
                   model_file, common_data.dialog_vocab, embed_init)
    elif task == RECOMMEND_TASK:
        recommend_train(context_text_encoder, context_image_encoder,
                        context_encoder, train_dataset, valid_dataset,
                        test_dataset, model_file, vocab_size, embed_init)
    elif task == KNOWLEDGE_TASK:
        knowledge_attribute_train(context_text_encoder, context_image_encoder,
                                  context_encoder, train_dataset,
                                  valid_dataset, test_dataset, model_file,
                                  knowledge_data.attribute_data,
                                  common_data.dialog_vocab, embed_init)
def main():

    # ignore warnings
    #warnings.simplefilter('ignore')

    args = get_arguments()
    SETTING = Dict(yaml.safe_load(open(os.path.join('arguments',args.arg+'.yaml'), encoding='utf8')))
    print(args)
    args.device = list (map(str,args.device))
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(args.device)

    #image transformer
    train_transform = transforms.Compose([
        transforms.Resize(SETTING.imsize_pre),
        transforms.RandomCrop(SETTING.imsize),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
        ])
    val_transform = transforms.Compose([
        transforms.Resize(SETTING.imsize_pre),
        transforms.CenterCrop(SETTING.imsize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
        ])

    # data load
    if args.dataset == 'coco':
        train_dset = CocoDset(root=SETTING.root_path,img_dir='train2017', ann_dir='annotations/captions_train2017.json', transform=train_transform)
        val_dset = CocoDset(root=SETTING.root_path, img_dir='val2017', ann_dir='annotations/captions_val2017.json', transform=val_transform)
    train_loader = DataLoader(train_dset, batch_size=SETTING.batch_size, shuffle=True, num_workers=SETTING.n_cpu, collate_fn=collater)
    val_loader = DataLoader(val_dset, batch_size=SETTING.batch_size, shuffle=False, num_workers=SETTING.n_cpu, collate_fn=collater)

    # setup vocab dict
    vocab = Vocabulary(max_len=SETTING.max_len)
    vocab.load_vocab(args.vocab_path)

    # setup encoder
    imenc = ImageEncoder(SETTING.out_size, SETTING.cnn_type)
    capenc = CaptionEncoder(len(vocab), SETTING.emb_size, SETTING.out_size, SETTING.rnn_type, vocab.padidx)
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    imenc = imenc.to(device)
    capenc = capenc.to(device)

    # learning rate
    cfgs = [{'params' : imenc.fc.parameters(), 'lr' : float(SETTING.lr_cnn)},
            {'params' : capenc.parameters(), 'lr' : float(SETTING.lr_rnn)}]

    # optimizer
    if SETTING.optimizer == 'SGD':
        optimizer = optim.SGD(cfgs, momentum=SETTING.momentum, weight_decay=SETTING.weight_decay)
    elif SETTING.optimizer == 'Adam':
        optimizer = optim.Adam(cfgs, betas=(SETTING.beta1, SETTING.beta2), weight_decay=SETTING.weight_decay)
    elif SETTING.optimizer == 'RMSprop':
        optimizer = optim.RMSprop(cfgs, alpha=SETTING.alpha, weight_decay=SETTING.weight_decay)
    if SETTING.scheduler == 'Plateau':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=SETTING.dampen_factor, patience=SETTING.patience, verbose=True)
    elif SETTING.scheduler == 'Step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=SETTING.patience, gamma=SETTING.dampen_factor)

    # loss
    lossfunc = PairwiseRankingLoss(margin=SETTING.margin, method=SETTING.method, improved=args.improved, intra=SETTING.intra, lamb=SETTING.imp_weight)


    # if start from checkpoint
    if args.checkpoint is not None:
        print("loading model and optimizer checkpoint from {} ...".format(args.checkpoint), flush=True)
        ckpt = torch.load(args.checkpoint)
        imenc.load_state_dict(ckpt["encoder_state"])
        capenc.load_state_dict(ckpt["decoder_state"])
        optimizer.load_state_dict(ckpt["optimizer_state"])
        if SETTING.scheduler != 'None':
            scheduler.load_state_dict(ckpt["scheduler_state"])
        offset = ckpt["epoch"]
        data = ckpt["stats"]
        bestscore = 0
        for rank in [1, 5, 10, 20]:
            bestscore += data["i2c_recall@{}".format(rank)] + data["c2i_recall@{}".format(rank)]
        bestscore = int(bestscore)
    # start new training
    else:
        offset = 0
        bestscore = -1
    
    if args.dataparallel:
        print("Using Multiple GPU . . . ")
        imenc = nn.DataParallel(imenc)
        capenc = nn.DataParallel(capenc)

    metrics = {}
    es_cnt = 0
   
    # training
    assert offset < SETTING.max_epochs
    for ep in range(offset, SETTING.max_epochs):

        epoch = ep+1

        # unfreeze cnn parameters
        if epoch == SETTING.freeze_epoch:
            if args.dataparallel:
                optimizer.add_param_group({'params': imenc.module.cnn.parameters(), 'lr': float(SETTING.lr_cnn)})
            else:
                optimizer.add_param_group({'params': imenc.cnn.parameters(), 'lr': float(SETTING.lr_cnn)})


        #train(1epoch)
        train(epoch, train_loader, imenc, capenc, optimizer, lossfunc, vocab, args, SETTING)

        #validate
        data = validate(epoch, val_loader, imenc, capenc, vocab, args, SETTING)
        totalscore = 0
        for rank in [1, 5, 10, 20]:
            totalscore += data["i2c_recall@{}".format(rank)] + data["c2i_recall@{}".format(rank)]
        totalscore = int(totalscore)

        #scheduler update
        if SETTING.scheduler == 'Plateau':
            scheduler.step(totalscore)
        if SETTING.scheduler == 'Step':
            scheduler.step()

        # update checkpoint
        if args.dataparallel:
            ckpt = {
                    "stats": data,
                    "epoch": epoch,
                    "encoder_state": imenc.module.state_dict(),
                    "decoder_state": capenc.module.state_dict(),
                    "optimizer_state": optimizer.state_dict()
                    }
        else:
            ckpt = {
                    "stats": data,
                    "epoch": epoch,
                    "encoder_state": imenc.state_dict(),
                    "decoder_state": capenc.state_dict(),
                    "optimizer_state": optimizer.state_dict()
                    }

                
        if SETTING.scheduler != 'None':
            ckpt['scheduler_state'] = scheduler.state_dict()

        # make savedir
        savedir = os.path.join("models", args.arg)
        if not os.path.exists(savedir):
            os.makedirs(savedir)

        #
        for k, v in data.items():
            if k not in metrics.keys():
                metrics[k] = [v]
            else:
                metrics[k].append(v)

        # save checkpoint
        savepath = os.path.join(savedir, "epoch_{:04d}_score_{:03d}.ckpt".format(epoch, totalscore))
        if int(totalscore) > int(bestscore):
            print("score: {:03d}, saving model and optimizer checkpoint to {} ...".format(totalscore, savepath), flush=True)
            bestscore = totalscore
            torch.save(ckpt, savepath)
            es_cnt = 0
        else:
            print("score: {:03d}, no improvement from best score of {:03d}, not saving".format(totalscore, bestscore), flush=True)
            es_cnt += 1
            # early stopping
            if es_cnt == SETTING.es_cnt:
                print("early stopping at epoch {} because of no improvement for {} epochs".format(epoch, SETTING.es_cnt))
                break

        
        print("done for epoch {:04d}".format(epoch), flush=True)

    visualize(metrics, args, SETTING)
    print("complete training")
Beispiel #8
0
    """ 9) Load Weights """
    LOAD_WEIGHTS = True
    EMBEDDING_WEIGHT_FILE = 'checkpoints/BIGDATASET-weights-embedding-epoch-3.pt'
    ENCODER_WEIGHT_FILE = 'checkpoints/BIGDATASET-weights-encoder-epoch-3.pt'
    DECODER_WEIGHT_FILE = 'checkpoints/BIGDATASET-weights-decoder-epoch-3.pt'
    if LOAD_WEIGHTS:
        print("Loading pretrained weights...")
        word_embedding.load_state_dict(torch.load(EMBEDDING_WEIGHT_FILE))
        image_encoder.load_state_dict(torch.load(ENCODER_WEIGHT_FILE))
        image_decoder.load_state_dict(torch.load(DECODER_WEIGHT_FILE))
    """ 10) Device Setup"""
    device = 'cuda:1'
    device = torch.device(device)

    word_embedding = word_embedding.to(device)
    image_encoder = image_encoder.to(device)
    image_decoder = image_decoder.to(device)

    print(vocab.word_to_index('yooooo'))

    for i, batch in enumerate(val_loader):
        image_batch, word_ids_batch = batch[0].to(device), batch[1].to(device)

        for image in image_batch:
            sentence = generate_caption(image, image_encoder, image_decoder,
                                        word_embedding, vocab, device)
            image = denormalize(image.cpu())
            plt.imshow(image)
            plt.title(sentence)
            plt.show()
            plt.pause(1)
Beispiel #9
0
def main():
    args = parse_args()

    transform = transforms.Compose([
        transforms.Resize((args.imsize, args.imsize)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    if args.dataset == "coco":
        val_dset = CocoDataset(
            root=args.root_path,
            split="val",
            transform=transform,
        )
    val_loader = DataLoader(
        val_dset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.n_cpu,
        collate_fn=collater,
    )

    vocab = Vocabulary(max_len=args.max_len)
    vocab.load_vocab(args.vocab_path)

    imenc = ImageEncoder(args.out_size, args.cnn_type)
    capenc = CaptionEncoder(len(vocab), args.emb_size, args.out_size,
                            args.rnn_type)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    imenc = imenc.to(device)
    capenc = capenc.to(device)

    assert args.checkpoint is not None
    print("loading model and optimizer checkpoint from {} ...".format(
        args.checkpoint),
          flush=True)
    ckpt = torch.load(args.checkpoint, map_location=device)
    imenc.load_state_dict(ckpt["encoder_state"])
    capenc.load_state_dict(ckpt["decoder_state"])

    begin = time.time()
    dset = EmbedDataset(val_loader, imenc, capenc, vocab, args)
    print("database created | {} ".format(sec2str(time.time() - begin)),
          flush=True)

    savedir = os.path.join("out", args.config_name)
    if not os.path.exists(savedir):
        os.makedirs(savedir, 0o777)

    image = dset.embedded["image"]
    caption = dset.embedded["caption"]
    n_i = image.shape[0]
    n_c = caption.shape[0]
    all = np.concatenate([image, caption], axis=0)

    emb_file = os.path.join(savedir, "embedding_{}.npy".format(n_i))
    save_file = os.path.join(savedir, "{}.npy".format(args.method))
    vis_file = os.path.join(savedir, "{}.png".format(args.method))
    np.save(emb_file, all)
    print("saved embeddings to {}".format(emb_file), flush=True)
    dimension_reduction(emb_file, save_file, method=args.method)
    plot_embeddings(save_file, n_i, vis_file, method=args.method)
Beispiel #10
0
def main():
    args = parse_args()

    transform = transforms.Compose([
        transforms.Resize((args.imsize, args.imsize)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    if args.dataset == 'coco':
        train_dset = CocoDataset(root=args.root_path,
                                 transform=transform,
                                 mode='one')
        val_dset = CocoDataset(root=args.root_path,
                               imgdir='val2017',
                               jsonfile='annotations/captions_val2017.json',
                               transform=transform,
                               mode='all')
    train_loader = DataLoader(train_dset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.n_cpu,
                              collate_fn=collater_train)
    val_loader = DataLoader(val_dset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.n_cpu,
                            collate_fn=collater_eval)

    vocab = Vocabulary(max_len=args.max_len)
    vocab.load_vocab(args.vocab_path)

    imenc = ImageEncoder(args.out_size, args.cnn_type)
    capenc = CaptionEncoder(len(vocab), args.emb_size, args.out_size,
                            args.rnn_type)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    imenc = imenc.to(device)
    capenc = capenc.to(device)

    optimizer = optim.SGD([{
        'params': imenc.parameters(),
        'lr': args.lr_cnn,
        'momentum': args.mom_cnn
    }, {
        'params': capenc.parameters(),
        'lr': args.lr_rnn,
        'momentum': args.mom_rnn
    }])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='max',
                                                     factor=0.1,
                                                     patience=args.patience,
                                                     verbose=True)
    lossfunc = PairwiseRankingLoss(margin=args.margin,
                                   method=args.method,
                                   improved=args.improved,
                                   intra=args.intra)

    if args.checkpoint is not None:
        print("loading model and optimizer checkpoint from {} ...".format(
            args.checkpoint),
              flush=True)
        ckpt = torch.load(args.checkpoint)
        imenc.load_state_dict(ckpt["encoder_state"])
        capenc.load_state_dict(ckpt["decoder_state"])
        optimizer.load_state_dict(ckpt["optimizer_state"])
        scheduler.load_state_dict(ckpt["scheduler_state"])
        offset = ckpt["epoch"]
    else:
        offset = 0
    imenc = nn.DataParallel(imenc)
    capenc = nn.DataParallel(capenc)

    metrics = {}

    assert offset < args.max_epochs
    for ep in range(offset, args.max_epochs):
        imenc, capenc, optimizer = train(ep + 1, train_loader, imenc, capenc,
                                         optimizer, lossfunc, vocab, args)
        data = validate(ep + 1, val_loader, imenc, capenc, vocab, args)
        totalscore = 0
        for rank in [1, 5, 10, 20]:
            totalscore += data["i2c_recall@{}".format(rank)] + data[
                "c2i_recall@{}".format(rank)]
        scheduler.step(totalscore)

        # save checkpoint
        ckpt = {
            "stats": data,
            "epoch": ep + 1,
            "encoder_state": imenc.module.state_dict(),
            "decoder_state": capenc.module.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict()
        }
        if not os.path.exists(args.model_save_path):
            os.makedirs(args.model_save_path)
        savepath = os.path.join(
            args.model_save_path,
            "epoch_{:04d}_score_{:05d}.ckpt".format(ep + 1,
                                                    int(100 * totalscore)))
        print(
            "saving model and optimizer checkpoint to {} ...".format(savepath),
            flush=True)
        torch.save(ckpt, savepath)
        print("done for epoch {}".format(ep + 1), flush=True)

        for k, v in data.items():
            if k not in metrics.keys():
                metrics[k] = [v]
            else:
                metrics[k].append(v)

    visualize(metrics, args)