def __init__(self, input_size, hidden_size, loss_type='mse'): super(RelationModule, self).__init__() self.loss_type = loss_type padding = 1 if (input_size[1] < 10) and ( input_size[2] < 10 ) else 0 # when using Resnet, conv map without avgpooling is 7x7, need padding in block to do pooling self.layer1 = RelationConvBlock(input_size[0] * 2, input_size[0], padding=padding) self.layer2 = RelationConvBlock(input_size[0], input_size[0], padding=padding) shrink_s = lambda s: int((int( (s - 2 + 2 * padding) / 2) - 2 + 2 * padding) / 2) if self.maml: self.fc1 = backbone.Linear_fw( input_size[0] * shrink_s(input_size[1]) * shrink_s(input_size[2]), hidden_size) self.fc2 = backbone.Linear_fw(hidden_size, 1) else: self.fc1 = nn.Linear( input_size[0] * shrink_s(input_size[1]) * shrink_s(input_size[2]), hidden_size) self.fc2 = nn.Linear(hidden_size, 1)
def __init__(self, model_func, n_way, n_support, tf_path=None, approx=False, dropout_method='none', dropout_rate=0., dropout_schedule='constant'): super(MAML, self).__init__(model_func, n_way, n_support, tf_path=tf_path, change_way=False) self.loss_fn = nn.CrossEntropyLoss() self.classifier = backbone.Linear_fw(self.feat_dim, n_way) self.classifier.bias.data.fill_(0) self.batch_size = 4 self.task_update_num = 5 self.train_lr = 0.01 self.approx = approx #first order approx. self.dropout = DropGrad(dropout_method, dropout_rate, dropout_schedule) self.optimizer = torch.optim.Adam(self.parameters())
def __init__(self, model_func, n_way, n_support, tf_path=None): super(GnnNet, self).__init__(model_func, n_way, n_support, tf_path=tf_path) # loss function self.loss_fn = nn.CrossEntropyLoss() # metric function self.fc = nn.Sequential( nn.Linear(self.feat_dim, 128), nn.BatchNorm1d(128, track_running_stats=False)) if not ( self.maml or self.maml_adain) else nn.Sequential( backbone.Linear_fw(self.feat_dim, 128), backbone.BatchNorm1d_fw(128, track_running_stats=False)) self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way) self.method = 'GnnNet' # fix label for training the metric function 1*nw(1 + ns)*nw support_label = torch.from_numpy( np.repeat(range(self.n_way), self.n_support)).unsqueeze(1) support_label = torch.zeros( self.n_way * self.n_support, self.n_way).scatter(1, support_label, 1).view(self.n_way, self.n_support, self.n_way) support_label = torch.cat( [support_label, torch.zeros(self.n_way, 1, self.n_way)], dim=1) self.support_label = support_label.view(1, -1, self.n_way)
def __init__(self, input_size, output_size): super(ClassifierModule, self).__init__() if self.maml: self.classifier = backbone.Linear_fw(input_size, output_size) self.classifier.bias.data.fill_(0) else: self.classifier = nn.Linear(input_size, output_size) self.classifier.bias.data.fill_(0)
def __init__(self, model_func, n_way, n_support): super(GnnNet, self).__init__(model_func, n_way, n_support) # loss function self.loss_fn = nn.CrossEntropyLoss() # metric function self.fc = nn.Sequential(nn.Linear( self.feat_dim, 128), nn.BatchNorm1d( 128, track_running_stats=False)) if not self.FWT else nn.Sequential( backbone.Linear_fw(self.feat_dim, 128), backbone.BatchNorm1d_fw(128, track_running_stats=False)) self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way) self.method = 'GnnNet'