def __init__(self, config, device):
        super(ESRGAN_EESN_Model, self).__init__(config, device)
        self.configG = config['network_G']
        self.configD = config['network_D']
        self.configT = config['train']
        self.configO = config['optimizer']['args']
        self.configS = config['lr_scheduler']
        self.device = device
        #Generator
        self.netG = model.ESRGAN_EESN(in_nc=self.configG['in_nc'],
                                      out_nc=self.configG['out_nc'],
                                      nf=self.configG['nf'],
                                      nb=self.configG['nb'])
        self.netG = self.netG.to(self.device)
        self.netG = DataParallel(self.netG, device_ids=[1, 0])

        #descriminator
        self.netD = model.Discriminator_VGG_128(in_nc=self.configD['in_nc'],
                                                nf=self.configD['nf'])
        self.netD = self.netD.to(self.device)
        self.netD = DataParallel(self.netD, device_ids=[1, 0])

        self.netG.train()
        self.netD.train()
        #print(self.configT['pixel_weight'])
        # G CharbonnierLoss for final output SR and GT HR
        self.cri_charbonnier = CharbonnierLoss().to(device)
        # G pixel loss
        if self.configT['pixel_weight'] > 0.0:
            l_pix_type = self.configT['pixel_criterion']
            if l_pix_type == 'l1':
                self.cri_pix = nn.L1Loss().to(self.device)
            elif l_pix_type == 'l2':
                self.cri_pix = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_pix_type))
            self.l_pix_w = self.configT['pixel_weight']
        else:
            self.cri_pix = None

        # G feature loss
        #print(self.configT['feature_weight']+1)
        if self.configT['feature_weight'] > 0:
            l_fea_type = self.configT['feature_criterion']
            if l_fea_type == 'l1':
                self.cri_fea = nn.L1Loss().to(self.device)
            elif l_fea_type == 'l2':
                self.cri_fea = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError(
                    'Loss type [{:s}] not recognized.'.format(l_fea_type))
            self.l_fea_w = self.configT['feature_weight']
        else:
            self.cri_fea = None
        if self.cri_fea:  # load VGG perceptual loss
            self.netF = model.VGGFeatureExtractor(feature_layer=34,
                                                  use_input_norm=True,
                                                  device=self.device)
            self.netF = self.netF.to(self.device)
            self.netF = DataParallel(self.netF, device_ids=[1, 0])
            self.netF.eval()

        # GD gan loss
        self.cri_gan = GANLoss(self.configT['gan_type'], 1.0,
                               0.0).to(self.device)
        self.l_gan_w = self.configT['gan_weight']
        # D_update_ratio and D_init_iters
        self.D_update_ratio = self.configT['D_update_ratio'] if self.configT[
            'D_update_ratio'] else 1
        self.D_init_iters = self.configT['D_init_iters'] if self.configT[
            'D_init_iters'] else 0

        # optimizers
        # G
        wd_G = self.configO['weight_decay_G'] if self.configO[
            'weight_decay_G'] else 0
        optim_params = []
        for k, v in self.netG.named_parameters(
        ):  # can optimize for a part of the model
            if v.requires_grad:
                optim_params.append(v)

        self.optimizer_G = torch.optim.Adam(optim_params,
                                            lr=self.configO['lr_G'],
                                            weight_decay=wd_G,
                                            betas=(self.configO['beta1_G'],
                                                   self.configO['beta2_G']))
        self.optimizers.append(self.optimizer_G)

        # D
        wd_D = self.configO['weight_decay_D'] if self.configO[
            'weight_decay_D'] else 0
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                            lr=self.configO['lr_D'],
                                            weight_decay=wd_D,
                                            betas=(self.configO['beta1_D'],
                                                   self.configO['beta2_D']))
        self.optimizers.append(self.optimizer_D)

        # schedulers
        if self.configS['type'] == 'MultiStepLR':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.MultiStepLR_Restart(
                        optimizer,
                        self.configS['args']['lr_steps'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights'],
                        gamma=self.configS['args']['lr_gamma'],
                        clear_state=False))
        elif self.configS['type'] == 'CosineAnnealingLR_Restart':
            for optimizer in self.optimizers:
                self.schedulers.append(
                    lr_scheduler.CosineAnnealingLR_Restart(
                        optimizer,
                        self.configS['args']['T_period'],
                        eta_min=self.configS['args']['eta_min'],
                        restarts=self.configS['args']['restarts'],
                        weights=self.configS['args']['restart_weights']))
        else:
            raise NotImplementedError(
                'MultiStepLR learning rate scheme is enough.')
        print(self.configS['args']['restarts'])
        self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    train_set = create_dataset(**cfg.train_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    model = Model(local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=len(train_set.ids2labels))
    # Model wrapper
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################
    #id_criterion = nn.CrossEntropyLoss()
    id_criterion = SoftmaxEntropyLoss()
    g_tri_loss = TripletLoss(margin=cfg.global_margin)
    l_tri_loss = TripletLoss(margin=cfg.local_margin)

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.base_lr,
                           weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        use_local_distance = (cfg.l_loss_weight > 0) \
                             and cfg.local_dist_own_hard_sample

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=cfg.normalize_feature,
                          use_local_distance=use_local_distance)

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        if cfg.lr_decay_type == 'exp':
            adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs,
                          cfg.exp_decay_at_epoch)
        else:
            adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        g_prec_meter = AverageMeter()
        g_m_meter = AverageMeter()
        g_dist_ap_meter = AverageMeter()
        g_dist_an_meter = AverageMeter()
        g_loss_meter = AverageMeter()

        l_prec_meter = AverageMeter()
        l_m_meter = AverageMeter()
        l_dist_ap_meter = AverageMeter()
        l_dist_an_meter = AverageMeter()
        l_loss_meter = AverageMeter()

        id_loss_meter = AverageMeter()

        sift_loss_meter = AverageMeter()

        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, cam_lables, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = TVT(torch.from_numpy(labels).long())
            labels_var = Variable(labels_t)

            feat, global_feat, local_feat, logits = model_w(ims_var)
            sift_func = ExtractSift()
            sift = torch.from_numpy(sift_func(ims_var)).cuda()

            g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
                g_tri_loss,
                global_feat,
                labels_t,
                normalize_feature=cfg.normalize_feature)

            if cfg.l_loss_weight == 0:
                l_loss = 0
            elif cfg.local_dist_own_hard_sample:
                # Let local distance find its own hard samples.
                l_loss, l_dist_ap, l_dist_an, _ = local_loss(
                    l_tri_loss,
                    local_feat,
                    None,
                    None,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)
            else:
                l_loss, l_dist_ap, l_dist_an = local_loss(
                    l_tri_loss,
                    local_feat,
                    p_inds,
                    n_inds,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)

            id_loss = 0
            if cfg.id_loss_weight > 0:
                id_loss = id_criterion(logits, labels_var)

            sift_loss = 0
            if cfg.sift_loss_weight > 0:
                sift_loss = torch.norm(
                    F.softmax(global_feat, dim=1) - F.softmax(sift, dim=1))

            loss = g_loss * cfg.g_loss_weight \
                   + l_loss * cfg.l_loss_weight \
                   + id_loss * cfg.id_loss_weight \
                   + sift_loss * cfg.sift_loss_weight

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            # precision
            g_prec = (g_dist_an > g_dist_ap).data.float().mean()
            # the proportion of triplets that satisfy margin
            g_m = (g_dist_an >
                   g_dist_ap + cfg.global_margin).data.float().mean()
            g_d_ap = g_dist_ap.data.mean()
            g_d_an = g_dist_an.data.mean()

            g_prec_meter.update(g_prec)
            g_m_meter.update(g_m)
            g_dist_ap_meter.update(g_d_ap)
            g_dist_an_meter.update(g_d_an)
            g_loss_meter.update(to_scalar(g_loss))

            if cfg.l_loss_weight > 0:
                # precision
                l_prec = (l_dist_an > l_dist_ap).data.float().mean()
                # the proportion of triplets that satisfy margin
                l_m = (l_dist_an >
                       l_dist_ap + cfg.local_margin).data.float().mean()
                l_d_ap = l_dist_ap.data.mean()
                l_d_an = l_dist_an.data.mean()

                l_prec_meter.update(l_prec)
                l_m_meter.update(l_m)
                l_dist_ap_meter.update(l_d_ap)
                l_dist_an_meter.update(l_d_an)
                l_loss_meter.update(to_scalar(l_loss))

            if cfg.id_loss_weight > 0:
                id_loss_meter.update(to_scalar(id_loss))

            if cfg.sift_loss_weight > 0:
                sift_loss_meter.update(to_scalar(sift_loss))

            loss_meter.update(to_scalar(loss))

            if step % cfg.log_steps == 0:
                time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
                    step,
                    ep + 1,
                    time.time() - step_st,
                )

                if cfg.g_loss_weight > 0:
                    g_log = (', gp {:.2%}, gm {:.2%}, '
                             'gd_ap {:.4f}, gd_an {:.4f}, '
                             'gL {:.4f}'.format(
                                 g_prec_meter.val,
                                 g_m_meter.val,
                                 g_dist_ap_meter.val,
                                 g_dist_an_meter.val,
                                 g_loss_meter.val,
                             ))
                else:
                    g_log = ''

                if cfg.l_loss_weight > 0:
                    l_log = (', lp {:.2%}, lm {:.2%}, '
                             'ld_ap {:.4f}, ld_an {:.4f}, '
                             'lL {:.4f}'.format(
                                 l_prec_meter.val,
                                 l_m_meter.val,
                                 l_dist_ap_meter.val,
                                 l_dist_an_meter.val,
                                 l_loss_meter.val,
                             ))
                else:
                    l_log = ''

                if cfg.id_loss_weight > 0:
                    id_log = (', idL {:.4f}'.format(id_loss_meter.val))
                else:
                    id_log = ''

                if cfg.sift_loss_weight > 0:
                    sift_log = (', sL {:.4f}'.format(sift_loss_meter.val))
                else:
                    sift_log = ''

                total_loss_log = ', loss {:.4f}'.format(loss_meter.val)

                log = time_log + \
                      g_log + l_log + id_log + \
                      sift_log + total_loss_log
                print(log)

        #############
        # Epoch Log #
        #############

        time_log = 'Ep {}, {:.2f}s'.format(
            ep + 1,
            time.time() - ep_st,
        )

        if cfg.g_loss_weight > 0:
            g_log = (', gp {:.2%}, gm {:.2%}, '
                     'gd_ap {:.4f}, gd_an {:.4f}, '
                     'gL {:.4f}'.format(
                         g_prec_meter.avg,
                         g_m_meter.avg,
                         g_dist_ap_meter.avg,
                         g_dist_an_meter.avg,
                         g_loss_meter.avg,
                     ))
        else:
            g_log = ''

        if cfg.l_loss_weight > 0:
            l_log = (', lp {:.2%}, lm {:.2%}, '
                     'ld_ap {:.4f}, ld_an {:.4f}, '
                     'lL {:.4f}'.format(
                         l_prec_meter.avg,
                         l_m_meter.avg,
                         l_dist_ap_meter.avg,
                         l_dist_an_meter.avg,
                         l_loss_meter.avg,
                     ))
        else:
            l_log = ''

        if cfg.id_loss_weight > 0:
            id_log = (', idL {:.4f}'.format(id_loss_meter.avg))
        else:
            id_log = ''

        if cfg.sift_loss_weight > 0:
            sift_log = (', sL {:.4f}'.format(sift_loss_meter.avg))
        else:
            sift_log = ''

        total_loss_log = ', loss {:.4f}'.format(loss_meter.avg)

        log = time_log + \
              g_log + l_log + id_log + \
              sift_log + total_loss_log
        print(log)

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars(
                'loss',
                dict(
                    global_loss=g_loss_meter.avg,
                    local_loss=l_loss_meter.avg,
                    id_loss=id_loss_meter.avg,
                    sift_loss=sift_loss_meter.avg,
                    loss=loss_meter.avg,
                ), ep)
            writer.add_scalars(
                'tri_precision',
                dict(
                    global_precision=g_prec_meter.avg,
                    local_precision=l_prec_meter.avg,
                ), ep)
            writer.add_scalars(
                'satisfy_margin',
                dict(
                    global_satisfy_margin=g_m_meter.avg,
                    local_satisfy_margin=l_m_meter.avg,
                ), ep)
            writer.add_scalars(
                'global_dist',
                dict(
                    global_dist_ap=g_dist_ap_meter.avg,
                    global_dist_an=g_dist_an_meter.avg,
                ), ep)
            writer.add_scalars(
                'local_dist',
                dict(
                    local_dist_ap=l_dist_ap_meter.avg,
                    local_dist_an=l_dist_an_meter.avg,
                ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
Exemple #3
0
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    train_set = create_dataset(**cfg.train_set_kwargs)
    num_classes = len(train_set.ids2labels)
    # The combined dataset does not provide val set currently.
    val_set = None if cfg.dataset == 'combined' else create_dataset(
        **cfg.val_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    model = Model(last_conv_stride=cfg.last_conv_stride,
                  num_stripes=cfg.num_stripes,
                  local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=num_classes)
    # Model wrapper
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################

    criterion = torch.nn.CrossEntropyLoss()

    # To finetune from ImageNet weights
    finetuned_params = list(model.base.parameters())
    # To train from scratch
    new_params = [
        p for n, p in model.named_parameters() if not n.startswith('base.')
    ]
    param_groups = [{
        'params': finetuned_params,
        'lr': cfg.finetuned_params_lr
    }, {
        'params': new_params,
        'lr': cfg.new_params_lr
    }]
    optimizer = optim.SGD(param_groups,
                          momentum=cfg.momentum,
                          weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=True, verbose=True)

    def validate():
        if val_set.extract_feat_func is None:
            val_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n===== Test on validation set =====\n')
        mAP, cmc_scores, _, _ = val_set.eval(normalize_feat=True,
                                             to_re_rank=False,
                                             verbose=True)
        print()
        return mAP, cmc_scores[0]

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        adjust_lr_staircase(optimizer.param_groups,
                            [cfg.finetuned_params_lr, cfg.new_params_lr],
                            ep + 1, cfg.staircase_decay_at_epochs,
                            cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        # For recording loss
        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_var = Variable(TVT(torch.from_numpy(labels).long()))

            _, logits_list = model_w(ims_var)
            loss = torch.sum(
                torch.cat(
                    [criterion(logits, labels_var) for logits in logits_list]))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            loss_meter.update(to_scalar(loss))

            if step % cfg.steps_per_log == 0:
                log = '\tStep {}/Ep {}, {:.2f}s, loss {:.4f}'.format(
                    step, ep + 1,
                    time.time() - step_st, loss_meter.val)
                print(log)

        #############
        # Epoch Log #
        #############

        log = 'Ep {}, {:.2f}s, loss {:.4f}'.format(ep + 1,
                                                   time.time() - ep_st,
                                                   loss_meter.avg)
        print(log)

        ##########################
        # Test on Validation Set #
        ##########################

        mAP, Rank1 = 0, 0
        if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
            mAP, Rank1 = validate()

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1), ep)
            writer.add_scalars('loss', dict(loss=loss_meter.avg, ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
Exemple #4
0
    def __init__(self, opt):
        super(VideoSRBaseModel, self).__init__(opt)

        if opt["dist"]:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt["train"]

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt["dist"]:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            #### loss
            loss_type = train_opt["pixel_criterion"]
            if loss_type == "l1":
                self.cri_pix = nn.L1Loss(reduction="sum").to(self.device)
            elif loss_type == "l2":
                self.cri_pix = nn.MSELoss(reduction="sum").to(self.device)
            elif loss_type == "cb":
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    "Loss type [{:s}] is not recognized.".format(loss_type))
            self.cri_aligned = (nn.L1Loss(reduction="sum").to(self.device)
                                if train_opt["aligned_criterion"] else None)
            self.l_pix_w = train_opt["pixel_weight"]

            #### optimizers
            wd_G = train_opt["weight_decay_G"] if train_opt[
                "weight_decay_G"] else 0
            if train_opt["ft_tsa_only"]:
                normal_params = []
                tsa_fusion_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        if "tsa_fusion" in k:
                            tsa_fusion_params.append(v)
                        else:
                            normal_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                "Params [{:s}] will not optimize.".format(k))
                optim_params = [
                    {  # add normal params first
                        "params": normal_params,
                        "lr": train_opt["lr_G"],
                    },
                    {"params": tsa_fusion_params, "lr": train_opt["lr_G"]},
                ]
            else:
                optim_params = []
                for k, v in self.netG.named_parameters():
                    if v.requires_grad:
                        optim_params.append(v)
                    else:
                        if self.rank <= 0:
                            logger.warning(
                                "Params [{:s}] will not optimize.".format(k))

            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1"], train_opt["beta2"]),
            )
            self.optimizers.append(self.optimizer_G)

            #### schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt["lr_steps"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                            gamma=train_opt["lr_gamma"],
                            clear_state=train_opt["clear_state"],
                        ))
            elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt["T_period"],
                            eta_min=train_opt["eta_min"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                        ))
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()
Exemple #5
0
                          train_cfg=cfg.train_cfg,
                          test_cfg=cfg.test_cfg)
    logger.info('-' * 20 + 'finish build model' + '-' * 20)
    logger.info('Total Parameters: %d,   Trainable Parameters: %s',
                model.net_parameters['Total'],
                str(model.net_parameters['Trainable']))
    # build dataset
    datasets = build_dataset(cfg.data.train)
    logger.info('-' * 20 + 'finish build dataset' + '-' * 20)
    # put model on gpu
    if torch.cuda.is_available():
        if len(cfg.gpu_ids) == 1:
            model = model.cuda()
            logger.info('-' * 20 + 'model to one gpu' + '-' * 20)
        else:
            model = DataParallel(model.cuda(), device_ids=cfg.gpu_ids)
            logger.info('-' * 20 + 'model to multi gpus' + '-' * 20)
    # create data_loader
    data_loader = build_dataloader(datasets, cfg.data.samples_per_gpu,
                                   cfg.data.workers_per_gpu, len(cfg.gpu_ids))
    logger.info('-' * 20 + 'finish build dataloader' + '-' * 20)
    # create optimizer
    optimizer = build_optimizer(model, cfg.optimizer)
    Scheduler = build_scheduler(cfg.lr_config)
    logger.info('-' * 20 + 'finish build optimizer' + '-' * 20)

    visualizer = Visualizer()
    vis = visdom.Visdom()
    criterion_ssim_loss = build_loss(cfg.loss_ssim)
    criterion_l1_loss = build_loss(cfg.loss_l1)
    ite_num = 0
