Exemplo n.º 1
0
    def test(self):
        data_loader = self.dev_loader
        data_loader.batch_size = 1
        prediciton_dic_list = []
        cnt = 1
        last_paragraph_id = -1
        last_turn_id = -1
        answer_filename = 'data/answers.json'
        timer1 = Timer()
        while data_loader.batch_state < len(data_loader):
            # if cnt>3:
            # 	break
            if cnt % 2000 == 0:
                print(timer1.remains(len(data_loader), cnt))
            input_batch = data_loader.get()
            prediction = self.gen_prediction(input_batch)
            turn_id = gen_turn_id(input_batch)
            paragraph_id = gen_paragraph_id(input_batch)
            prediction_dict = {
                "id": paragraph_id[0],
                "turn_id": turn_id[0],
                "answer": prediction[0]
            }

            is_exist, last_paragraph_id, last_turn_id = check_exist_status(
                paragraph_id, turn_id, last_paragraph_id, last_turn_id)
            if not is_exist:
                prediciton_dic_list.append(prediction_dict)
                cnt += 1

        with open(answer_filename, 'w') as outfile:
            json.dump(prediciton_dic_list, outfile)
        test_evaluator.test('data/coqa-dev-v1.0.json', answer_filename)
        print("generate {} answers".format(cnt - 1))
