def __init__(self, args, emb_size): super(MetricNN, self).__init__() self.metric_network = args.metric_network self.emb_size = emb_size self.args = args if self.metric_network == 'gnn_iclr_nl': assert (self.args.train_N_way == self.args.test_N_way) num_inputs = self.emb_size + self.args.train_N_way if self.args.dataset == 'mini_imagenet': self.gnn_obj_patch = gnn_iclr.GNN_nl_patch(args, self.emb_size, nf=96, J=1) self.gnn_obj = gnn_iclr.GNN_nl(args, num_inputs, nf=96, J=1) elif 'omniglot' in self.args.dataset: self.gnn_obj = gnn_iclr.GNN_nl_omniglot(args, num_inputs, nf=96, J=1) elif self.metric_network == 'gnn_iclr_active': assert (self.args.train_N_way == self.args.test_N_way) num_inputs = self.emb_size + self.args.train_N_way self.gnn_obj = gnn_iclr.GNN_active(args, num_inputs, 96, J=1) else: raise NotImplementedError
def __init__(self, args, emb_size): super(MetricNN, self).__init__() self.metric_network = args.metric_network self.emb_size = emb_size self.args = args assert (self.args.train_N_way == self.args.test_N_way) num_inputs = self.emb_size + self.args.train_N_way #就是论文里面x的长度 #nf 什么 if self.metric_network == 'gnn_iclr_nl': if self.args.dataset == 'mini_imagenet': self.gnn_obj = gnn_iclr.GNN_nl(args, num_inputs, nf=96, J=1) #这个J没有什么用啊,里面都是2,这里设1? elif 'omniglot' in self.args.dataset: self.gnn_obj = gnn_iclr.GNN_nl_omniglot(args, num_inputs, nf=96, J=1) elif self.metric_network == 'gnn_iclr_active': self.gnn_obj = gnn_iclr.GNN_active(args, num_inputs, nf=96, J=1) else: raise NotImplementedError