Exemple #6
0
 def on_train_begin(self, **kwargs):
     self.learn.model = DataParallel(self.learn.model)
Exemple #7
0
 def before_fit(self): self.learn.model = DataParallel(self.learn.model, device_ids=self.device_ids)
 def after_fit(self): self.learn.model = self.learn.model.module
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVTs, TMOs, relative_device_ids = set_devices_for_ml(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    train_set = create_dataset(**cfg.train_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    models = [
        Model(local_conv_out_channels=cfg.local_conv_out_channels,
              num_classes=len(train_set.ids2labels))
        for _ in range(cfg.num_models)
    ]
    # Model wrappers
    model_ws = [
        DataParallel(models[i], device_ids=relative_device_ids[i])
        for i in range(cfg.num_models)
    ]

    #############################
    # Criteria and Optimizers   #
    #############################

    id_criterion = nn.CrossEntropyLoss()
    g_tri_loss = TripletLoss(margin=cfg.global_margin)
    l_tri_loss = TripletLoss(margin=cfg.local_margin)

    optimizers = [
        optim.Adam(m.parameters(),
                   lr=cfg.base_lr,
                   weight_decay=cfg.weight_decay) for m in models
    ]

    # Bind them together just to save some codes in the following usage.
    modules_optims = models + optimizers

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizers
    # is to cope with the case when you load the checkpoint to a new device.
    for TMO, model, optimizer in zip(TMOs, models, optimizers):
        TMO([model, optimizer])

    ########
    # Test #
    ########

    # Test each model using different distance settings.
    def test(load_model_weight=False):
        if load_model_weight:
            load_ckpt(modules_optims, cfg.ckpt_file)

        use_local_distance = (cfg.l_loss_weight > 0) \
                             and cfg.local_dist_own_hard_sample

        for i, (model_w, TVT) in enumerate(zip(model_ws, TVTs)):
            for test_set, name in zip(test_sets, test_set_names):
                test_set.set_feat_func(ExtractFeature(model_w, TVT))
                print(
                    '\n=========> Test Model #{} on dataset: {} <=========\n'.
                    format(i + 1, name))
                test_set.eval(normalize_feat=cfg.normalize_feature,
                              use_local_distance=use_local_distance)

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    # Storing things that can be accessed cross threads.

    ims_list = [None for _ in range(cfg.num_models)]
    labels_list = [None for _ in range(cfg.num_models)]

    done_list1 = [False for _ in range(cfg.num_models)]
    done_list2 = [False for _ in range(cfg.num_models)]

    probs_list = [None for _ in range(cfg.num_models)]
    g_dist_mat_list = [None for _ in range(cfg.num_models)]
    l_dist_mat_list = [None for _ in range(cfg.num_models)]

    # Two phases for each model:
    # 1) forward and single-model loss;
    # 2) further add mutual loss and backward.
    # The 2nd phase is only ready to start when the 1st is finished for
    # all models.
    run_event1 = threading.Event()
    run_event2 = threading.Event()

    # This event is meant to be set to stop threads. However, as I found, with
    # `daemon` set to true when creating threads, manually stopping is
    # unnecessary. I guess some main-thread variables required by sub-threads
    # are destroyed when the main thread ends, thus the sub-threads throw errors
    # and exit too.
    # Real reason should be further explored.
    exit_event = threading.Event()

    # The function to be called by threads.
    def thread_target(i):
        while not exit_event.isSet():
            # If the run event is not set, the thread just waits.
            if not run_event1.wait(0.001): continue

            ######################################
            # Phase 1: Forward and Separate Loss #
            ######################################

            TVT = TVTs[i]
            model_w = model_ws[i]
            ims = ims_list[i]
            labels = labels_list[i]
            optimizer = optimizers[i]

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = TVT(torch.from_numpy(labels).long())
            labels_var = Variable(labels_t)

            global_feat, local_feat, logits = model_w(ims_var)
            probs = F.softmax(logits)
            log_probs = F.log_softmax(logits)

            g_loss, p_inds, n_inds, g_dist_ap, g_dist_an, g_dist_mat = global_loss(
                g_tri_loss,
                global_feat,
                labels_t,
                normalize_feature=cfg.normalize_feature)

            if cfg.l_loss_weight == 0:
                l_loss, l_dist_mat = 0, 0
            elif cfg.local_dist_own_hard_sample:
                # Let local distance find its own hard samples.
                l_loss, l_dist_ap, l_dist_an, l_dist_mat = local_loss(
                    l_tri_loss,
                    local_feat,
                    None,
                    None,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)
            else:
                l_loss, l_dist_ap, l_dist_an = local_loss(
                    l_tri_loss,
                    local_feat,
                    p_inds,
                    n_inds,
                    labels_t,
                    normalize_feature=cfg.normalize_feature)
                l_dist_mat = 0

            id_loss = 0
            if cfg.id_loss_weight > 0:
                id_loss = id_criterion(logits, labels_var)

            probs_list[i] = probs
            g_dist_mat_list[i] = g_dist_mat
            l_dist_mat_list[i] = l_dist_mat

            done_list1[i] = True

            # Wait for event to be set, meanwhile checking if need to exit.
            while True:
                phase2_ready = run_event2.wait(0.001)
                if exit_event.isSet():
                    return
                if phase2_ready:
                    break

            #####################################
            # Phase 2: Mutual Loss and Backward #
            #####################################

            # Probability Mutual Loss (KL Loss)
            pm_loss = 0
            if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
                for j in range(cfg.num_models):
                    if j != i:
                        pm_loss += F.kl_div(log_probs,
                                            TVT(probs_list[j]).detach(), False)
                pm_loss /= 1. * (cfg.num_models - 1) * len(ims)

            # Global Distance Mutual Loss (L2 Loss)
            gdm_loss = 0
            if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
                for j in range(cfg.num_models):
                    if j != i:
                        gdm_loss += torch.sum(
                            torch.pow(
                                g_dist_mat - TVT(g_dist_mat_list[j]).detach(),
                                2))
                gdm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims)

            # Local Distance Mutual Loss (L2 Loss)
            ldm_loss = 0
            if (cfg.num_models > 1) \
                and cfg.local_dist_own_hard_sample \
                and (cfg.ldm_loss_weight > 0):
                for j in range(cfg.num_models):
                    if j != i:
                        ldm_loss += torch.sum(
                            torch.pow(
                                l_dist_mat - TVT(l_dist_mat_list[j]).detach(),
                                2))
                ldm_loss /= 1. * (cfg.num_models - 1) * len(ims) * len(ims)

            loss = g_loss * cfg.g_loss_weight \
                   + l_loss * cfg.l_loss_weight \
                   + id_loss * cfg.id_loss_weight \
                   + pm_loss * cfg.pm_loss_weight \
                   + gdm_loss * cfg.gdm_loss_weight \
                   + ldm_loss * cfg.ldm_loss_weight

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ##################################
            # Step Log For One of the Models #
            ##################################

            # These meters are outer-scope variables

            # Just record for the first model
            if i == 0:

                # precision
                g_prec = (g_dist_an > g_dist_ap).data.float().mean()
                # the proportion of triplets that satisfy margin
                g_m = (g_dist_an >
                       g_dist_ap + cfg.global_margin).data.float().mean()
                g_d_ap = g_dist_ap.data.mean()
                g_d_an = g_dist_an.data.mean()

                g_prec_meter.update(g_prec)
                g_m_meter.update(g_m)
                g_dist_ap_meter.update(g_d_ap)
                g_dist_an_meter.update(g_d_an)
                g_loss_meter.update(to_scalar(g_loss))

                if cfg.l_loss_weight > 0:
                    # precision
                    l_prec = (l_dist_an > l_dist_ap).data.float().mean()
                    # the proportion of triplets that satisfy margin
                    l_m = (l_dist_an >
                           l_dist_ap + cfg.local_margin).data.float().mean()
                    l_d_ap = l_dist_ap.data.mean()
                    l_d_an = l_dist_an.data.mean()

                    l_prec_meter.update(l_prec)
                    l_m_meter.update(l_m)
                    l_dist_ap_meter.update(l_d_ap)
                    l_dist_an_meter.update(l_d_an)
                    l_loss_meter.update(to_scalar(l_loss))

                if cfg.id_loss_weight > 0:
                    id_loss_meter.update(to_scalar(id_loss))

                if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
                    pm_loss_meter.update(to_scalar(pm_loss))

                if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
                    gdm_loss_meter.update(to_scalar(gdm_loss))

                if (cfg.num_models > 1) \
                    and cfg.local_dist_own_hard_sample \
                    and (cfg.ldm_loss_weight > 0):
                    ldm_loss_meter.update(to_scalar(ldm_loss))

                loss_meter.update(to_scalar(loss))

            ###################
            # End Up One Step #
            ###################

            run_event1.clear()
            run_event2.clear()

            done_list2[i] = True

    threads = []
    for i in range(cfg.num_models):
        thread = threading.Thread(target=thread_target, args=(i, ))
        # Set the thread in daemon mode, so that the main program ends normally.
        thread.daemon = True
        thread.start()
        threads.append(thread)

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        for optimizer in optimizers:
            if cfg.lr_decay_type == 'exp':
                adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs,
                              cfg.exp_decay_at_epoch)
            else:
                adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1,
                                    cfg.staircase_decay_at_epochs,
                                    cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        epoch_done = False

        g_prec_meter = AverageMeter()
        g_m_meter = AverageMeter()
        g_dist_ap_meter = AverageMeter()
        g_dist_an_meter = AverageMeter()
        g_loss_meter = AverageMeter()

        l_prec_meter = AverageMeter()
        l_m_meter = AverageMeter()
        l_dist_ap_meter = AverageMeter()
        l_dist_an_meter = AverageMeter()
        l_loss_meter = AverageMeter()

        id_loss_meter = AverageMeter()

        # Global Distance Mutual Loss
        gdm_loss_meter = AverageMeter()
        # Local Distance Mutual Loss
        ldm_loss_meter = AverageMeter()
        # Probability Mutual Loss
        pm_loss_meter = AverageMeter()

        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            for i in range(cfg.num_models):
                ims_list[i] = ims
                labels_list[i] = labels
                done_list1[i] = False
                done_list2[i] = False

            run_event1.set()
            # Waiting for phase 1 done
            while not all(done_list1):
                continue

            run_event2.set()
            # Waiting for phase 2 done
            while not all(done_list2):
                continue

            ############
            # Step Log #
            ############

            if step % cfg.log_steps == 0:
                time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
                    step,
                    ep + 1,
                    time.time() - step_st,
                )

                if cfg.g_loss_weight > 0:
                    g_log = (', gp {:.2%}, gm {:.2%}, '
                             'gd_ap {:.4f}, gd_an {:.4f}, '
                             'gL {:.4f}'.format(
                                 g_prec_meter.val,
                                 g_m_meter.val,
                                 g_dist_ap_meter.val,
                                 g_dist_an_meter.val,
                                 g_loss_meter.val,
                             ))
                else:
                    g_log = ''

                if cfg.l_loss_weight > 0:
                    l_log = (', lp {:.2%}, lm {:.2%}, '
                             'ld_ap {:.4f}, ld_an {:.4f}, '
                             'lL {:.4f}'.format(
                                 l_prec_meter.val,
                                 l_m_meter.val,
                                 l_dist_ap_meter.val,
                                 l_dist_an_meter.val,
                                 l_loss_meter.val,
                             ))
                else:
                    l_log = ''

                if cfg.id_loss_weight > 0:
                    id_log = (', idL {:.4f}'.format(id_loss_meter.val))
                else:
                    id_log = ''

                if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
                    pm_log = (', pmL {:.4f}'.format(pm_loss_meter.val))
                else:
                    pm_log = ''

                if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
                    gdm_log = (', gdmL {:.4f}'.format(gdm_loss_meter.val))
                else:
                    gdm_log = ''

                if (cfg.num_models > 1) \
                    and cfg.local_dist_own_hard_sample \
                    and (cfg.ldm_loss_weight > 0):
                    ldm_log = (', ldmL {:.4f}'.format(ldm_loss_meter.val))
                else:
                    ldm_log = ''

                total_loss_log = ', loss {:.4f}'.format(loss_meter.val)

                log = time_log + \
                      g_log + l_log + id_log + \
                      pm_log + gdm_log + ldm_log + \
                      total_loss_log
                print(log)

        #############
        # Epoch Log #
        #############

        time_log = 'Ep {}, {:.2f}s'.format(
            ep + 1,
            time.time() - ep_st,
        )

        if cfg.g_loss_weight > 0:
            g_log = (', gp {:.2%}, gm {:.2%}, '
                     'gd_ap {:.4f}, gd_an {:.4f}, '
                     'gL {:.4f}'.format(
                         g_prec_meter.avg,
                         g_m_meter.avg,
                         g_dist_ap_meter.avg,
                         g_dist_an_meter.avg,
                         g_loss_meter.avg,
                     ))
        else:
            g_log = ''

        if cfg.l_loss_weight > 0:
            l_log = (', lp {:.2%}, lm {:.2%}, '
                     'ld_ap {:.4f}, ld_an {:.4f}, '
                     'lL {:.4f}'.format(
                         l_prec_meter.avg,
                         l_m_meter.avg,
                         l_dist_ap_meter.avg,
                         l_dist_an_meter.avg,
                         l_loss_meter.avg,
                     ))
        else:
            l_log = ''

        if cfg.id_loss_weight > 0:
            id_log = (', idL {:.4f}'.format(id_loss_meter.avg))
        else:
            id_log = ''

        if (cfg.num_models > 1) and (cfg.pm_loss_weight > 0):
            pm_log = (', pmL {:.4f}'.format(pm_loss_meter.avg))
        else:
            pm_log = ''

        if (cfg.num_models > 1) and (cfg.gdm_loss_weight > 0):
            gdm_log = (', gdmL {:.4f}'.format(gdm_loss_meter.avg))
        else:
            gdm_log = ''

        if (cfg.num_models > 1) \
            and cfg.local_dist_own_hard_sample \
            and (cfg.ldm_loss_weight > 0):
            ldm_log = (', ldmL {:.4f}'.format(ldm_loss_meter.avg))
        else:
            ldm_log = ''

        total_loss_log = ', loss {:.4f}'.format(loss_meter.avg)

        log = time_log + \
              g_log + l_log + id_log + \
              pm_log + gdm_log + ldm_log + \
              total_loss_log
        print(log)

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars(
                'loss',
                dict(
                    global_loss=g_loss_meter.avg,
                    local_loss=l_loss_meter.avg,
                    id_loss=id_loss_meter.avg,
                    pm_loss=pm_loss_meter.avg,
                    gdm_loss=gdm_loss_meter.avg,
                    ldm_loss=ldm_loss_meter.avg,
                    loss=loss_meter.avg,
                ), ep)
            writer.add_scalars(
                'tri_precision',
                dict(
                    global_precision=g_prec_meter.avg,
                    local_precision=l_prec_meter.avg,
                ), ep)
            writer.add_scalars(
                'satisfy_margin',
                dict(
                    global_satisfy_margin=g_m_meter.avg,
                    local_satisfy_margin=l_m_meter.avg,
                ), ep)
            writer.add_scalars(
                'global_dist',
                dict(
                    global_dist_ap=g_dist_ap_meter.avg,
                    global_dist_an=g_dist_an_meter.avg,
                ), ep)
            writer.add_scalars(
                'local_dist',
                dict(
                    local_dist_ap=l_dist_ap_meter.avg,
                    local_dist_an=l_dist_an_meter.avg,
                ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
Exemple #9
0
    def __init__(self, opt):
        super(B_Model, self).__init__(opt)

        if opt["dist"]:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt["dist"]:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()]
            )
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            train_opt = opt["train"]
            # self.init_model() # Not use init is OK, since Pytorch has its owen init (by default)
            self.netG.train()

            # loss
            loss_type = train_opt["pixel_criterion"]
            if loss_type == "l1":
                self.cri_pix = nn.L1Loss().to(self.device)
            elif loss_type == "l2":
                self.cri_pix = nn.MSELoss().to(self.device)
            elif loss_type == "cb":
                self.cri_pix = CharbonnierLoss().to(self.device)
            else:
                raise NotImplementedError(
                    "Loss type [{:s}] is not recognized.".format(loss_type)
                )
            self.l_pix_w = train_opt["pixel_weight"]

            # optimizers
            wd_G = train_opt["weight_decay_G"] if train_opt["weight_decay_G"] else 0
            optim_params = []
            for (
                k,
                v,
            ) in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning("Params [{:s}] will not optimize.".format(k))
            self.optimizer_G = torch.optim.Adam(
                optim_params,
                lr=train_opt["lr_G"],
                weight_decay=wd_G,
                betas=(train_opt["beta1"], train_opt["beta2"]),
            )
            # self.optimizer_G = torch.optim.SGD(optim_params, lr=train_opt['lr_G'], momentum=0.9)
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt["lr_scheme"] == "MultiStepLR":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt["lr_steps"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                            gamma=train_opt["lr_gamma"],
                            clear_state=train_opt["clear_state"],
                        )
                    )
            elif train_opt["lr_scheme"] == "CosineAnnealingLR_Restart":
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt["T_period"],
                            eta_min=train_opt["eta_min"],
                            restarts=train_opt["restarts"],
                            weights=train_opt["restart_weights"],
                        )
                    )
            else:
                print("MultiStepLR learning rate scheme is enough.")

            self.log_dict = OrderedDict()