Exemplo n.º 2
0
def main():
    run_started = datetime.today().strftime(
        '%d-%m-%y_%H%M')  #start time to create unique experiment name
    parser = argparse.ArgumentParser(description='UPS Training')
    parser.add_argument('--out',
                        default=f'outputs',
                        help='directory to output the result')
    parser.add_argument('--gpu-id',
                        default='0',
                        type=int,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--num-workers',
                        type=int,
                        default=8,
                        help='number of workers')
    parser.add_argument('--dataset',
                        default='cifar10',
                        type=str,
                        choices=['cifar10', 'cifar100'],
                        help='dataset names')
    parser.add_argument('--n-lbl',
                        type=int,
                        default=4000,
                        help='number of labeled data')
    parser.add_argument('--arch',
                        default='cnn13',
                        type=str,
                        choices=['wideresnet', 'cnn13', 'shakeshake'],
                        help='architecture name')
    parser.add_argument(
        '--iterations',
        default=20,
        type=int,
        help='number of total pseudo-labeling iterations to run')
    parser.add_argument('--epchs',
                        default=1024,
                        type=int,
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch',
                        default=0,
                        type=int,
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--batchsize',
                        default=128,
                        type=int,
                        help='train batchsize')
    parser.add_argument('--lr',
                        '--learning-rate',
                        default=0.03,
                        type=float,
                        help='initial learning rate, default 0.03')
    parser.add_argument('--warmup',
                        default=0,
                        type=float,
                        help='warmup epochs (unlabeled data based)')
    parser.add_argument('--wdecay',
                        default=5e-4,
                        type=float,
                        help='weight decay')
    parser.add_argument('--nesterov',
                        action='store_true',
                        default=True,
                        help='use nesterov momentum')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--seed',
                        type=int,
                        default=-1,
                        help="random seed (-1: don't use random seed)")
    parser.add_argument('--no-progress',
                        action='store_true',
                        help="don't use progress bar")
    parser.add_argument('--dropout',
                        default=0.3,
                        type=float,
                        help='dropout probs')
    parser.add_argument('--num-classes',
                        default=10,
                        type=int,
                        help='total classes')
    parser.add_argument('--class-blnc',
                        default=10,
                        type=int,
                        help='total number of class balanced iterations')
    parser.add_argument(
        '--tau-p',
        default=0.70,
        type=float,
        help='confidece threshold for positive pseudo-labels, default 0.70')
    parser.add_argument(
        '--tau-n',
        default=0.05,
        type=float,
        help='confidece threshold for negative pseudo-labels, default 0.05')
    parser.add_argument(
        '--kappa-p',
        default=0.05,
        type=float,
        help='uncertainty threshold for positive pseudo-labels, default 0.05')
    parser.add_argument(
        '--kappa-n',
        default=0.005,
        type=float,
        help='uncertainty threshold for negative pseudo-labels, default 0.005')
    parser.add_argument(
        '--temp-nl',
        default=2.0,
        type=float,
        help='temperature for generating negative pseduo-labels, default 2.0')
    parser.add_argument(
        '--no-uncertainty',
        action='store_true',
        help='use uncertainty in the pesudo-label selection, default true')
    parser.add_argument(
        '--split-txt',
        default='run1',
        type=str,
        help=
        'extra text to differentiate different experiments. it also creates a new labeled/unlabeled split'
    )
    parser.add_argument('--model-width',
                        default=2,
                        type=int,
                        help='model width for WRN-28')
    parser.add_argument('--model-depth',
                        default=28,
                        type=int,
                        help='model depth for WRN')
    parser.add_argument('--test-freq',
                        default=10,
                        type=int,
                        help='frequency of evaluations')

    args = parser.parse_args()
    #print key configurations
    print(
        '########################################################################'
    )
    print(
        '########################################################################'
    )
    print(f'dataset:                                  {args.dataset}')
    print(f'number of labeled samples:                {args.n_lbl}')
    print(f'architecture:                             {args.arch}')
    print(f'number of pseudo-labeling iterations:     {args.iterations}')
    print(f'number of epochs:                         {args.epchs}')
    print(f'batch size:                               {args.batchsize}')
    print(f'lr:                                       {args.lr}')
    print(f'value of tau_p:                           {args.tau_p}')
    print(f'value of tau_n:                           {args.tau_n}')
    print(f'value of kappa_p:                         {args.kappa_p}')
    print(f'value of kappa_n:                         {args.kappa_n}')
    print(
        '########################################################################'
    )
    print(
        '########################################################################'
    )

    DATASET_GETTERS = {'cifar10': get_cifar10, 'cifar100': get_cifar100}
    exp_name = f'exp_{args.dataset}_{args.n_lbl}_{args.arch}_{args.split_txt}_{args.epchs}_{args.class_blnc}_{args.tau_p}_{args.tau_n}_{args.kappa_p}_{args.kappa_n}_{run_started}'
    device = torch.device('cuda', args.gpu_id)
    args.device = device
    args.exp_name = exp_name
    args.dtype = torch.float32
    if args.seed != -1:
        set_seed(args)
    args.out = os.path.join(args.out, args.exp_name)
    start_itr = 0

    if args.resume and os.path.isdir(args.resume):
        resume_files = os.listdir(args.resume)
        resume_itrs = [
            int(item.replace('.pkl', '').split("_")[-1])
            for item in resume_files if 'pseudo_labeling_iteration' in item
        ]
        if len(resume_itrs) > 0:
            start_itr = max(resume_itrs)
        args.out = args.resume
    os.makedirs(args.out, exist_ok=True)
    writer = SummaryWriter(args.out)

    if args.dataset == 'cifar10':
        args.num_classes = 10
    elif args.dataset == 'cifar100':
        args.num_classes = 100

    for itr in range(start_itr, args.iterations):
        if itr == 0 and args.n_lbl < 4000:  #use a smaller batchsize to increase the number of iterations
            args.batch_size = 64
            args.epochs = 1024
        else:
            args.batch_size = args.batchsize
            args.epochs = args.epchs

        if os.path.exists(
                f'data/splits/{args.dataset}_basesplit_{args.n_lbl}_{args.split_txt}.pkl'
        ):
            lbl_unlbl_split = f'data/splits/{args.dataset}_basesplit_{args.n_lbl}_{args.split_txt}.pkl'
        else:
            lbl_unlbl_split = None

        #load the saved pseudo-labels
        if itr > 0:
            pseudo_lbl_dict = f'{args.out}/pseudo_labeling_iteration_{str(itr)}.pkl'
        else:
            pseudo_lbl_dict = None

        lbl_dataset, nl_dataset, unlbl_dataset, test_dataset = DATASET_GETTERS[
            args.dataset]('data/datasets', args.n_lbl, lbl_unlbl_split,
                          pseudo_lbl_dict, itr, args.split_txt)

        model = create_model(args)
        model.to(args.device)

        nl_batchsize = int((float(args.batch_size) * len(nl_dataset)) /
                           (len(lbl_dataset) + len(nl_dataset)))

        if itr == 0:
            lbl_batchsize = args.batch_size
            args.iteration = len(lbl_dataset) // args.batch_size
        else:
            lbl_batchsize = args.batch_size - nl_batchsize
            args.iteration = (len(lbl_dataset) +
                              len(nl_dataset)) // args.batch_size

        lbl_loader = DataLoader(lbl_dataset,
                                sampler=RandomSampler(lbl_dataset),
                                batch_size=lbl_batchsize,
                                num_workers=args.num_workers,
                                drop_last=True)

        nl_loader = DataLoader(nl_dataset,
                               sampler=RandomSampler(nl_dataset),
                               batch_size=nl_batchsize,
                               num_workers=args.num_workers,
                               drop_last=True)

        test_loader = DataLoader(test_dataset,
                                 sampler=SequentialSampler(test_dataset),
                                 batch_size=args.batch_size,
                                 num_workers=args.num_workers)

        unlbl_loader = DataLoader(unlbl_dataset,
                                  sampler=SequentialSampler(unlbl_dataset),
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              nesterov=args.nesterov)
        args.total_steps = args.epochs * args.iteration
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, args.warmup * args.iteration, args.total_steps)
        start_epoch = 0

        if args.resume and itr == start_itr and os.path.isdir(args.resume):
            resume_itrs = [
                int(item.replace('.pth.tar', '').split("_")[-1])
                for item in resume_files if 'checkpoint_iteration_' in item
            ]
            if len(resume_itrs) > 0:
                checkpoint_itr = max(resume_itrs)
                resume_model = os.path.join(
                    args.resume,
                    f'checkpoint_iteration_{checkpoint_itr}.pth.tar')
                if os.path.isfile(resume_model) and checkpoint_itr == itr:
                    checkpoint = torch.load(resume_model)
                    best_acc = checkpoint['best_acc']
                    start_epoch = checkpoint['epoch']
                    model.load_state_dict(checkpoint['state_dict'])
                    optimizer.load_state_dict(checkpoint['optimizer'])
                    scheduler.load_state_dict(checkpoint['scheduler'])

        model.zero_grad()
        best_acc = 0
        for epoch in range(start_epoch, args.epochs):
            if itr == 0:
                train_loss = train_initial(args, lbl_loader, model, optimizer,
                                           scheduler, epoch, itr)
            else:
                train_loss = train_regular(args, lbl_loader, nl_loader, model,
                                           optimizer, scheduler, epoch, itr)

            test_loss = 0.0
            test_acc = 0.0
            test_model = model
            if epoch > (args.epochs + 1) / 2 and epoch % args.test_freq == 0:
                test_loss, test_acc = test(args, test_loader, test_model)
            elif epoch == (args.epochs - 1):
                test_loss, test_acc = test(args, test_loader, test_model)

            writer.add_scalar('train/1.train_loss', train_loss,
                              (itr * args.epochs) + epoch)
            writer.add_scalar('test/1.test_acc', test_acc,
                              (itr * args.epochs) + epoch)
            writer.add_scalar('test/2.test_loss', test_loss,
                              (itr * args.epochs) + epoch)

            is_best = test_acc > best_acc
            best_acc = max(test_acc, best_acc)
            model_to_save = model.module if hasattr(model, "module") else model
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model_to_save.state_dict(),
                    'acc': test_acc,
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                }, is_best, args.out, f'iteration_{str(itr)}')

        checkpoint = torch.load(
            f'{args.out}/checkpoint_iteration_{str(itr)}.pth.tar')
        model.load_state_dict(checkpoint['state_dict'])
        model.zero_grad()

        #pseudo-label generation and selection
        pl_loss, pl_acc, pl_acc_pos, total_sel_pos, pl_acc_neg, total_sel_neg, unique_sel_neg, pseudo_label_dict = pseudo_labeling(
            args, unlbl_loader, model, itr)

        writer.add_scalar('pseudo_labeling/1.regular_loss', pl_loss, itr)
        writer.add_scalar('pseudo_labeling/2.regular_acc', pl_acc, itr)
        writer.add_scalar('pseudo_labeling/3.pseudo_acc_positive', pl_acc_pos,
                          itr)
        writer.add_scalar('pseudo_labeling/4.total_sel_positive',
                          total_sel_pos, itr)
        writer.add_scalar('pseudo_labeling/5.pseudo_acc_negative', pl_acc_neg,
                          itr)
        writer.add_scalar('pseudo_labeling/6.total_sel_negative',
                          total_sel_neg, itr)
        writer.add_scalar('pseudo_labeling/7.unique_samples_negative',
                          unique_sel_neg, itr)

        with open(
                os.path.join(args.out,
                             f'pseudo_labeling_iteration_{str(itr+1)}.pkl'),
                "wb") as f:
            pickle.dump(pseudo_label_dict, f)

        with open(os.path.join(args.out, 'log.txt'), 'a+') as ofile:
            ofile.write(
                f'############################# PL Iteration: {itr+1} #############################\n'
            )
            ofile.write(
                f'Last Test Acc: {test_acc}, Best Test Acc: {best_acc}\n')
            ofile.write(
                f'PL Acc (Positive): {pl_acc_pos}, Total Selected (Positive): {total_sel_pos}\n'
            )
            ofile.write(
                f'PL Acc (Negative): {pl_acc_neg}, Total Selected (Negative): {total_sel_neg}, Unique Negative Samples: {unique_sel_neg}\n\n'
            )

    writer.close()
