Esempio n. 1
0
    def __init__(self, params, tf_path=None, change_way=True):
        super(LFTNet, self).__init__()

        # tf writer
        self.tf_writer = SummaryWriter(
            log_dir=tf_path) if tf_path is not None else None

        # get metric-based model and enable L2L(maml) training
        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        backbone.BatchNorm2d_fa.feature_augment = True
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        if params.method == 'protonet':
            model = protonet.ProtoNet(model_dict[params.model],
                                      feat_aug=params.feat_aug,
                                      tf_path=params.tf_dir,
                                      **train_few_shot_params)
        elif params.method == 'matchingnet':
            backbone.LSTMCell.maml = True
            model = matchingnet.MatchingNet(model_dict[params.model],
                                            tf_path=params.tf_dir,
                                            **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            relationnet.RelationConvBlock.maml = True
            relationnet.RelationModule.maml = True
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            else:
                feature_model = model_dict[params.model]
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
            model = relationnet.RelationNet(feature_model,
                                            loss_type=loss_type,
                                            tf_path=params.tf_dir,
                                            **train_few_shot_params)
        elif params.method == 'gnnnet':
            gnnnet.GnnNet.maml = True
            gnn.Gconv.maml = True
            gnn.Wcompute.maml = True
            model = gnnnet.GnnNet(model_dict[params.model],
                                  tf_path=params.tf_dir,
                                  **train_few_shot_params)
        else:
            raise ValueError('Unknown method')
        self.model = model
        print('  train with {} framework'.format(params.method))

        # optimizer
        model_params, ft_params = self.split_model_parameters()
        self.model_optim = torch.optim.Adam(model_params)
        self.ft_optim = torch.optim.Adam(ft_params, weight_decay=1e-8, lr=1e-3)

        # total epochs
        self.total_epoch = params.stop_epoch
Esempio n. 2
0
  def __init__(self, params, tf_path=None, change_way=True):
    super(LFTNet, self).__init__()

    # tf writer
    self.tf_writer = SummaryWriter(log_dir=tf_path) if tf_path is not None else None

    # get metric-based model and enable L2L(maml) training
    train_few_shot_params    = dict(n_way=params.train_n_way, n_support=params.n_shot)
    backbone.FeatureWiseTransformation2d_fw.feature_augment = True
    backbone.ConvBlock.maml = True
    backbone.SimpleBlock.maml = True
    backbone.ResNet.maml = True
    if params.method == 'protonet':
      model = protonet.ProtoNet( model_dict[params.model], tf_path=params.tf_dir, **train_few_shot_params)
    elif params.method == 'matchingnet':
      backbone.LSTMCell.maml = True
      model = matchingnet.MatchingNet( model_dict[params.model], tf_path=params.tf_dir, **train_few_shot_params)
    elif params.method in ['relationnet', 'relationnet_softmax']:
      relationnet.RelationConvBlock.maml = True
      relationnet.RelationModule.maml = True
      if params.model == 'Conv4':
        feature_model = backbone.Conv4NP
      elif params.model == 'Conv6':
        feature_model = backbone.Conv6NP
      else:
        feature_model = model_dict[params.model]
      loss_type = 'mse' if params.method == 'relationnet' else 'softmax'
      model = relationnet.RelationNet( feature_model, loss_type = loss_type, tf_path=params.tf_dir, **train_few_shot_params)
    elif params.method == 'gnnnet':
      gnnnet.GnnNet.maml=True
      gnn.Gconv.maml=True
      gnn.Wcompute.maml=True
      model = gnnnet.GnnNet(model_dict[params.model], tf_path=params.tf_dir, **train_few_shot_params)
    else:
      raise ValueError('Unknown method')
    self.model = model
    print('  train with {} framework'.format(params.method))

    # for auxiliary training
    feat_dim = self.model.feat_dim[0] if type(self.model.feat_dim) is list else self.model.feat_dim
    self.aux_classifier = nn.Sequential(
        nn.Linear(feat_dim, feat_dim),
        nn.ReLU(inplace=True),
        nn.Linear(feat_dim, feat_dim),
        nn.ReLU(inplace=True),
        nn.Linear(feat_dim, 64))
    self.aux_loss_fn = nn.CrossEntropyLoss()

    # optimizer
    model_params, ft_params = self.split_model_parameters()
    self.model_optim = torch.optim.Adam(model_params + list(self.aux_classifier.parameters()))
    self.ft_optim = torch.optim.Adam(ft_params, weight_decay=1e-8, lr=1e-3)

    # total epochs
    self.total_epoch = params.stop_epoch
Esempio n. 3
0
    def __init__(self, params):
        super(LFTNet, self).__init__()
        backbone.FeatureWiseTransformation2d_fw.feature_augment = True
        backbone.ConvBlock.FWT = True
        backbone.SimpleBlock.FWT = True
        backbone.ResNet.FWT = True

        if params.method == 'ProtoNet':
            model = ProtoNet(model_dict[params.model],
                             n_way=params.train_n_way,
                             n_support=params.n_shot)
        elif params.method == 'MatchingNet':
            backbone.LSTMCell.FWT = True
            model = MatchingNet(model_dict[params.model],
                                n_way=params.train_n_way,
                                n_support=params.n_shot)
        elif params.method == 'RelationNet':
            relationnet.RelationConvBlock.FWT = True
            relationnet.RelationModule.FWT = True
            model = relationnet.RelationNet(model_dict[params.model],
                                            n_way=params.train_n_way,
                                            n_support=params.n_shot)
        elif params.method == 'GNN':
            gnnnet.GnnNet.FWT = True
            gnn.Gconv.FWT = True
            gnn.Wcompute.FWT = True
            model = gnnnet.GnnNet(model_dict[params.model],
                                  n_way=params.train_n_way,
                                  n_support=params.n_shot)
        elif params.method == 'TPN':
            tpn.RelationNetwork.FWT = True
            model = tpn.TPN(model_dict[params.model],
                            n_way=params.train_n_way,
                            n_support=params.n_shot).cuda()
        else:
            raise ValueError('Unknown method')
        self.model = model
        print('\ttrain with {} framework'.format(params.method))

        # optimizer
        model_params = self.split_model_parameters()
        self.model_optim = torch.optim.Adam(model_params)

        # total epochs
        self.total_epoch = params.stop_epoch