Exemple #10
0
    test_data = data[split_idx:]
    print('train samples:', len(training_data))
    print('test samples:', len(test_data))

    model = GatedCNN(seq_len, vocab_size, embd_size, n_layers, kernel, out_chs,
                     res_block_count, vocab_size)
    cuda = None
    if torch.cuda.is_available():
        print("cuda")
        model.cuda()
        cuda = True
    else:
        cuda = False

    distributed_mode = False
    if distributed_mode:
        # sampler = DistributedSampler(training_data, num_replicas=world_size, rank=rank)
        if cuda:
            model = DataParallel(model)
        else:
            model = DistributedDataParallelCPU(model)

    #non-distributed. set the model to DataParallel to increase training speed
    elif not distributed_mode and cuda:
        model = torch.nn.DataParallel(model)

optimizer = torch.optim.Adadelta(model.parameters())
loss_fn = nn.NLLLoss()
train(model, training_data, test_data, optimizer, loss_fn)
# test(model, test_data)
Exemple #11
0
    def __init__(self,
                 model,
                 regime=None,
                 criterion=None,
                 label_smoothing=0,
                 print_freq=10,
                 eval_freq=1000,
                 save_freq=1000,
                 grad_clip=None,
                 embedding_grad_clip=None,
                 max_tokens=None,
                 chunk_batch=1,
                 duplicates=1,
                 save_info={},
                 save_path='.',
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 avg_loss_time=True,
                 distributed=False,
                 local_rank=0,
                 dtype=torch.float,
                 loss_scale=1,
                 device_ids=None,
                 device="cuda"):
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion or CrossEntropyLoss(
            ignore_index=PAD,
            smooth_eps=label_smoothing,
            reduction='sum',
            from_logits=False)

        self.optimizer = OptimRegime(self.model,
                                     regime=regime,
                                     use_float_copy=dtype == torch.float16)
        self.grad_clip = grad_clip
        self.embedding_grad_clip = embedding_grad_clip
        self.epoch = 0
        self.training_steps = 0
        self.save_info = save_info
        self.device = device
        self.dtype = dtype
        self.loss_scale = loss_scale
        self.max_tokens = max_tokens
        self.chunk_batch = chunk_batch
        self.duplicates = duplicates
        self.print_freq = print_freq
        self.eval_freq = eval_freq
        self.perplexity = float('inf')
        self.device_ids = device_ids
        self.avg_loss_time = avg_loss_time
        self.model_with_loss = AddLossModule(self.model, self.criterion)
        self.distributed = distributed
        self.local_rank = local_rank
        if distributed:
            self.model_with_loss = DistributedDataParallel(
                self.model_with_loss,
                device_ids=[local_rank],
                output_device=local_rank)
        else:
            if isinstance(self.device_ids, tuple):
                self.model_with_loss = DataParallel(
                    self.model_with_loss,
                    self.device_ids,
                    dim=0 if self.batch_first else 1)
        self.save_path = save_path
        self.save_freq = save_freq
        self.checkpoint_filename = checkpoint_filename
        self.keep_checkpoints = keep_checkpoints + 1
        results_file = os.path.join(save_path, 'results')
        self.results = ResultsLog(results_file,
                                  params=save_info.get('config', None))
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None
    print(cfg.sys_device_ids)
    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    if not cfg.only_test:
        train_set = create_dataset(**cfg.train_set_kwargs)

    test_sets = []
    test_set_names = []

    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########
    if cfg.partial_dataset == 'holistic':
        model = Model(last_conv_stride=cfg.last_conv_stride)
        # Model wrapper
        model_w = DataParallel(model)
    else:
        model = PartialModel(last_conv_stride=cfg.last_conv_stride)
        # Model wrapper
        model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################
    tri_loss = TripletLoss(margin=cfg.margin)
    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.base_lr,
                           weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)

                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            test_set.eval(normalize_feat=cfg.normalize_feature, verbose=True)

    def validate():
        if val_set.extract_feat_func is None:
            val_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n=========> Test on validation set <=========\n')
        mAP, cmc_scores, _, _ = val_set.eval(
            normalize_feat=cfg.normalize_feature,
            to_re_rank=False,
            verbose=False)
        print()
        return mAP, cmc_scores[0]

    if cfg.only_test:
        test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        if cfg.lr_decay_type == 'exp':
            adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs,
                          cfg.exp_decay_at_epoch)
        else:
            adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        # For recording precision, satisfying margin, etc
        prec_meter = AverageMeter()
        sm_meter = AverageMeter()
        dist_ap_meter = AverageMeter()
        dist_an_meter = AverageMeter()
        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False

        while not epoch_done:
            step += 1
            step_st = time.time()
            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = TVT(torch.from_numpy(labels).long())
            feat, spatialFeature = model_w(ims_var)

            loss, p_inds, n_inds, dist_ap, dist_an, dist_mat = global_loss(
                tri_loss,
                feat,
                spatialFeature,
                labels_t,
                normalize_feature=cfg.normalize_feature,
                cfg.spatial_train)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ############
            # Step Log #
            ###########
            # precision
            prec = (dist_an > dist_ap).data.float().mean()
            # the proportion of triplets that satisfy margin
            sm = (dist_an > dist_ap + cfg.margin).data.float().mean()
            # average (anchor, positive) distance
            d_ap = dist_ap.data.mean()
            # average (anchor, negative) distance
            d_an = dist_an.data.mean()

            prec_meter.update(prec)
            sm_meter.update(sm)
            dist_ap_meter.update(d_ap)
            dist_an_meter.update(d_an)
            loss_meter.update(to_scalar(loss))
            if step % cfg.steps_per_log == 0:
                time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
                    step,
                    ep + 1,
                    time.time() - step_st,
                )

                tri_log = (', prec {:.2%}, sm {:.2%}, '
                           'd_ap {:.4f}, d_an {:.4f}, '
                           'loss {:.4f}'.format(
                               prec_meter.val,
                               sm_meter.val,
                               dist_ap_meter.val,
                               dist_an_meter.val,
                               loss_meter.val,
                           ))

                log = time_log + tri_log
                print(log)

        #############
        # Epoch Log #
        #############

        time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st)

        tri_log = (', prec {:.2%}, sm {:.2%}, '
                   'd_ap {:.4f}, d_an {:.4f}, '
                   'loss {:.4f}'.format(
                       prec_meter.avg,
                       sm_meter.avg,
                       dist_ap_meter.avg,
                       dist_an_meter.avg,
                       loss_meter.avg,
                   ))

        log = time_log + tri_log
        print(log)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    #######
    test(load_model_weight=False)