Exemplo n.º 3
0
def train_net(cfg):
    data_source = cfg['data_source']
    data_params = cfg['data_params']
    teacher_params = cfg['teacher_params']
    student_params = cfg['student_params']
    loss_params = cfg['loss_params']
    test_params = cfg['test_params']
    opt_params = cfg['opt_params']

    # Set data iterators
    train_iter, _ = get_rec_data_iterators(data_source['params']['train_db'], '',
                                           cfg['input_shape'], cfg['batch_size'],
                                           data_params, cfg['devices_id'])

    devices = [mx.gpu(device_id) for device_id in cfg['devices_id']]
    batch_size = cfg['batch_size'] * len(devices)
    num_batches = data_source['train_samples_num'] // batch_size

    # Set teacher extractor
    if teacher_params['type'] == 'insightface-resnet':
        teacher_net = get_model(teacher_params, True)
    else:
        sys.exit('Unsupported teacher net architecture: %s !' % teacher_params['type'])
    # Init teacher extractor
    teacher_net.load_parameters('%s-%04d.params' % teacher_params['init'], ctx=devices,
                                allow_missing=False, ignore_extra=False)
    logging.info("Teacher extractor parameters were successfully loaded")
    teacher_net.hybridize(static_alloc=True, static_shape=True)

    # Set student extractor
    if student_params['type'] == 'insightface-resnet':
        student_net = get_model(student_params, False, 'student_')
    else:
        sys.exit('Unsupported student net architecture: %s !' % student_params['type'])
    # Init student extractor
    if student_params['init']:
        student_net.load_parameters('%s-%04d.params' % student_params['init'], ctx=devices,
                                    allow_missing=False, ignore_extra=False)
        logging.info("Student extractor parameters were successfully loaded")
    else:
        init_params = [
            ('.*gamma|.*alpha|.*running_mean|.*running_var', mx.init.Constant(1)),
            ('.*beta|.*bias', mx.init.Constant(0.0)),
            ('.*weight', mx.init.Xavier())
        ]
        for mask, initializer in init_params:
            student_net.collect_params(mask).initialize(initializer, ctx=devices)
    student_net.hybridize(static_alloc=True, static_shape=True)
    
    params = student_net.collect_params()

    # Set teacher classifier
    teacher_clf = None
    if loss_params['HKD']['weight'] > 0.0:
        teacher_clf = get_angular_classifier(data_source['num_classes'],
                                             teacher_params['embedding_dim'],
                                             loss_params['classification'])
        # init teacher classifier
        filename = '%s-%04d.params' % loss_params['HKD']['teacher_init']
        teacher_clf.load_parameters(filename, ctx=devices,
                                    allow_missing=False,
                                    ignore_extra=False)
        logging.info("Teacher classifier parameters "
                     "were successfully loaded")
        teacher_clf.hybridize(static_alloc=True, static_shape=True)
        
    # Set student classifier
    student_clf = None
    if loss_params['HKD']['weight'] > 0.0 or loss_params['classification']['weight'] > 0.0:
        student_clf = get_angular_classifier(data_source['num_classes'],
                                             student_params['embedding_dim'],
                                             loss_params['classification'],
                                             'student_')
        # init student classifier
        if loss_params['classification']['student_init']:
            filename = '%s-%04d.params' % loss_params['classification']['student_init']
            student_clf.load_parameters(filename, ctx=devices,
                                        allow_missing=False,
                                        ignore_extra=False)
            logging.info("Student classifier parameters "
                         "were successfully loaded")
        else:
            student_clf.initialize(mx.init.Normal(0.01), ctx=devices)
        student_clf.hybridize(static_alloc=True, static_shape=True)
        params.update(student_clf.collect_params())

    # Set losses
    L_clf, L_hkd, L_mld = get_losses(loss_params)

     # Set train evaluation metrics
    eval_metrics_train = init_eval_metrics(loss_params)

    # Set optimizer
    optimizer = 'sgd'
    optimizer_params = {'wd': opt_params['wd'], 'momentum': opt_params['momentum']}

    # Set trainer
    trainer = gluon.Trainer(params, optimizer, optimizer_params, kvstore='local')

    # Initialize test results
    test_best_result = {db_name : [0.0, 0] for db_name in test_params['dbs']}

    # TRAINING LOOP
    iteration = 0
    for epoch in range(opt_params['num_epoch']):
        tic_epoch = time.time()

        # reset metrics
        for metric in eval_metrics_train.values():
            metric['metric'].reset()
            metric['losses'] = []

        # update learning rate: step decay
        if epoch == 0:
            trainer.set_learning_rate(opt_params['lr_base'])
        elif epoch > 0 and not epoch % opt_params['lr_epoch_step']:
            trainer.set_learning_rate(trainer.learning_rate * opt_params['lr_factor'])
            logging.info("Learning rate has been changed to %f" % trainer.learning_rate)

        tic_batch = time.time()
        for i, batch in enumerate(train_iter):
            iteration += 1
            # process batch
            data, label = unpack_batch(batch)
            loss = []
            for X, y_gt in zip(data, label):
                # get teacher predictions
                with autograd.predict_mode():
                    embeddings_teacher = teacher_net(X)
                    if teacher_clf:
                        logits_teacher = teacher_clf(embeddings_teacher, y_gt)
                # get student predictions and compute loss
                with autograd.record():
                    embeddings_student = student_net(X)
                    if student_clf:
                        logits_student = student_clf(embeddings_student, y_gt)
                    device_losses = []
                    # classification loss
                    if L_clf:
                        loss_clf = loss_params['classification']['weight'] * \
                                   L_clf(logits_student, y_gt)
                        device_losses.append(loss_clf)
                        eval_metrics_train['classification']['losses'].append(loss_clf)
                    # Hinton's knowledge distillation loss
                    if L_hkd:
                        loss_hkd = loss_params['HKD']['weight'] * \
                                   L_hkd(logits_student, logits_teacher)
                        device_losses.append(loss_hkd)
                        eval_metrics_train['HKD']['losses'].append(loss_hkd)
                    # metric learning distillation losses
                    for name, L, weight in L_mld:
                        loss_mld = weight * L(embeddings_student, embeddings_teacher)
                        device_losses.append(loss_mld)
                        eval_metrics_train[name]['losses'].append(loss_mld)
                    # aggregate all losses
                    device_losses = [loss_term.mean() for loss_term in device_losses]
                    loss.append(mx.nd.add_n(*device_losses))
            eval_metrics_train['total']['losses'] = loss

            # Backpropagate errors
            for l in loss:
                l.backward()
            trainer.step(batch_size)

            # update metrics
            for metric in eval_metrics_train.values():
                metric['metric'].update(_, metric['losses'])
                metric['losses'] = []

            # display training statistics
            if not (i+1) % cfg['display_period']:
                disp_template = 'Epoch[%d/%d] Batch[%d/%d]\tSpeed: %f samples/sec\tlr=%f'
                disp_params = [epoch, opt_params['num_epoch'], i+1, num_batches,
                               batch_size * cfg['display_period'] / (time.time() - tic_batch),
                               trainer.learning_rate]
                for metric in eval_metrics_train.values():
                    metric_name, metric_score = metric['metric'].get()
                    disp_template += '\t%s=%f'
                    disp_params.append(metric_name)
                    disp_params.append(metric_score)
                logging.info(disp_template % tuple(disp_params))
                tic_batch = time.time()

            if not iteration % cfg['test_period']:
                period_idx = iteration // cfg['test_period']
                # save model
                logging.info("[Epoch %d][Batch %d] "
                             "Saving network params [%d] at %s" % 
                            (epoch, i, period_idx, cfg['experiment_dir']))
                student_net.export('%s/student' % cfg['experiment_dir'], period_idx)
                if student_net:
                    student_clf.export('%s/student-clf' % cfg['experiment_dir'], period_idx)
                # test model using outside data
                if test_params['dbs']:
                    logging.info('[Epoch %d] Testing student network ...' % epoch)
                    # emore bin-files testing
                    for db_name in test_params['dbs']:
                        db_path = '%s/%s.bin' % (test_params['dbs_root'], db_name)
                        data_set = load_bin(db_path, [cfg['input_shape'][1], cfg['input_shape'][2]])
                        _, _, acc, std, _, _ = test(data_set, student_net, cfg['batch_size'], 10)
                        if acc > test_best_result[db_name][0]:
                            test_best_result[db_name] = [acc, period_idx]
                        logging.info("Epoch[%d] Batch[%d] %s: "
                                     "Accuracy-Flip = %1.5f+-%1.5f "
                                     "(best: %f, at snapshot %04d)" %
                                    (epoch, i+1, db_name, acc, std,
                                     test_best_result[db_name][0],
                                     test_best_result[db_name][1]))

        # estimate epoch training speed
        throughput = int(batch_size * (i+1) / (time.time() - tic_epoch))
        logging.info("[Epoch %d] Speed: %d samples/sec\t"
                     "Time cost: %f seconds" %
                     (epoch, throughput, time.time() - tic_epoch))
