Exemplo n.º 1
0
import argparse
import logging
import numpy as np
from sklearn import linear_model

from attalos.dataset.dataset import Dataset
from attalos.dataset.transformers.onehot import OneHot
from attalos.evaluation.evaluation import Evaluation

import attalos.util.log.log as l

logger = l.getLogger(__name__)


def get_xy(dataset, tag_transformer=None):
    x = []
    y = []
    for idx in dataset:
        image_feats, text_feats = dataset.get_index(idx)
        if tag_transformer:
            text_feats = tag_transformer.get_multiple(text_feats)
        x.append(image_feats)
        y.append(text_feats)
    return np.asarray(x), np.asarray(y)


def train(train_dataset, test_dataset=None, tag_transformer=None, n_jobs=-1):
    x, targets = get_xy(train_dataset, tag_transformer=tag_transformer)
    model = linear_model.LinearRegression(n_jobs=n_jobs)
    logger.info("Training model.")
    model.fit(x, targets)
Exemplo n.º 2
0
import numpy as np
import tensorflow as tf

from attalos.imgtxt_algorithms.approaches.base import AttalosModel
from attalos.util.transformers.onehot import OneHot
from attalos.imgtxt_algorithms.correlation.correlation import construct_W
from attalos.imgtxt_algorithms.util.negsamp import NegativeSampler

import attalos.util.log.log as l
logger = l.getLogger(__name__)

class FastZeroTagModel(AttalosModel):
    """
    Create a tensorflow graph that finds the principal direction of the target word embeddings 
    (with negative sampling), using the loss function from "Fast Zero-Shot Image Tagging".
    """
    def __init__(self, wv_model, datasets, **kwargs):
        self.wv_model = wv_model
        self.one_hot = OneHot(datasets, valid_vocab=wv_model.vocab)
        word_counts = NegativeSampler.get_wordcount_from_datasets(datasets, self.one_hot)
        self.negsampler = NegativeSampler(word_counts)
        self.w = construct_W(wv_model, self.one_hot.get_key_ordering()).T
        self.learning_rate = kwargs.get("learning_rate", 0.0001)
        self.optim_words = kwargs.get("optim_words", True)
        self.hidden_units = kwargs.get("hidden_units", "200")
        self.use_batch_norm = kwargs.get("use_batch_norm",False)
        self.opt_type = kwargs.get("opt_type","adam")
        if self.hidden_units=='0':                                                                                                  
            self.hidden_units=[]
        else:
            self.hidden_units = [int(x) for x in self.hidden_units.split(',')]