Exemple #13
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)

        train_opt = opt['train']

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        self.netG = DataParallel(self.netG)

        self.netD = networks.define_D(opt).to(self.device)
        self.netD = DataParallel(self.netD)
        if self.is_train:
            self.netG.train()
            self.netD.train()

        if not self.is_train and 'attack' in self.opt:
            # G pixel loss
            if opt['pixel_weight'] > 0:
                l_pix_type = opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if opt['feature_weight'] > 0:
                l_fea_type = opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(opt['gan_type'], 1.0, 0.0).to(self.device)
            self.l_gan_w = opt['gan_weight']

        self.delta = 0

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    logger.warning(
                        'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
Exemple #14
0
    def __init__(self, opt):
        super(CLSGAN_Model, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        G_opt = opt['network_G']

        # define networks and load pretrained models
        self.netG = RCAN(G_opt).to(self.device)
        self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = Discriminator_VGG_256(3, G_opt['nf']).to(self.device)
            self.netD = DataParallel(self.netD)
            self.netG.train()
            self.netD.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = VGGFeatureExtractor(feature_layer=34,
                                                use_bn=False,
                                                use_input_norm=True,
                                                device=self.device).to(
                                                    self.device)
                self.netF = DataParallel(self.netF)

            # G feature loss
            if train_opt['cls_weight'] > 0:
                l_cls_type = train_opt['cls_criterion']
                if l_cls_type == 'CE':
                    self.cri_cls = nn.NLLLoss().to(self.device)
                elif l_cls_type == 'l1':
                    self.cri_cls = nn.L1Loss().to(self.device)
                elif l_cls_type == 'l2':
                    self.cri_cls = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_cls_type))
                self.l_cls_w = train_opt['cls_weight']
            else:
                logger.info('Remove classification loss.')
                self.cri_cls = None
            if self.cri_cls:  # load VGG perceptual loss
                self.netC = VGGFeatureExtractor(feature_layer=49,
                                                use_bn=True,
                                                use_input_norm=True,
                                                device=self.device).to(
                                                    self.device)
                load_path_C = self.opt['path']['pretrain_model_C']
                assert load_path_C is not None, "Must get Pretrained Classfication prior."
                self.netC.load_model(load_path_C)
                self.netC = DataParallel(self.netC)

            if train_opt['brc_weight'] > 0:
                self.l_brc_w = train_opt['brc_weight']
                self.netR = VGG_Classifier().to(self.device)
                load_path_C = self.opt['path']['pretrain_model_C']
                assert load_path_C is not None, "Must get Pretrained Classfication prior."
                self.netR.load_model(load_path_C)
                self.netR = DataParallel(self.netR)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
Exemple #15
0
def main():
    start_time = time.monotonic()

    # init distributed training
    args, cfg = parge_config()
    dist = init_dist(cfg)
    set_random_seed(cfg.TRAIN.seed, cfg.TRAIN.deterministic)
    synchronize()

    # init logging file
    logger = Logger(cfg.work_dir / "log.txt", debug=False)
    sys.stdout = logger
    print("==========\nArgs:{}\n==========".format(args))
    log_config_to_file(cfg)

    # build train loader
    train_loader, train_sets = build_train_dataloader(cfg, joint=False)

    # build model
    model = build_model(cfg, 0, init=cfg.MODEL.source_pretrained)
    model.cuda()
    if dist:
        ddp_cfg = {
            "device_ids": [cfg.gpu],
            "output_device": cfg.gpu,
            "find_unused_parameters": True,
        }
        model = DistributedDataParallel(model, **ddp_cfg)
    elif cfg.total_gpus > 1:
        model = DataParallel(model)

    # build optimizer
    optimizer = build_optimizer([model], **cfg.TRAIN.OPTIM)

    # build lr_scheduler
    if cfg.TRAIN.SCHEDULER.lr_scheduler is not None:
        lr_scheduler = build_lr_scheduler(optimizer, **cfg.TRAIN.SCHEDULER)
    else:
        lr_scheduler = None

    # build loss functions
    num_memory = 0
    for idx, set in enumerate(train_sets):
        if idx in cfg.TRAIN.unsup_dataset_indexes:
            # instance-level memory for unlabeled data
            num_memory += len(set)
        else:
            # class-level memory for labeled data
            num_memory += set.num_pids

    if isinstance(model, (DataParallel, DistributedDataParallel)):
        num_features = model.module.num_features
    else:
        num_features = model.num_features

    criterions = build_loss(
        cfg.TRAIN.LOSS,
        num_features=num_features,
        num_memory=num_memory,
        cuda=True,
    )

    # init memory
    loaders, datasets = build_val_dataloader(
        cfg, for_clustering=True, all_datasets=True
    )
    memory_features = []
    for idx, (loader, dataset) in enumerate(zip(loaders, datasets)):
        features = extract_features(
            model, loader, dataset, with_path=False, prefix="Extract: ",
        )
        assert features.size(0) == len(dataset)
        if idx in cfg.TRAIN.unsup_dataset_indexes:
            # init memory for unlabeled data with instance features
            memory_features.append(features)
        else:
            # init memory for labeled data with class centers
            centers_dict = collections.defaultdict(list)
            for i, (_, pid, _) in enumerate(dataset):
                centers_dict[pid].append(features[i].unsqueeze(0))
            centers = [
                torch.cat(centers_dict[pid], 0).mean(0)
                for pid in sorted(centers_dict.keys())
            ]
            memory_features.append(torch.stack(centers, 0))
    del loaders, datasets

    memory_features = torch.cat(memory_features)
    criterions["hybrid_memory"]._update_feature(memory_features)

    # build runner
    runner = SpCLRunner(
        cfg,
        model,
        optimizer,
        criterions,
        train_loader,
        train_sets=train_sets,
        lr_scheduler=lr_scheduler,
        meter_formats={"Time": ":.3f",},
        reset_optim=False,
    )

    # resume
    if args.resume_from:
        runner.resume(args.resume_from)

    # start training
    runner.run()

    # load the best model
    runner.resume(cfg.work_dir / "model_best.pth")

    # final testing
    test_loaders, queries, galleries = build_test_dataloader(cfg)
    for i, (loader, query, gallery) in enumerate(zip(test_loaders, queries, galleries)):
        cmc, mAP = test_reid(
            cfg, model, loader, query, gallery, dataset_name=cfg.TEST.datasets[i]
        )

    # print time
    end_time = time.monotonic()
    print("Total running time: ", timedelta(seconds=end_time - start_time))
Exemple #16
0
    def test_ema_hook_cuda(self):
        ema = ExponentialMovingAverageHook(**self.default_config)
        cuda_runner = SimpleRunner()
        cuda_runner.model = cuda_runner.model.cuda()
        ema.after_train_iter(cuda_runner)

        module_a = cuda_runner.model.module_a
        module_a_ema = cuda_runner.model.module_a_ema

        ema_states = module_a_ema.state_dict()
        assert torch.equal(ema_states['a'], torch.tensor([1., 2.]).cuda())

        module_a.b /= 2.
        module_a.a.data /= 2.
        module_a.c /= 2.

        cuda_runner.iter += 1
        ema.after_train_iter(cuda_runner)
        ema_states = module_a_ema.state_dict()
        assert torch.equal(cuda_runner.model.module_a.a,
                           torch.tensor([0.5, 1.]).cuda())
        assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]).cuda())
        assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]).cuda())
        assert 'c' not in ema_states

        # test before run
        ema = ExponentialMovingAverageHook(**self.default_config)
        self.runner.model = SimpleModelNoEMA().cuda()
        self.runner.model = DataParallel(self.runner.model)
        self.runner.iter = 0
        ema.before_run(self.runner)
        assert hasattr(self.runner.model.module, 'module_a_ema')

        module_a = self.runner.model.module.module_a
        module_a_ema = self.runner.model.module.module_a_ema

        ema.after_train_iter(self.runner)
        ema_states = module_a_ema.state_dict()
        assert torch.equal(ema_states['a'], torch.tensor([1., 2.]).cuda())

        module_a.b /= 2.
        module_a.a.data /= 2.
        module_a.c /= 2.

        self.runner.iter += 1
        ema.after_train_iter(self.runner)
        ema_states = module_a_ema.state_dict()
        assert torch.equal(self.runner.model.module.module_a.a,
                           torch.tensor([0.5, 1.]).cuda())
        assert torch.equal(ema_states['a'], torch.tensor([0.75, 1.5]).cuda())
        assert torch.equal(ema_states['b'], torch.tensor([1., 1.5]).cuda())
        assert 'c' not in ema_states

        # test ema with simple warm up
        runner = SimpleRunner()
        runner.model = runner.model.cuda()
        cfg_ = deepcopy(self.default_config)
        cfg_.update(dict(start_iter=3, interval=1))
        ema = ExponentialMovingAverageHook(**cfg_)
        ema.before_run(runner)

        module_a = runner.model.module_a
        module_a_ema = runner.model.module_a_ema

        module_a.a.data /= 2.

        runner.iter += 1
        ema.after_train_iter(runner)
        ema_states = module_a_ema.state_dict()
        assert torch.equal(runner.model.module_a.a,
                           torch.tensor([0.5, 1.]).cuda())
        assert torch.equal(ema_states['a'], torch.tensor([0.5, 1.]).cuda())

        module_a.a.data /= 2
        runner.iter += 2
        ema.after_train_iter(runner)
        ema_states = module_a_ema.state_dict()
        assert torch.equal(runner.model.module_a.a,
                           torch.tensor([0.25, 0.5]).cuda())
        assert torch.equal(ema_states['a'], torch.tensor([0.375, 0.75]).cuda())
