def get_base_optimizer(model):

    kwargs = {
        'weight_decay': 5e-4,
        'lr': 0.0003,
        'betas': (0.9, 0.999),
    }
    param_groups = model.parameters()

    optimizer = torch.optim.Adam(param_groups, **kwargs)
    scheduler = init_lr_scheduler(optimizer, stepsize=[20, 40], gamma=0.1)

    return optimizer, scheduler
def get_base_sgd_optimizer(model):

    kwargs = {
        'weight_decay': 5e-4,
        'lr': 0.001,
        'momentum': 0.9,
    }

    param_groups = model.parameters()

    optimizer = torch.optim.SGD(param_groups, **kwargs)
    scheduler = init_lr_scheduler(optimizer, stepsize=[25, 50], gamma=0.1)

    return optimizer, scheduler
def get_RRI_optimizer(
    model,
    lr
):

    kwargs = {
        'weight_decay': 5e-4,
        'lr': lr,
        'momentum': 0.9,
    }
    param_groups = model.parameters()

    optimizer = torch.optim.SGD(param_groups, **kwargs)
    scheduler = init_lr_scheduler(optimizer, stepsize=[12], gamma=0.1)

    return optimizer, scheduler
def main():
    global args

    set_random_seed(args.seed)
    if not args.use_avai_gpus:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False
    log_name = 'log_test.txt' if args.evaluate else 'log_train.txt'
    sys.stdout = Logger(osp.join(args.save_dir, log_name))
    print('==========\nArgs:{}\n=========='.format(args))

    if use_gpu:
        print('Currently using GPU {}'.format(args.gpu_devices))
        cudnn.benchmark = True
    else:
        print('Currently using CPU, however, GPU is highly recommended')

    print('Initializing video data manager')
    dm = VideoDataManager(use_gpu, **video_dataset_kwargs(args))
    trainloader, testloader_dict = dm.return_dataloaders()

    print('Initializing model: {}'.format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=dm.num_train_pids,
                              loss={'xent', 'htri'},
                              pretrained=not args.no_pretrained,
                              use_gpu=use_gpu)
    print('Model size: {:.3f} M'.format(count_num_param(model)))

    if args.load_weights and check_isfile(args.load_weights):
        load_pretrained_weights(model, args.load_weights)

    if args.resume and check_isfile(args.resume):
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch'] + 1
        best_rank1 = checkpoint['rank1']
        print('Loaded checkpoint from "{}"'.format(args.resume))
        print('- start_epoch: {}\n- rank1: {}'.format(args.start_epoch,
                                                      best_rank1))

    model = nn.DataParallel(model).cuda() if use_gpu else model

    criterion = CrossEntropyLoss(num_classes=dm.num_train_pids,
                                 use_gpu=use_gpu,
                                 label_smooth=args.label_smooth)
    criterion_htri = TripletLoss(margin=args.margin)
    optimizer = init_optimizer(model, **optimizer_kwargs(args))
    scheduler = init_lr_scheduler(optimizer, **lr_scheduler_kwargs(args))

    if args.evaluate:
        print('Evaluate only')

        for name in args.target_names:
            print('Evaluating {} ...'.format(name))
            queryloader = testloader_dict[name]['query']
            galleryloader = testloader_dict[name]['gallery']
            distmat = test(model,
                           queryloader,
                           galleryloader,
                           args.pool_tracklet_features,
                           use_gpu,
                           return_distmat=True)

            if args.visualize_ranks:
                visualize_ranked_results(distmat,
                                         dm.return_testdataset_by_name(name),
                                         save_dir=osp.join(
                                             args.save_dir, 'ranked_results',
                                             name),
                                         topk=20)
        return

    start_time = time.time()
    ranklogger = RankLogger(args.source_names, args.target_names)
    train_time = 0
    print('=> Start training')

    if args.fixbase_epoch > 0:
        print(
            'Train {} for {} epochs while keeping other layers frozen'.format(
                args.open_layers, args.fixbase_epoch))
        initial_optim_state = optimizer.state_dict()

        for epoch in range(args.fixbase_epoch):
            start_train_time = time.time()
            train(epoch,
                  model,
                  criterion_xent,
                  criterion_htri,
                  optimizer,
                  trainloader,
                  use_gpu,
                  fixbase=True)
            train_time += round(time.time() - start_train_time)

        print('Done. All layers are open to train for {} epochs'.format(
            args.max_epoch))
        optimizer.load_state_dict(initial_optim_state)

    for epoch in range(args.start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion_xent, criterion_htri, optimizer,
              trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_freq > 0 and (
                epoch + 1) % args.eval_freq == 0 or (epoch +
                                                     1) == args.max_epoch:
            print('=> Test')

            for name in args.target_names:
                print('Evaluating {} ...'.format(name))
                queryloader = testloader_dict[name]['query']
                galleryloader = testloader_dict[name]['gallery']
                rank1 = test(model, queryloader, galleryloader,
                             args.pool_tracklet_features, use_gpu)
                ranklogger.write(name, epoch + 1, rank1)

            save_checkpoint(
                {
                    'state_dict': model.state_dict(),
                    'rank1': rank1,
                    'epoch': epoch,
                }, False,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        'Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.'.
        format(elapsed, train_time))
    ranklogger.show_summary()