Exemplo n.º 4
0
def main():
    args = parse_args()

    working_dir_name = get_working_dir_name(args.results_root, args)
    working_dir = os.path.join(args.results_root, working_dir_name)
    check_dir(working_dir)
    print(f'Working Dir : {working_dir}')
    make_results_folders(working_dir) # 'weights' / 'test_img'

    # update args 
    args.working_dir = working_dir
    args.weights_dir = os.path.join(args.working_dir, 'weights')
    args.test_img_dir = os.path.join(args.working_dir, 'test_img')

    # init writer for tensorboard 
    writer = SummaryWriter(working_dir)
    print(f'Tensorboard info will be saved in \'{working_dir}\'')

    # save args in run folder 
    with open(os.path.join(working_dir, 'args.txt'), 'w') as f: 
        json.dump(args.__dict__, f, indent=4)

    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(args)

    # Dataset 
    if args.dataset == 'sixray':
        train_dataset = sixrayDataset('../SIXray', mode='train')
        eval_dataset = sixrayDataset('../SIXray', mode='eval')
    elif args.dataset ==  'coco' :
        train_dataset = cocoDataset('../coco', mode='train')
        eval_dataset = cocoDataset('../coco', mode='eval')
    else : 
        raise RuntimeError('Invalide dataset type')


    # Dataloader 
    train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, 
                            shuffle=True, num_workers=0, 
                            collate_fn=collate_sixray)

    eval_loader = DataLoader(dataset=eval_dataset, batch_size=1, 
                            shuffle=False, num_workers=0, 
                            collate_fn=collate_sixray)

    # Models 
    if args.model == 'res34': 
        backbone = torchvision.models.resnet34(pretrained=True) # res 34
        backbone = get_resent_features(backbone)
        out_ch = 512 # resnet18,34 : 512 
    elif args.model == 'res50':
        backbone = torchvision.models.resnet50(pretrained=True) # res 50 
        backbone = get_resent_features(backbone)
        out_ch = 2048 # resnet50~152 : 2048
    elif args.model == 'res34AAA':
        backbone = AAA('res34', True, args)
        out_ch = 512 # resnet18,34 : 512 
    else : 
        assert()

    # Anchor size : ((size, size*2, size*4, size*8, size*16), )
    anchor_size = (tuple(int(args.anchor_init_size * math.pow(2, i)) for i in range(5)), )

    backbone.out_channels = out_ch

    anchor_generator = AnchorGenerator(sizes=anchor_size,
                                        aspect_ratios=((0.5, 1.0, 2.0),))
    
    roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], 
                                                    output_size=7, 
                                                    sampling_ratio=2)
    model = FasterRCNN(backbone=backbone, 
                    num_classes=7, # 6 class + 1 background 
                    rpn_anchor_generator=anchor_generator,
                    box_roi_pool=roi_pooler,
                    min_size=args.img_min_size, 
                    max_size=args.img_max_size).to(args.device)

    # if args.model == 'res50fpn':
    #     model = fasterrcnn_resnet50_fpn(pretrained=True).to(args.device)
    #     model.rpn.anchor_generator.sizes = ((8,), (16,), (32,), (64,), (128,))

    # Optimizer 
    optimizer = optim.SGD(model.parameters(), 
                          lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    schedular = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[90, 120], gamma=0.1)
    
    # train 
    global_step = 0
    accuracies = {}
    for epoch in range(1, args.epochs+1):
        progress = tqdm.tqdm(train_loader)
        for images, targets, _ in progress:
            model.train() 
            
            images = list(image.to(args.device) for image in images)
            targets = [{k: v.to(args.device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            loss_cls = loss_dict['loss_classifier'].item()
            loss_reg = loss_dict['loss_box_reg'].item()
            loss_obj = loss_dict['loss_objectness'].item()
            loss_rpn_reg = loss_dict['loss_rpn_box_reg'].item()

            progress.set_description(f'Train {epoch} / {args.epochs}, lr : {optimizer.param_groups[0]["lr"]} ' +
                                    f'Loss : [TT]{losses:.3f}, [HC]{loss_cls:.3f}, [HR]{loss_reg:.3f}, ' +
                                    f'[RO]{loss_obj:.3f}, [RR]{loss_rpn_reg:.3f} ')
            
        if epoch % args.save_epoch == 0 : 
            torch.save(model.state_dict(), 
                       os.path.join(args.weights_dir, f'{args.model}_{epoch}.ckpt'))
                       
        if epoch % args.eval_epoch == 0 : 
            accuracies = evaluate(model, eval_loader, args, epoch, accs=accuracies, update_acc=True)
            if args.test_img_name == '': 
                image_path = os.path.join(args.test_img_folder, 
                                          random.sample(os.listdir(args.test_img_folder), 1)[0])
                args.test_img_name = image_path
            test(model, image_path, args, epoch)
        
        ## Tensor Board 
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], global_step)
        writer.add_scalar('loss_classifier', loss_cls, global_step)
        writer.add_scalar('loss_box_reg', loss_reg, global_step)
        writer.add_scalar('loss_objectness', loss_obj, global_step)
        writer.add_scalar('loss_rpn_box_reg', loss_rpn_reg, global_step)
        global_step += 1

        schedular.step()

    ## Evaluate rankings 
    accuracies = sorted(accuracies.items(), key=lambda x: x[1], reverse=True)
    print('##### TOP 3 models by iou 0.5 value #####')
    for i in range(3) : 
        print(f'TOP {i+1} : epoch {accuracies[i][0]}, accuracy {accuracies[i][1]}')
        np.array([[cf[0], cf[1]] for cf in test_cf], np.int32))
    """define model"""
    model = Recommender(n_params, args, graph, mean_mat_list[0]).to(device)
    """define optimizer"""
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    load_dir = './result/epoch-0.model'
    last_model = torch.load(load_dir)
    model.load_state_dict(last_model['net'])
    optimizer.load_state_dict(last_model['optimizer'])

    cur_best_pre_0 = 0
    stopping_step = 0
    should_stop = False

    ret = test(model, user_dict, n_params)

    train_res = PrettyTable()
    train_res.field_names = [
        "Epoch", "training time", "tesing time", "Loss", "recall", "ndcg",
        "precision", "hit_ratio"
    ]
    # train_res.add_row(
    #     [9, train_e_t - train_s_t, test_e_t - test_s_t, loss.item(), ret['recall'], ret['ndcg'], ret['precision'], ret['hit_ratio']]
    # )
    print(train_res)

    # *********************************************************
    # early stopping when cur_best_pre_0 is decreasing for ten successive steps.
    cur_best_pre_0, stopping_step, should_stop = early_stopping(
        ret['recall'][0],