Exemple #17
0
def main(args):
    logdir = args.savedir + '/logs/'
    if not os.path.isdir(logdir):
        os.makedirs(logdir)

    my_logger = Logger(60066, logdir)

    if args.dataset == 'pascal':
        crop_size = (512, 512)
        args.scale = (0.5, 2.0)
    elif args.dataset == 'city':
        crop_size = (768, 768)
        args.scale = (0.5, 2.0)

    print_info_message(
        'Running Model at image resolution {}x{} with batch size {}'.format(
            crop_size[1], crop_size[0], args.batch_size))
    if not os.path.isdir(args.savedir):
        os.makedirs(args.savedir)

    if args.dataset == 'pascal':
        from data_loader.segmentation.voc import VOCSegmentation, VOC_CLASS_LIST
        train_dataset = VOCSegmentation(root=args.data_path,
                                        train=True,
                                        crop_size=crop_size,
                                        scale=args.scale,
                                        coco_root_dir=args.coco_path)
        val_dataset = VOCSegmentation(root=args.data_path,
                                      train=False,
                                      crop_size=crop_size,
                                      scale=args.scale)
        seg_classes = len(VOC_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
    elif args.dataset == 'city':
        from data_loader.segmentation.cityscapes import CityscapesSegmentation, CITYSCAPE_CLASS_LIST
        train_dataset = CityscapesSegmentation(root=args.data_path,
                                               train=True,
                                               size=crop_size,
                                               scale=args.scale,
                                               coarse=args.coarse)
        val_dataset = CityscapesSegmentation(root=args.data_path,
                                             train=False,
                                             size=crop_size,
                                             scale=args.scale,
                                             coarse=False)
        seg_classes = len(CITYSCAPE_CLASS_LIST)
        class_wts = torch.ones(seg_classes)
        class_wts[0] = 2.8149201869965
        class_wts[1] = 6.9850029945374
        class_wts[2] = 3.7890393733978
        class_wts[3] = 9.9428062438965
        class_wts[4] = 9.7702074050903
        class_wts[5] = 9.5110931396484
        class_wts[6] = 10.311357498169
        class_wts[7] = 10.026463508606
        class_wts[8] = 4.6323022842407
        class_wts[9] = 9.5608062744141
        class_wts[10] = 7.8698215484619
        class_wts[11] = 9.5168733596802
        class_wts[12] = 10.373730659485
        class_wts[13] = 6.6616044044495
        class_wts[14] = 10.260489463806
        class_wts[15] = 10.287888526917
        class_wts[16] = 10.289801597595
        class_wts[17] = 10.405355453491
        class_wts[18] = 10.138095855713
        class_wts[19] = 0.0
    else:
        print_error_message('Dataset: {} not yet supported'.format(
            args.dataset))
        exit(-1)

    print_info_message('Training samples: {}'.format(len(train_dataset)))
    print_info_message('Validation samples: {}'.format(len(val_dataset)))

    if args.model == 'espnetv2':
        from model.espnetv2 import espnetv2_seg
        args.classes = seg_classes
        model = espnetv2_seg(args)
    elif args.model == 'espnet':
        from model.espnet import espnet_seg
        args.classes = seg_classes
        model = espnet_seg(args)
    elif args.model == 'mobilenetv2_1_0':
        from model.mobilenetv2 import get_mobilenet_v2_1_0_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_1_0_seg(args)
    elif args.model == 'mobilenetv2_0_35':
        from model.mobilenetv2 import get_mobilenet_v2_0_35_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_0_35_seg(args)
    elif args.model == 'mobilenetv2_0_5':
        from model.mobilenetv2 import get_mobilenet_v2_0_5_seg
        args.classes = seg_classes
        model = get_mobilenet_v2_0_5_seg(args)
    elif args.model == 'mobilenetv3_small':
        from model.mobilenetv3 import get_mobilenet_v3_small_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_small_seg(args)
    elif args.model == 'mobilenetv3_large':
        from model.mobilenetv3 import get_mobilenet_v3_large_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_large_seg(args)
    elif args.model == 'mobilenetv3_RE_small':
        from model.mobilenetv3 import get_mobilenet_v3_RE_small_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_RE_small_seg(args)
    elif args.model == 'mobilenetv3_RE_large':
        from model.mobilenetv3 import get_mobilenet_v3_RE_large_seg
        args.classes = seg_classes
        model = get_mobilenet_v3_RE_large_seg(args)
    else:
        print_error_message('Arch: {} not yet supported'.format(args.model))
        exit(-1)

    num_gpus = torch.cuda.device_count()
    device = 'cuda' if num_gpus > 0 else 'cpu'

    train_params = []
    params_dict = dict(model.named_parameters())
    others = args.weight_decay * 0.01
    for key, value in params_dict.items():
        if len(value.data.shape) == 4:
            if value.data.shape[1] == 1:
                train_params += [{
                    'params': [value],
                    'lr': args.lr,
                    'weight_decay': 0.0
                }]
            else:
                train_params += [{
                    'params': [value],
                    'lr': args.lr,
                    'weight_decay': args.weight_decay
                }]
        else:
            train_params += [{
                'params': [value],
                'lr': args.lr,
                'weight_decay': others
            }]

    args.learning_rate = args.lr
    optimizer = get_optimizer(args.optimizer, train_params, args)
    num_params = model_parameters(model)
    flops = compute_flops(model,
                          input=torch.Tensor(1, 3, crop_size[1], crop_size[0]))
    print_info_message(
        'FLOPs for an input of size {}x{}: {:.2f} million'.format(
            crop_size[1], crop_size[0], flops))
    print_info_message('Network Parameters: {:.2f} million'.format(num_params))

    start_epoch = 0
    epochs_len = args.epochs
    best_miou = 0.0

    #criterion = nn.CrossEntropyLoss(weight=class_wts, reduction='none', ignore_index=args.ignore_idx)
    criterion = SegmentationLoss(n_classes=seg_classes,
                                 loss_type=args.loss_type,
                                 device=device,
                                 ignore_idx=args.ignore_idx,
                                 class_wts=class_wts.to(device))

    if num_gpus >= 1:
        if num_gpus == 1:
            # for a single GPU, we do not need DataParallel wrapper for Criteria.
            # So, falling back to its internal wrapper
            from torch.nn.parallel import DataParallel
            model = DataParallel(model)
            model = model.cuda()
            criterion = criterion.cuda()
        else:
            from utilities.parallel_wrapper import DataParallelModel, DataParallelCriteria
            model = DataParallelModel(model)
            model = model.cuda()
            criterion = DataParallelCriteria(criterion)
            criterion = criterion.cuda()

        if torch.backends.cudnn.is_available():
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            cudnn.deterministic = True

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=args.workers,
                                               drop_last=True)
    if args.dataset == 'city':
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers,
                                                 drop_last=True)
    else:
        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 pin_memory=True,
                                                 num_workers=args.workers,
                                                 drop_last=True)

    lr_scheduler = get_lr_scheduler(args)

    print_info_message(lr_scheduler)

    with open(args.savedir + os.sep + 'arguments.json', 'w') as outfile:
        import json
        arg_dict = vars(args)
        arg_dict['model_params'] = '{} '.format(num_params)
        arg_dict['flops'] = '{} '.format(flops)
        json.dump(arg_dict, outfile)

    extra_info_ckpt = '{}_{}_{}'.format(args.model, args.s, crop_size[0])

    if args.fp_epochs > 0:
        print_info_message("========== MODEL FP WARMUP ===========")

        for epoch in range(args.fp_epochs):
            lr = lr_scheduler.step(epoch)

            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            print_info_message(
                'Running epoch {} with learning rates: {:.6f}'.format(
                    epoch, lr))
            start_t = time.time()
            miou_train, train_loss = train(model,
                                           train_loader,
                                           optimizer,
                                           criterion,
                                           seg_classes,
                                           epoch,
                                           device=device)
    if args.optimizer.startswith('Q'):
        optimizer.is_warmup = False
        print('exp_sensitivity calibration fin.')

    if not args.fp_train:
        model.module.quantized.fuse_model()
        model.module.quantized.qconfig = torch.quantization.get_default_qat_qconfig(
            'qnnpack')
        torch.quantization.prepare_qat(model.module.quantized, inplace=True)

    if args.resume:
        start_epoch = args.start_epoch
        if os.path.isfile(args.resume):
            print_info_message('Loading weights from {}'.format(args.resume))
            weight_dict = torch.load(args.resume, device)
            model.module.load_state_dict(weight_dict)
            print_info_message('Done')
        else:
            print_warning_message('No file for resume. Please check.')

    for epoch in range(start_epoch, args.epochs):
        lr = lr_scheduler.step(epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        print_info_message(
            'Running epoch {} with learning rates: {:.6f}'.format(epoch, lr))
        miou_train, train_loss = train(model,
                                       train_loader,
                                       optimizer,
                                       criterion,
                                       seg_classes,
                                       epoch,
                                       device=device)
        miou_val, val_loss = val(model,
                                 val_loader,
                                 criterion,
                                 seg_classes,
                                 device=device)

        # remember best miou and save checkpoint
        is_best = miou_val > best_miou
        best_miou = max(miou_val, best_miou)

        weights_dict = model.module.state_dict(
        ) if device == 'cuda' else model.state_dict()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.model,
                'state_dict': weights_dict,
                'best_miou': best_miou,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.savedir, extra_info_ckpt)
        if is_best:
            model_file_name = args.savedir + '/model_' + str(epoch +
                                                             1) + '.pth'
            torch.save(weights_dict, model_file_name)
            print('weights saved in {}'.format(model_file_name))
        info = {
            'Segmentation/LR': round(lr, 6),
            'Segmentation/Loss/train': train_loss,
            'Segmentation/Loss/val': val_loss,
            'Segmentation/mIOU/train': miou_train,
            'Segmentation/mIOU/val': miou_val,
            'Segmentation/Complexity/Flops': best_miou,
            'Segmentation/Complexity/Params': best_miou,
        }

        for tag, value in info.items():
            if tag == 'Segmentation/Complexity/Flops':
                my_logger.scalar_summary(tag, value, math.ceil(flops))
            elif tag == 'Segmentation/Complexity/Params':
                my_logger.scalar_summary(tag, value, math.ceil(num_params))
            else:
                my_logger.scalar_summary(tag, value, epoch + 1)

    print_info_message("========== TRAINING FINISHED ===========")
