Ejemplo n.º 1
0
dataloaders['training'] = torch.utils.data.DataLoader(
    datasets['training'],
    num_workers=opt.kernels,
    batch_sampler=train_data_sampler)

opt.n_classes = len(dataloaders['training'].dataset.avail_classes)
"""============================================================================"""
#################### CREATE LOGGING FILES ###############
sub_loggers = ['Train', 'Test', 'Model Grad']
LOG = logger.LOGGER(opt,
                    sub_loggers=sub_loggers,
                    start_new=True,
                    log_online=opt.log_online)
"""============================================================================"""
#################### LOSS SETUP ####################
batchminer = bmine.select(opt.batch_mining, opt)
criterion_dict = {}

for key in opt.diva_features:
    if 'discriminative' in key:
        criterion_dict[key], to_optim = criteria.select(
            opt.loss, opt, to_optim, batchminer)

if len(opt.diva_decorrelations):
    criterion_dict['separation'], to_optim = criteria.select(
        'adversarial_separation', opt, to_optim, None)
if 'selfsimilarity' in opt.diva_features:
    criterion_dict['selfsimilarity'], to_optim = criteria.select(
        opt.diva_ssl, opt, to_optim, None)
if 'invariantspread' in opt.diva_features:
    criterion_dict['invariantspread'], to_optim = criteria.select(
    datasets['training'],
    num_workers=opt.kernels,
    batch_sampler=train_data_sampler)

opt.n_classes = len(dataloaders['training'].dataset.avail_classes)
"""============================================================================"""
#################### CREATE LOGGING FILES ###############
sub_loggers = ['Train', 'Test', 'Model Grad']
if opt.use_tv_split: sub_loggers.append('Val')
LOG = logger.LOGGER(opt,
                    sub_loggers=sub_loggers,
                    start_new=True,
                    log_online=opt.log_online)
"""============================================================================"""
#################### LOSS SETUP ####################
batchminer = bmine.select(opt.batch_mining, opt)
criterion, to_optim = criteria.select(opt.loss, opt, to_optim, batchminer)
_ = criterion.to(opt.device)

if 'criterion' in train_data_sampler.name:
    train_data_sampler.internal_criterion = criterion
"""============================================================================"""
#################### OPTIM SETUP ####################
if opt.optim == 'adam':
    optimizer = torch.optim.Adam(to_optim)
elif opt.optim == 'sgd':
    optimizer = torch.optim.SGD(to_optim, momentum=0.9)
else:
    raise Exception('Optimizer <{}> not available!'.format(opt.optim))
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=opt.tau,
Ejemplo n.º 3
0
    def __init__(self, opt):
        """
        Args:
            opt: Namespace containing all relevant parameters.
        """
        super(Criterion, self).__init__()

        self.opt = opt

        #### Some base flags and parameters
        self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS
        self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
        self.REQUIRES_OPTIM = REQUIRES_OPTIM
        self.name = 'S2SD'
        self.d_mode = 'cosine'
        self.iter_count = 0
        self.embed_dim = opt.embed_dim

        ### Will contain all parameters to be optimized, e.g. the target MLPs and
        ### potential parameters of training criteria.
        self.optim_dict_list = []

        ### All S2SD-specific Parameters
        self.T = opt.loss_s2sd_T
        self.w = opt.loss_s2sd_w
        self.feat_w = opt.loss_s2sd_feat_w
        self.pool_aggr = opt.loss_s2sd_pool_aggr
        self.match_feats = opt.loss_s2sd_feat_distill
        self.max_feat_iter = opt.loss_s2sd_feat_distill_delay

        ### Initialize all target networks as two-layer MLPs
        f_dim = 1024 if 'bninception' in opt.arch else 2048
        self.target_nets = torch.nn.ModuleList([
            nn.Sequential(nn.Linear(f_dim, t_dim), nn.ReLU(),
                          nn.Linear(t_dim, t_dim))
            for t_dim in opt.loss_s2sd_target_dims
        ])
        self.optim_dict_list.append({
            'params': self.target_nets.parameters(),
            'lr': opt.lr
        })

        ### Initialize all target criteria. As each criterion may require its separate set of
        ### trainable parameters, several instances have to be created.
        old_embed_dim = copy.deepcopy(opt.embed_dim)
        self.target_criteria = nn.ModuleList()
        for t_dim in opt.loss_s2sd_target_dims:
            opt.embed_dim = t_dim

            batchminer = bmine.select(opt.batch_mining, opt)
            target_criterion = criteria.select(opt.loss_s2sd_target,
                                               opt,
                                               batchminer=batchminer)
            self.target_criteria.append(target_criterion)

            if hasattr(target_criterion, 'optim_dict_list'):
                self.optim_dict_list.extend(target_criterion.optim_dict_list)
            else:
                self.optim_dict_list.append({
                    'params':
                    target_criterion.parameters(),
                    'lr':
                    opt.lr
                })

        ### Initialize the source objective. By default the same as the target objective(s)
        opt.embed_dim = old_embed_dim
        batchminer = bmine.select(opt.batch_mining, opt)
        self.source_criterion = criteria.select(opt.loss_s2sd_source,
                                                opt,
                                                batchminer=batchminer)

        if hasattr(self.source_criterion, 'optim_dict_list'):
            self.optim_dict_list.extend(self.source_criterion.optim_dict_list)
        else:
            self.optim_dict_list.append({
                'params':
                self.source_criterion.parameters(),
                'lr':
                opt.lr
            })