Beispiel #1
0
    def __init__(self):
        BaseModelParams.__init__(self)
        self.n_save_epoch = 10
        self.n_max_save = 10
        self.r_domain = 1.0
        self.r_pair = 1.0

        self.epoch = 500
        self.margin = .1
        self.alpha = 5
        self.batch_size = 64
        self.visual_feat_dim = 4096
        #self.word_vec_dim = 300
        self.word_vec_dim = 5000
        self.lr_total = 0.0001
        self.lr_emb = 0.0001
        self.lr_domain = 0.0001
        self.lr_pair = 0.0001
        self.top_k = 50
        self.semantic_emb_dim = 40
        self.dataset_name = 'wikipedia_dataset'
        self.model_name = 'adv_semantic_zsl'
        self.model_dir = 'adv_semantic_zsl_pair_%d_%d_%d' % (self.visual_feat_dim, self.word_vec_dim, self.semantic_emb_dim)

        self.checkpoint_dir = 'checkpoint'
        self.sample_dir = 'samples'
        self.dataset_dir = './data'
        self.log_dir = 'logs'
Beispiel #2
0
    def __init__(self):
        BaseModelParams.__init__(self)

        self.batch_size = 64
        self.visual_feat_dim = 4096
        # self.word_vec_dim = 200
        self.word_vec_dim = 1000
        self.lr_emb = 0.0001
        self.lr_domain = 0.0001
        self.top_k = 50
        self.semantic_emb_dim = 40
        self.dataset_name = 'nuswide'
        self.model_name = 'adv_semantic_zsl'
        self.model_dir = 'adv_semantic_zsl_%d_%d_%d' % (self.visual_feat_dim, self.word_vec_dim, self.semantic_emb_dim)
        self.checkpoint_dir = 'checkpoint'
        self.sample_dir = 'samples'
        self.dataset_dir = './data'
        self.log_dir = 'logs'
    def __init__(self): #模型的各个参数
        BaseModelParams.__init__(self)

        self.epoch = 200 #周期
        self.batch_size = 64 #批次规模
        self.visual_feat_dim = 4096 #视觉特征维度
        #self.word_vec_dim = 300
        self.word_vec_dim = 5000 #文本词特征维度
        self.lr_emb = 0.0001 #学习率有2种,因为emb和adv(domain)的优化方向不一样(对抗)
        self.lr_domain = 0.0001
        self.top_k = 50 #topk
        self.semantic_emb_dim = 40 #语义嵌入维度
        self.dataset_name = 'wikipedia_datasete' #使用数据集名称
        self.model_name = 'adv_semantic_zsl' #模型名,用于存储
        self.model_dir = 'adv_semantic_zsl_%d_%d_%d' % (self.visual_feat_dim, self.word_vec_dim, self.semantic_emb_dim)

        #各种储存型文件的路径
        self.checkpoint_dir = 'checkpoint'
        self.sample_dir = 'samples'
        self.dataset_dir = './data'
        self.log_dir = 'logs'