Exemple #18
0
    def __init__(self, opt):
        super(FIRNModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        test_opt = opt['test']
        self.train_opt = train_opt
        self.test_opt = test_opt

        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        self.Quantization = Quantization()

        if self.is_train:
            self.netG.train()

            # loss
            self.Reconstruction_forw = ReconstructionLoss(
                self.device, losstype=self.train_opt['pixel_criterion_forw'])
            self.Reconstruction_back = ReconstructionLoss(
                self.device, losstype=self.train_opt['pixel_criterion_back'])

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'],
                                                       train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    TVT, TMO = set_devices(cfg.sys_device_ids)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    test_set = create_dataset(**cfg.test_set_kwargs)

    #########
    # Model #
    #########

    model = Model(last_conv_stride=cfg.last_conv_stride,
                  num_stripes=cfg.num_stripes,
                  local_conv_out_channels=cfg.local_conv_out_channels,
                  num_classes=0)
    # Model wrapper
    model_w = DataParallel(model)

    # May Transfer Model to Specified Device.
    TMO([model])

    #####################
    # Load Model Weight #
    #####################

    # To first load weights to CPU
    map_location = (lambda storage, loc: storage)
    used_file = cfg.model_weight_file or cfg.ckpt_file
    loaded = torch.load(used_file, map_location=map_location)
    if cfg.model_weight_file == '':
        loaded = loaded['state_dicts'][0]
    load_state_dict(model, loaded)
    print('Loaded model weights from {}'.format(used_file))

    ###################
    # Extract Feature #
    ###################

    test_set.set_feat_func(ExtractFeature(model_w, TVT))

    with measure_time('Extracting feature...', verbose=True):
        feat, ids, cams, im_names, marks = test_set.extract_feat(True,
                                                                 verbose=True)

    #######################
    # Select Query Images #
    #######################

    # Fix some query images, so that the visualization for different models can
    # be compared.

    # Sort in the order of image names
    inds = np.argsort(im_names)
    feat, ids, cams, im_names, marks = \
      feat[inds], ids[inds], cams[inds], im_names[inds], marks[inds]

    # query, gallery index mask
    is_q = marks == 0
    is_g = marks == 1

    prng = np.random.RandomState(1)
    # selected query indices
    sel_q_inds = prng.permutation(range(np.sum(is_q)))[:cfg.num_queries]

    q_ids = ids[is_q][sel_q_inds]
    q_cams = cams[is_q][sel_q_inds]
    q_feat = feat[is_q][sel_q_inds]
    q_im_names = im_names[is_q][sel_q_inds]

    ####################
    # Compute Distance #
    ####################

    # query-gallery distance
    q_g_dist = compute_dist(q_feat, feat[is_g], type='euclidean')

    ###########################
    # Save Rank List as Image #
    ###########################

    q_im_paths = [ospj(test_set.im_dir, n) for n in q_im_names]
    save_paths = [ospj(cfg.exp_dir, 'rank_lists', n) for n in q_im_names]
    g_im_paths = [ospj(test_set.im_dir, n) for n in im_names[is_g]]

    for dist_vec, q_id, q_cam, q_im_path, save_path in zip(
            q_g_dist, q_ids, q_cams, q_im_paths, save_paths):

        rank_list, same_id = get_rank_list(dist_vec, q_id, q_cam, ids[is_g],
                                           cams[is_g], cfg.rank_list_size)

        save_rank_list_to_im(rank_list, same_id, q_im_path, g_im_paths,
                             save_path)
Exemple #20
0
    def __init__(self, opt):
        super(LRimgestimator_Model, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.train_opt = train_opt
        self.kernel_size = opt['datasets']['train']['kernel_size']
        self.patch_size = opt['datasets']['train']['patch_size']
        self.batch_size = opt['datasets']['train']['batch_size']

        # define networks and load pretrained models
        self.scale = opt['scale']
        self.model_name = opt['network_E']['which_model_E']
        self.mode = opt['network_E']['mode']

        self.netE = networks.define_E(opt).to(self.device)
        if opt['dist']:
            self.netE = DistributedDataParallel(
                self.netE, device_ids=[torch.cuda.current_device()])
        else:
            self.netE = DataParallel(self.netE)
        self.load()

        # loss
        if train_opt['loss_ftn'] == 'l1':
            self.MyLoss = nn.L1Loss(reduction='mean').to(self.device)
        elif train_opt['loss_ftn'] == 'l2':
            self.MyLoss = nn.MSELoss(reduction='mean').to(self.device)
        else:
            self.MyLoss = None

        if self.is_train:
            self.netE.train()

            # optimizers
            self.optimizers = []
            wd_R = train_opt['weight_decay_R'] if train_opt[
                'weight_decay_R'] else 0
            optim_params = []
            for k, v in self.netE.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    print('WARNING: params [%s] will not optimize.' % k)
            self.optimizer_E = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_C'],
                                                weight_decay=wd_R)
            print('Weight_decay:%f' % wd_R)
            self.optimizers.append(self.optimizer_E)

            # schedulers
            self.schedulers = []
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                                                                    train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()
Exemple #21
0
def main():
    cfg = Config()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    if not cfg.only_test:
        train_set = create_dataset(**cfg.train_set_kwargs)
        # The combined dataset does not provide val set currently.
        val_set = None if cfg.dataset == 'combined' else create_dataset(
            **cfg.val_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    # model = Model(last_conv_stride=cfg.last_conv_stride)
    if cfg.dataset == 'market1501':
        nr_class = 751
    elif cfg.dataset == 'duke':
        nr_class = 702
    elif cfg.dataset == 'cuhk03':
        nr_class = 767
    elif cfg.dataset == 'combined':
        nr_class = 2220
    model = get_scp_model(nr_class)

    # load pretrained ImageNet weights
    model_dict = model.state_dict()  # original
    pretrained_dict = torch.load('models/resnet50-19c8e357.pth')  # pretrained
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items()
        if k in model_dict and model_dict[k].size() == v.size()
    }
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

    # Model wrapper
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################

    criterion_cls = torch.nn.CrossEntropyLoss()
    criterion_feature = torch.nn.MSELoss()

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.base_lr,
                           weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            mAP, cmc_scores, mq_mAP, mq_cmc_scores = test_set.eval(
                normalize_feat=cfg.normalize_feature, verbose=True)
            return mAP, cmc_scores, mq_mAP, mq_cmc_scores

    def validate():
        if val_set.extract_feat_func is None:
            val_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n=========> Test on validation set <=========\n')
        mAP, cmc_scores, _, _ = val_set.eval(
            normalize_feat=cfg.normalize_feature,
            to_re_rank=False,
            verbose=False)
        print()
        return mAP, cmc_scores[0]

    if cfg.only_test:
        mAP, cmc_scores, mq_mAP, mq_cmc_scores = test(load_model_weight=True)
        return

    ############
    # Training #
    ############

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        # if cfg.lr_decay_type == 'exp':
        #     adjust_lr_exp(
        #         optimizer,
        #         cfg.base_lr,
        #         ep + 1,
        #         cfg.total_epochs,
        #         cfg.exp_decay_at_epoch)
        # else:
        #     adjust_lr_staircase(
        #         optimizer,
        #         cfg.base_lr,
        #         ep + 1,
        #         cfg.staircase_decay_at_epochs,
        #         cfg.staircase_decay_multiply_factor)

        #
        if ep < 20:
            lr = 1e-4 * (ep + 1) / 2
        elif ep < 80:
            lr = 1e-3
        elif 80 <= ep <= 180:
            lr = 1e-4
        else:
            lr = 1e-5
        for g in optimizer.param_groups:
            g['lr'] = lr

        may_set_mode(modules_optims, 'train')

        # For recording precision, satisfying margin, etc
        prec_meter = AverageMeter()
        sm_meter = AverageMeter()
        dist_ap_meter = AverageMeter()
        dist_an_meter = AverageMeter()
        loss_meter = AverageMeter()

        ep_st = time.time()
        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = Variable(TVT(torch.from_numpy(labels).long()))

            feat = model_w(ims_var)

            loss, prec = scp_loss(feat, labels_t, criterion_cls,
                                  criterion_feature, cfg.ids_per_batch,
                                  cfg.ims_per_id)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ############
            # Step Log #
            ############

            prec_meter.update(prec)
            loss_meter.update(to_scalar(loss))

            if step % cfg.steps_per_log == 0:
                time_log = '\tStep {}/Ep {}, {:.2f}s'.format(
                    step,
                    ep + 1,
                    time.time() - step_st,
                )

                tri_log = (', prec {:.2%}, sm {:.2%}, '
                           'd_ap {:.4f}, d_an {:.4f}, '
                           'loss {:.4f}'.format(
                               prec_meter.val,
                               sm_meter.val,
                               dist_ap_meter.val,
                               dist_an_meter.val,
                               loss_meter.val,
                           ))

                log = time_log + tri_log
                print(log)

        #############
        # Epoch Log #
        #############

        time_log = 'Ep {}, {:.2f}s'.format(ep + 1, time.time() - ep_st)

        tri_log = (', prec {:.2%}, sm {:.2%}, '
                   'd_ap {:.4f}, d_an {:.4f}, '
                   'loss {:.4f}'.format(
                       prec_meter.avg,
                       sm_meter.avg,
                       dist_ap_meter.avg,
                       dist_an_meter.avg,
                       loss_meter.avg,
                   ))

        log = time_log + tri_log
        print(log)

        ##########################
        # Test on Validation Set #
        ##########################

        mAP, Rank1 = 0, 0
        if ((ep + 1) % cfg.epochs_per_val == 0) and (val_set is not None):
            mAP, Rank1 = validate()

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars('val scores', dict(mAP=mAP, Rank1=Rank1), ep)
            writer.add_scalars('loss', dict(loss=loss_meter.avg, ), ep)
            writer.add_scalars('precision', dict(precision=prec_meter.avg, ),
                               ep)
            writer.add_scalars('satisfy_margin',
                               dict(satisfy_margin=sm_meter.avg, ), ep)
            writer.add_scalars(
                'average_distance',
                dict(
                    dist_ap=dist_ap_meter.avg,
                    dist_an=dist_an_meter.avg,
                ), ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test(load_model_weight=False)
Exemple #22
0
def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    logger = get_logger(cfg.log_level)

    # init distributed environment if necessary
    if args.launcher == 'none':
        dist = False
        logger.info('Disabled distributed training.')
    else:
        dist = True
        init_dist(**cfg.dist_params)
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        if rank != 0:
            logger.setLevel('ERROR')
        logger.info('Enabled distributed training.')

    # build datasets and dataloaders
    normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
    train_dataset = datasets.CIFAR10(root=cfg.data_root,
                                     train=True,
                                     transform=transforms.Compose([
                                         transforms.RandomCrop(32, padding=4),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         normalize,
                                     ]))
    val_dataset = datasets.CIFAR10(root=cfg.data_root,
                                   train=False,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       normalize,
                                   ]))
    if dist:
        num_workers = cfg.data_workers
        assert cfg.batch_size % world_size == 0
        batch_size = cfg.batch_size // world_size
        train_sampler = DistributedSampler(train_dataset, world_size, rank)
        val_sampler = DistributedSampler(val_dataset, world_size, rank)
        shuffle = False
    else:
        num_workers = cfg.data_workers * len(cfg.gpus)
        batch_size = cfg.batch_size
        train_sampler = None
        val_sampler = None
        shuffle = True
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              num_workers=num_workers)
    val_loader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            sampler=val_sampler,
                            num_workers=num_workers)

    # build model
    model = getattr(resnet_cifar, cfg.model)()
    if dist:
        model = DistributedDataParallel(
            model.cuda(), device_ids=[torch.cuda.current_device()])
    else:
        model = DataParallel(model, device_ids=cfg.gpus).cuda()

    # build runner and register hooks
    runner = Runner(model,
                    batch_processor,
                    cfg.optimizer,
                    cfg.work_dir,
                    log_level=cfg.log_level)
    runner.register_training_hooks(lr_config=cfg.lr_config,
                                   optimizer_config=cfg.optimizer_config,
                                   checkpoint_config=cfg.checkpoint_config,
                                   log_config=cfg.log_config,
                                   custom_hooks_config=cfg.get(
                                       'custom_train_hooks', None))
    if dist:
        runner.register_hook(DistSamplerSeedHook())

    # load param (if necessary) and run
    if cfg.get('resume_from') is not None:
        runner.resume(cfg.resume_from)
    elif cfg.get('load_from') is not None:
        runner.load_checkpoint(cfg.load_from)

    runner.run([train_loader, val_loader], cfg.workflow, cfg.total_epochs)
Exemple #23
0
    def __init__(self, opt):
        super(VideoSRBaseModel, self).__init__(opt)

        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)

        if opt['dist']:
            self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        # print network
        self.print_network()
        self.load()

        if self.is_train:
            self.netG.train()

            #### loss
            loss_type = train_opt['pixel_criterion']
            if loss_type == 'l1':
                self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
            elif loss_type == 'l2':
                self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
            elif loss_type == 'cb':
                self.cri_pix = CharbonnierLoss().to(self.device)
            elif loss_type == 'lp':
                self.cri_pix = LapLoss(max_levels=5).to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
            self.l_pix_w = train_opt['pixel_weight']

            #### optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning('Params [{:s}] will not optimize.'.format(k))

            self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1'], train_opt['beta2']))
            self.optimizers.append(self.optimizer_G)
            #### schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
                                                         restarts=train_opt['restarts'],
                                                         weights=train_opt['restart_weights'],
                                                         gamma=train_opt['lr_gamma'],
                                                         clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError()

            self.log_dict = OrderedDict()
Exemple #24
0
def main():
    cfg = Config()

    # breakpoint()

    # breakpoint()

    # Redirect logs to both console and file.
    if cfg.log_to_file:
        ReDirectSTD(cfg.stdout_file, 'stdout', False)
        ReDirectSTD(cfg.stderr_file, 'stderr', False)

    # Lazily create SummaryWriter
    writer = None

    TVT, TMO = set_devices(cfg.sys_device_ids)

    # breakpoint()

    if cfg.seed is not None:
        set_seed(cfg.seed)

    # Dump the configurations to log.
    import pprint
    print('-' * 60)
    print('cfg.__dict__')
    pprint.pprint(cfg.__dict__)
    print('-' * 60)

    ###########
    # Dataset #
    ###########

    if not cfg.only_test:
        train_set = create_dataset(**cfg.train_set_kwargs)
        val_set = create_dataset(**cfg.val_set_kwargs)

    test_sets = []
    test_set_names = []
    if cfg.dataset == 'combined':
        for name in ['market1501', 'cuhk03', 'duke']:
            cfg.test_set_kwargs['name'] = name
            test_sets.append(create_dataset(**cfg.test_set_kwargs))
            test_set_names.append(name)
    else:
        test_sets.append(create_dataset(**cfg.test_set_kwargs))
        test_set_names.append(cfg.dataset)

    ###########
    # Models  #
    ###########

    # breakpoint()

    model = Model(last_conv_stride=cfg.last_conv_stride)
    # Model wrapper
    # model_w = model
    model_w = DataParallel(model)

    #############################
    # Criteria and Optimizers   #
    #############################

    # import pprint
    # pprint.pprint(train_set)
    # import pdb
    # pdb.set_trace()

    optimizer = optim.Adam(model.parameters(),
                           lr=cfg.base_lr,
                           weight_decay=cfg.weight_decay)

    # Bind them together just to save some codes in the following usage.
    modules_optims = [model, optimizer]

    ################################
    # May Resume Models and Optims #
    ################################

    if cfg.resume:
        resume_ep, scores = load_ckpt(modules_optims, cfg.ckpt_file)

    # May Transfer Models and Optims to Specified Device. Transferring optimizer
    # is to cope with the case when you load the checkpoint to a new device.
    # TMO(modules_optims)

    ########
    # Test #
    ########

    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            mAP, cmc_scores = test_set.eval_simple(
                normalize_feat=cfg.normalize_feature, verbose=True)
            return mAP, cmc_scores[0]

    def test_full(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=cfg.normalize_feature, verbose=True)

    def validate():
        if val_set.extract_feat_func is None:
            val_set.set_feat_func(ExtractFeature(model_w, TVT))
        print('\n=========> Test on validation set <=========\n')
        mAP, cmc_scores, _, _ = val_set.eval(
            normalize_feat=cfg.normalize_feature,
            to_re_rank=False,
            verbose=False)
        print()
        return mAP, cmc_scores[0]

    def normalize(x, axis=-1):
        x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) +
                      1e-12)
        return x

    if cfg.only_test:
        test_full(load_model_weight=True)
        return

    ############
    # Training #
    ############

    num_class = len(train_set.ids2labels)
    hidden_dim = 2048
    # center_loss = CenterLoss(num_class, hidden_dim)
    nll_loss = torch.nn.NLLLoss()
    # tri_loss = TripletLoss(margin=cfg.margin)
    # center_loss_weight = cfg.center_loss_weight
    # xentrp_loss_weight = cfg.xentrp_loss_weight
    # triplet_loss_weight = cfg.triplet_loss_weight

    start_ep = resume_ep if cfg.resume else 0
    for ep in range(start_ep, cfg.total_epochs):

        # Adjust Learning Rate
        if cfg.lr_decay_type == 'exp':
            adjust_lr_exp(optimizer, cfg.base_lr, ep + 1, cfg.total_epochs,
                          cfg.exp_decay_at_epoch)
        else:
            adjust_lr_staircase(optimizer, cfg.base_lr, ep + 1,
                                cfg.staircase_decay_at_epochs,
                                cfg.staircase_decay_multiply_factor)

        may_set_mode(modules_optims, 'train')

        # For recording precision, satisfying margin, etc
        prec_meter = AverageMeter()
        sm_meter = AverageMeter()
        dist_ap_meter = AverageMeter()
        dist_an_meter = AverageMeter()
        loss_meter = AverageMeter()

        # tri_loss_meter = AverageMeter()
        # cent_loss_meter = AverageMeter()
        # xent_loss_meter = AverageMeter()

        # ep_st = time.time()

        # criterion = losses.create('myknnsoftmax', alpha=cfg.knnsoftmax_alpha, k=cfg.knnsoftmax_k, r=1.0, weight=1.0).cuda()

        criterion = SN_LOSS(alpha=cfg.SN_alpha, k=cfg.SN_k,
                            weight=cfg.SN_w).cuda()

        # .cuda()

        step = 0
        epoch_done = False
        while not epoch_done:

            step += 1
            step_st = time.time()

            ims, im_names, labels, mirrored, epoch_done = train_set.next_batch(
            )

            if epoch_done or ims.shape[0] != 100:
                # breakpoint()
                continue

            ims_var = Variable(TVT(torch.from_numpy(ims).float()))
            labels_t = TVT(torch.from_numpy(labels).long())
            feat = model_w(ims_var)

            if cfg.normalize_feature:
                feat = normalize(feat, axis=-1)

            # print(feat.shape)
            # import pdb
            # pdb.set_trace()

            # aa = feat.cuda()
            # bb = Variable(labels_t.cuda())
            # import pdb
            # pdb.set_trace()
            loss, inter_, dist_ap, dist_an = criterion(
                feat.cuda(), Variable(labels_t.cuda()))

            # tloss, p_inds, n_inds, dist_ap, dist_an, dist_mat = global_loss(
            #   tri_loss, feat, labels_t, normalize_feature=cfg.normalize_feature)

            # cent_loss = center_loss.forward(feat, Variable(labels_t))
            # import pdb
            # pdb.set_trace()

            # cent_loss = center_loss.forward(Variable(labels_t).cpu(), feat.cpu())
            # cent_loss =  cent_loss.cuda()

            # if normalize_feature:

            # def normalize(x, axis=-1):

            # nfeat = 1. * feat / (torch.norm(feat, 2, -1, keepdim=True).expand_as(-1) + 1e-12)
            # import pdb
            # pdb.set_trace()

            # nfeat = normalize(feat, axis=-1)

            # feat = torch.nn.functional.log_softmax(feat)
            # xentrp_loss = nll_loss(feat, Variable(labels_t))

            # loss = triplet_loss_weight*tloss + center_loss_weight * cent_loss + xentrp_loss_weight * xentrp_loss
            # loss = triplet_loss_weight * tloss
            # + center_loss_weight * cent_loss + xentrp_loss_weight * xentrp_loss
            # center_loss_weight = 1.0
            # xentrp_loss_weight = 1.0

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            if (step + 1) % 10 == 0:
                print(
                    '[Epoch %05d: step %05d]\t Loss: %.6f \t Accuracy: %.3f \t Pos-Dist: %.3f \t Neg-Dist: %.3f'
                    %
                    (ep + 1, step + 1, loss.data[0], inter_, dist_ap, dist_an))
            ############
            # Step Log #
            ############

            # precision
            # prec = (dist_an > dist_ap).data.float().mean()
            # # the proportion of triplets that satisfy margin
            # sm = (dist_an > dist_ap + cfg.margin).data.float().mean()
            # # average (anchor, positive) distance
            # d_ap = dist_ap.data.mean()
            # # average (anchor, negative) distance
            # d_an = dist_an.data.mean()

            # prec_meter.update(prec)
            # sm_meter.update(sm)
            # dist_ap_meter.update(d_ap)
            # dist_an_meter.update(d_an)

            loss_meter.update(to_scalar(loss))
            # tri_loss_meter.update(to_scalar(triplet_loss_weight*tloss))
            # cent_loss_meter.update(to_scalar(center_loss_weight*cent_loss))
            # xent_loss_meter.update(to_scalar(xentrp_loss_weight*xentrp_loss))

            # if step % cfg.steps_per_log == 0:
            #   time_log = '\tStep {}/Ep {}'.format(step, ep + 1)
            #   tri_log = ('loss {:.4f}'.format(loss_meter.val))
            #   log = time_log + tri_log
            #   print(log)

        #############
        # Epoch Log #
        #############

        # time_log = 'Ep {}'.format(ep + 1)
        # tri_log = ('loss {:.4f}'.format(loss_meter.avg))
        # log = time_log + tri_log
        # print(log)

        mAP, Rank1 = 0, 0
        if (ep + 1) % 50 == 0:
            mAP, Rank1 = test(load_model_weight=False)

        # mAP, Rank1 = 0, 0
        # if (ep + 1) % cfg.epochs_per_val == 0:
        #   mAP, Rank1 = validate()

        # Log to TensorBoard

        if cfg.log_to_file:
            if writer is None:
                writer = SummaryWriter(
                    log_dir=osp.join(cfg.exp_dir, 'tensorboard'))
            writer.add_scalars('test scores', dict(mAP=mAP, Rank1=Rank1), ep)
            writer.add_scalars('loss', dict(loss=loss_meter.avg, ), ep)
            # writer.add_scalars(
            #   'precision',
            #   dict(precision=prec_meter.avg, ),
            #   ep)
            # writer.add_scalars(
            #   'satisfy_margin',
            #   dict(satisfy_margin=sm_meter.avg, ),
            #   ep)
            # writer.add_scalars(
            #   'average_distance',
            #   dict(dist_ap=dist_ap_meter.avg,
            #        dist_an=dist_an_meter.avg, ),
            #   ep)

        # save ckpt
        if cfg.log_to_file:
            save_ckpt(modules_optims, ep + 1, 0, cfg.ckpt_file)

    ########
    # Test #
    ########

    test_full(load_model_weight=False)
Exemple #25
0
 def begin_fit(self):
     self.learn.model = DataParallel(self.learn.model,
                                     device_ids=self.device_ids)
Exemple #26
0
    def __init__(self, opt):
        super(SRGANModel, self).__init__(opt)
        if opt['dist']:
            self.rank = torch.distributed.get_rank()
        else:
            self.rank = -1  # non dist training
        train_opt = opt['train']
        self.train_opt = train_opt
        self.opt = opt

        self.segmentor = None

        # define networks and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        if opt['dist']:
            self.netG = DistributedDataParallel(
                self.netG, device_ids=[torch.cuda.current_device()])
        else:
            self.netG = DataParallel(self.netG)
        if self.is_train:
            self.netD = networks.define_D(opt).to(self.device)
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D = networks.define_video_D(opt).to(self.device)
            if opt['dist']:
                self.netD = DistributedDataParallel(
                    self.netD, device_ids=[torch.cuda.current_device()])
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DistributedDataParallel(
                        self.net_video_D,
                        device_ids=[torch.cuda.current_device()])
            else:
                self.netD = DataParallel(self.netD)
                if train_opt.get("gan_video_weight", 0) > 0:
                    self.net_video_D = DataParallel(self.net_video_D)

            self.netG.train()
            self.netD.train()
            if train_opt.get("gan_video_weight", 0) > 0:
                self.net_video_D.train()

        # define losses, optimizer and scheduler
        if self.is_train:
            # G pixel loss
            if train_opt['pixel_weight'] > 0:
                l_pix_type = train_opt['pixel_criterion']
                if l_pix_type == 'l1':
                    self.cri_pix = nn.L1Loss().to(self.device)
                elif l_pix_type == 'l2':
                    self.cri_pix = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_pix_type))
                self.l_pix_w = train_opt['pixel_weight']
            else:
                logger.info('Remove pixel loss.')
                self.cri_pix = None

            # Pixel mask loss
            if train_opt.get("pixel_mask_weight", 0) > 0:
                l_pix_type = train_opt['pixel_mask_criterion']
                self.cri_pix_mask = LMaskLoss(
                    l_pix_type=l_pix_type,
                    segm_mask=train_opt['segm_mask']).to(self.device)
                self.l_pix_mask_w = train_opt['pixel_mask_weight']
            else:
                logger.info('Remove pixel mask loss.')
                self.cri_pix_mask = None

            # G feature loss
            if train_opt['feature_weight'] > 0:
                l_fea_type = train_opt['feature_criterion']
                if l_fea_type == 'l1':
                    self.cri_fea = nn.L1Loss().to(self.device)
                elif l_fea_type == 'l2':
                    self.cri_fea = nn.MSELoss().to(self.device)
                else:
                    raise NotImplementedError(
                        'Loss type [{:s}] not recognized.'.format(l_fea_type))
                self.l_fea_w = train_opt['feature_weight']
            else:
                logger.info('Remove feature loss.')
                self.cri_fea = None
            if self.cri_fea:  # load VGG perceptual loss
                self.netF = networks.define_F(opt,
                                              use_bn=False).to(self.device)
                if opt['dist']:
                    self.netF = DistributedDataParallel(
                        self.netF, device_ids=[torch.cuda.current_device()])
                else:
                    self.netF = DataParallel(self.netF)

            # GD gan loss
            self.cri_gan = GANLoss(train_opt['gan_type'], 1.0,
                                   0.0).to(self.device)
            self.l_gan_w = train_opt['gan_weight']
            # Video gan weight
            if train_opt.get("gan_video_weight", 0) > 0:
                self.cri_video_gan = GANLoss(train_opt['gan_video_type'], 1.0,
                                             0.0).to(self.device)
                self.l_gan_video_w = train_opt['gan_video_weight']

                # can't use optical flow with i and i+1 because we need i+2 lr to calculate i+1 oflow
                if 'train' in self.opt['datasets'].keys():
                    key = "train"
                else:
                    key = 'test_1'
                assert self.opt['datasets'][key][
                    'optical_flow_with_ref'] == True, f"Current value = {self.opt['datasets'][key]['optical_flow_with_ref']}"
            # D_update_ratio and D_init_iters
            self.D_update_ratio = train_opt['D_update_ratio'] if train_opt[
                'D_update_ratio'] else 1
            self.D_init_iters = train_opt['D_init_iters'] if train_opt[
                'D_init_iters'] else 0

            # optimizers
            # G
            wd_G = train_opt['weight_decay_G'] if train_opt[
                'weight_decay_G'] else 0
            optim_params = []
            for k, v in self.netG.named_parameters(
            ):  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                else:
                    if self.rank <= 0:
                        logger.warning(
                            'Params [{:s}] will not optimize.'.format(k))
            self.optimizer_G = torch.optim.Adam(optim_params,
                                                lr=train_opt['lr_G'],
                                                weight_decay=wd_G,
                                                betas=(train_opt['beta1_G'],
                                                       train_opt['beta2_G']))
            self.optimizers.append(self.optimizer_G)
            # D
            wd_D = train_opt['weight_decay_D'] if train_opt[
                'weight_decay_D'] else 0
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                lr=train_opt['lr_D'],
                                                weight_decay=wd_D,
                                                betas=(train_opt['beta1_D'],
                                                       train_opt['beta2_D']))
            self.optimizers.append(self.optimizer_D)

            # Video D
            if train_opt.get("gan_video_weight", 0) > 0:
                self.optimizer_video_D = torch.optim.Adam(
                    self.net_video_D.parameters(),
                    lr=train_opt['lr_D'],
                    weight_decay=wd_D,
                    betas=(train_opt['beta1_D'], train_opt['beta2_D']))
                self.optimizers.append(self.optimizer_video_D)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.MultiStepLR_Restart(
                            optimizer,
                            train_opt['lr_steps'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights'],
                            gamma=train_opt['lr_gamma'],
                            clear_state=train_opt['clear_state']))
            elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
                for optimizer in self.optimizers:
                    self.schedulers.append(
                        lr_scheduler.CosineAnnealingLR_Restart(
                            optimizer,
                            train_opt['T_period'],
                            eta_min=train_opt['eta_min'],
                            restarts=train_opt['restarts'],
                            weights=train_opt['restart_weights']))
            else:
                raise NotImplementedError(
                    'MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        self.print_network()  # print network
        self.load()  # load G and D if needed
Exemple #27
0
def get_model_wrapper(model, multi_gpu):
    from torch.nn.parallel import DataParallel
    if multi_gpu:
        return DataParallel(model)
    else:
        return model
Exemple #28
0
def run_experiment(config: ExperimentConfig, save_model=True):  # pylint: disable=too-many-statements, too-many-branches, too-many-locals
    # Check Pytorch Version Before Running
    logger.info('Torch Version: %s', torch.__version__)  # type: ignore
    logger.info('Cuda Version: %s', torch.version.cuda)  # type: ignore

    if config.random_seed is not None:
        setup_random_seed(config.random_seed)

    # Initialize Writer
    writer_dir = f"{config.tensorboard_log_root}/{config.cur_time}/"
    writer = SummaryWriter(log_dir=writer_dir)

    # Initialize Device
    if isinstance(config.gpu_device_id, list):
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
            map(str, config.gpu_device_id))
    elif config.gpu_device_id is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_device_id)

    replica = torch.cuda.device_count()
    logger.info('Device Counts: %s', replica)
    wandb.log({"Model batch size": config.batch_size * replica}, 0)
    # device = torch.device(f"cuda:{config.gpu_device_id}")

    # Initialize Dataset and Split into train/valid/test DataSets
    dataset_dict = {
        "output_type": config.output_type,
        "frames_per_clip": config.frames_per_clip,
        "step_between_clips": config.step_between_clips,
        "frame_rate": config.frame_rate,
        "num_sample_per_clip": config.num_sample_per_clip,
    }

    if config.dataset_artifact:
        dataset = MouseClipDataset.from_wandb_artifact(
            config.dataset_artifact,
            split_by=config.split_by,
            mix_clip=config.mix_clip,
            no_valid=config.no_valid,
            extract_groom=config.extract_groom,
            exclude_5min=config.exclude_5min,
            exclude_2_mouse=config.exclude_2_mouse,
            exclude_fpvid=config.exclude_fpvid,
            exclude_2_mouse_valid=config.exclude_2_mouse_valid,
            **dataset_dict)
    elif config.dataset_root in ['./data/breakfast', './data/mpii']:
        dataset = MouseClipDataset.from_annotation_list(
            dataset_root=config.dataset_root, **dataset_dict)
    else:
        metadata_path = (config.metadata_path
                         if config.metadata_path is not None else os.path.join(
                             config.dataset_root, "metadata.pth"))
        dataset = MouseClipDataset.from_ds_folder(
            dataset_root=config.dataset_root,
            metadata_path=metadata_path,
            extract_groom=config.extract_groom,
            **dataset_dict)

    train_set = dataset.get_split("train", config.split_by,
                                  config.transform_size, {})
    valid_set = dataset.get_split("valid", config.split_by,
                                  config.transform_size, config.valid_set_args)
    test_set = dataset.get_split("test", config.split_by,
                                 config.transform_size, config.test_set_args)

    logger.info('Train Transform:\n%s', train_set.transform)
    logger.info('Valid Transform:\n%s', valid_set.transform)
    logger.info('Test Transform:\n%s', test_set.transform)

    dataloaders = {
        "train":
        DataLoader(train_set,
                   config.batch_size * replica,
                   sampler=train_set.get_sampler("train",
                                                 config.train_sampler_config,
                                                 config.samples_per_epoch),
                   num_workers=config.num_worker,
                   pin_memory=True,
                   drop_last=True),
        "valid":
        DataLoader(valid_set,
                   config.batch_size * replica,
                   sampler=valid_set.get_sampler("valid",
                                                 config.valid_sampler_config),
                   num_workers=config.num_worker,
                   pin_memory=True,
                   drop_last=False),
        "test":
        DataLoader(test_set,
                   config.batch_size * replica,
                   sampler=test_set.get_sampler("test",
                                                config.test_sampler_config),
                   num_workers=config.num_worker,
                   pin_memory=True,
                   drop_last=False),
    }

    # initialize model
    if config.model is not None:
        model = config.model(**config.model_args)
        if isinstance(model, ResNet):
            model.fc = torch.nn.Linear(model.fc.in_features,
                                       len(set(dataset.labels)))

        if config.xavier_init:
            model = init_xavier_weights(model)
        # Make wandb Track the model
        wandb.watch(model, "parameters")

        logger.info('Model: %s', model.__class__.__name__)
        # Log total parameters in the model
        pytorch_total_params = sum(p.numel() for p in model.parameters())
        logger.info('Model params: %s', pytorch_total_params)
        pytorch_total_params_trainable = sum(p.numel()
                                             for p in model.parameters()
                                             if p.requires_grad)
        logger.info('Model params trainable: %s',
                    pytorch_total_params_trainable)

        model_structure_str = "Model Structue:\n"
        for name, param in model.named_parameters():
            model_structure_str += f"\t{name}: {param.requires_grad}, {param.numel()}\n"
        # logger.info(model_structure_str)

        model = model.cuda()
        if replica > 1:
            model = DataParallel(model)
    else:
        logger.critical("Model not chosen in config!")
        return None

    if isinstance(config.loss_function, torch.nn.Module):
        config.loss_function = config.loss_function.cuda()

    optimizer = config.optimizer(params=model.parameters(),
                                 **config.optimizer_args)
    logger.info("Optimizer: %s\n%s", config.optimizer.__name__,
                config.optimizer_args)

    if config.lr_scheduler is not None:
        lr_scheduler = config.lr_scheduler(optimizer,
                                           **config.lr_scheduler_args)
        logger.info("LR Scheduler: %s\n%s", config.lr_scheduler.__name__,
                    config.lr_scheduler_args)
    else:
        lr_scheduler = None
        logger.info("No LR Scheduler")

    logger.info("Training Started!")
    ckpter = Checkpointer(config.best_metric, save_path=config.save_path)

    training_history, total_steps = train_model(
        model=model,
        optimizer=optimizer,
        dataloaders=dataloaders,
        writer=writer,
        num_epochs=config.num_epochs,
        loss_function=config.loss_function,
        lr_scheduler=lr_scheduler,
        valid_every_epoch=config.valid_every_epoch,
        ckpter=ckpter,
    )
    logger.info("Training Complete!")

    if ckpter is not None:
        ckpter.load_best_model(model)

    logger.info("Testing Started!")
    test_report = evaluate_model(model, dataloaders['test'], "Testing",
                                 total_steps, writer, config.loss_function)
    logger.info("Testing Complete!")

    if save_model:
        train_artifact = wandb.Artifact(f'run_{wandb.run.id}_model', 'model')
        model_tmp_path = os.path.join(wandb.run.dir,
                                      f'best_valid_{ckpter.name}_model.pth')
        torch.save(model.module if isinstance(model, DataParallel) else model,
                   model_tmp_path)
        train_artifact.add_file(model_tmp_path)
        with train_artifact.new_file('split_data.json') as f:
            json.dump(dataset.split_data, f)
        wandb.run.log_artifact(train_artifact)

    return training_history, test_report
Exemple #29
0
def test_get_state_dict():
    state_dict_keys = set([
        'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
        'block.norm.bias', 'block.norm.running_mean', 'block.norm.running_var',
        'block.norm.num_batches_tracked', 'conv.weight', 'conv.bias'
    ])

    model = Model()
    state_dict = get_state_dict(model)
    assert isinstance(state_dict, OrderedDict)
    assert set(state_dict.keys()) == state_dict_keys

    assert_tensor_equal(state_dict['block.conv.weight'],
                        model.block.conv.weight)
    assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias)
    assert_tensor_equal(state_dict['block.norm.weight'],
                        model.block.norm.weight)
    assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias)
    assert_tensor_equal(state_dict['block.norm.running_mean'],
                        model.block.norm.running_mean)
    assert_tensor_equal(state_dict['block.norm.running_var'],
                        model.block.norm.running_var)
    assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
                        model.block.norm.num_batches_tracked)
    assert_tensor_equal(state_dict['conv.weight'], model.conv.weight)
    assert_tensor_equal(state_dict['conv.bias'], model.conv.bias)

    wrapped_model = DDPWrapper(model)
    state_dict = get_state_dict(wrapped_model)
    assert isinstance(state_dict, OrderedDict)
    assert set(state_dict.keys()) == state_dict_keys
    assert_tensor_equal(state_dict['block.conv.weight'],
                        wrapped_model.module.block.conv.weight)
    assert_tensor_equal(state_dict['block.conv.bias'],
                        wrapped_model.module.block.conv.bias)
    assert_tensor_equal(state_dict['block.norm.weight'],
                        wrapped_model.module.block.norm.weight)
    assert_tensor_equal(state_dict['block.norm.bias'],
                        wrapped_model.module.block.norm.bias)
    assert_tensor_equal(state_dict['block.norm.running_mean'],
                        wrapped_model.module.block.norm.running_mean)
    assert_tensor_equal(state_dict['block.norm.running_var'],
                        wrapped_model.module.block.norm.running_var)
    assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
                        wrapped_model.module.block.norm.num_batches_tracked)
    assert_tensor_equal(state_dict['conv.weight'],
                        wrapped_model.module.conv.weight)
    assert_tensor_equal(state_dict['conv.bias'],
                        wrapped_model.module.conv.bias)

    # wrapped inner module
    for name, module in wrapped_model.module._modules.items():
        module = DataParallel(module)
        wrapped_model.module._modules[name] = module
    state_dict = get_state_dict(wrapped_model)
    assert isinstance(state_dict, OrderedDict)
    assert set(state_dict.keys()) == state_dict_keys
    assert_tensor_equal(state_dict['block.conv.weight'],
                        wrapped_model.module.block.module.conv.weight)
    assert_tensor_equal(state_dict['block.conv.bias'],
                        wrapped_model.module.block.module.conv.bias)
    assert_tensor_equal(state_dict['block.norm.weight'],
                        wrapped_model.module.block.module.norm.weight)
    assert_tensor_equal(state_dict['block.norm.bias'],
                        wrapped_model.module.block.module.norm.bias)
    assert_tensor_equal(state_dict['block.norm.running_mean'],
                        wrapped_model.module.block.module.norm.running_mean)
    assert_tensor_equal(state_dict['block.norm.running_var'],
                        wrapped_model.module.block.module.norm.running_var)
    assert_tensor_equal(
        state_dict['block.norm.num_batches_tracked'],
        wrapped_model.module.block.module.norm.num_batches_tracked)
    assert_tensor_equal(state_dict['conv.weight'],
                        wrapped_model.module.conv.module.weight)
    assert_tensor_equal(state_dict['conv.bias'],
                        wrapped_model.module.conv.module.bias)
 def parallelize(self):
     model = DataParallel(self)
     # self.loss_fn = DataParallelCriterion(self.loss_fn)
     return model