Пример #1
0
def main():
    global args, logger, writer, dataset_configs
    global best_top1_epoch, best_top5_epoch, best_top1, best_top5, best_top1_top5, best_top5_top1
    dataset_configs = get_and_save_args(parser)
    parser.set_defaults(**dataset_configs)
    args = parser.parse_args()

    # ================== GPU setting ===============
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    """copy codes and creat dir for saving models and logs"""
    if not os.path.isdir(args.snapshot_pref):
        os.makedirs(args.snapshot_pref)

    logger = Prepare_logger(args)
    logger.info('\ncreating folder: ' + args.snapshot_pref)

    if not args.evaluate:
        writer = SummaryWriter(args.snapshot_pref)
        recorder = Recorder(args.snapshot_pref)
        recorder.writeopt(args)

    logger.info('\nruntime args\n\n{}\n'.format(json.dumps(vars(args), indent=4)))

    """prepare dataset and model"""
    # word2idx = json.load(open('./data/dataset/TACoS/TACoS_word2id_glove_lower.json', 'r'))
    # train_dataset = TACoS(args, split='train')
    # test_dataset = TACoS(args, split='test')
    word2idx = json.load(open('./data/dataset/Charades/Charades_word2id.json', 'r'))
    train_dataset = CharadesSTA(args, split='train')
    test_dataset = CharadesSTA(args, split='test')
    train_dataloader = DataLoader(
        train_dataset, batch_size=args.batch_size,
        shuffle=True, collate_fn=collate_data, num_workers=8, pin_memory=True
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=args.test_batch_size,
        shuffle=False, collate_fn=collate_data, num_workers=8, pin_memory=True
    )
    vocab_size = len(word2idx)

    lr = args.lr
    n_epoch = args.n_epoch

    main_model = mainModel(vocab_size, args, hidden_dim=512, embed_dim=300,
                           bidirection=True, graph_node_features=1024)

    if os.path.exists(args.glove_weights):
        logger.info("Loading glove weights")
        main_model.query_encoder.embedding.weight.data.copy_(torch.load(args.glove_weights))
    else:
        logger.info("Generating glove weights")
        main_model.query_encoder.embedding.weight.data.copy_(glove_init(word2idx))

    main_model = nn.DataParallel(main_model).cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            logger.info(("=> loading checkpoint '{}'".format(args.resume)))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            pretrained_dict = checkpoint['state_dict']
            # only resume part of model paramete
            model_dict = main_model.state_dict()
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            main_model.load_state_dict(model_dict)
            # main_model.load_state_dict(checkpoint['state_dict'])
            logger.info(("=> loaded checkpoint '{}' (epoch {})"
                      .format(args.evaluate, checkpoint['epoch'])))
        else:
            logger.info(("=> no checkpoint found at '{}'".format(args.resume)))

    if args.evaluate:
        topks, accuracy_topks = evaluate(main_model, test_dataloader, word2idx, False)
        for ind, topk in enumerate(topks):
            print("R@{}: {:.1f}\n".format(topk, accuracy_topks[ind] * 100))
        return

    learned_params = None
    if args.is_first_stage:
        for name, value in main_model.named_parameters():
            if 'iou_scores' in name or 'mix_fc' in name:
                value.requires_grad = False
        learned_params = filter(lambda p: p.requires_grad, main_model.parameters())
        n_epoch = 10
    elif args.is_second_stage:
        head_params = main_model.module.fcos.head.iou_scores.parameters()
        fc_params = main_model.module.fcos.head.mix_fc.parameters()
        learned_params = list(head_params) + list(fc_params)
        lr /= 100
    elif args.is_third_stage:
        learned_params = main_model.parameters()
        lr /= 10000

    optimizer = torch.optim.Adam(learned_params, lr)

    for epoch in range(args.start_epoch, n_epoch):

        train_loss = train_epoch(main_model, train_dataloader, optimizer, epoch)

        if (epoch + 1) % args.eval_freq == 0 or epoch == args.n_epoch - 1:

            val_loss, topks, accuracy_topks = validate_epoch(
                main_model, test_dataloader, epoch, word2idx, False
            )

            for ind, topk in enumerate(topks):
                writer.add_scalar('test_result/Recall@top{}'.format(topk), accuracy_topks[ind]*100, epoch)

            is_best_top1 = (accuracy_topks[0]*100) > best_top1
            best_top1 = max((accuracy_topks[0]*100), best_top1)
            if is_best_top1:
                best_top1_epoch = epoch
                best_top1_top5 = accuracy_topks[1]*100
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': main_model.state_dict(),
                'loss': val_loss,
                'top1': accuracy_topks[0]*100,
                'top5': accuracy_topks[1]*100,
            }, is_best_top1, epoch=epoch, top1=accuracy_topks[0]*100, top5=accuracy_topks[1]*100)

            is_best_top5 = (accuracy_topks[1]*100) > best_top5
            best_top5= max((accuracy_topks[1]*100), best_top5)
            if is_best_top5:
                best_top5_epoch = epoch
                best_top5_top1= accuracy_topks[0] * 100
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': main_model.state_dict(),
                'loss': val_loss,
                'top1': accuracy_topks[0]*100,
                'top5': accuracy_topks[1]*100,
            }, is_best_top5, epoch=epoch, top1=accuracy_topks[0]*100, top5=accuracy_topks[1]*100)

            writer.add_scalar('test_result/Best_Recall@top1', best_top1, epoch)
            writer.add_scalar('test_result/Best_Recall@top5', best_top5, epoch)

            logger.info(
                "R@1: {:.2f}, R@5: {:.2f}, epoch: {}\n".format(
                    accuracy_topks[0] * 100, accuracy_topks[1] * 100, epoch)
            )
            logger.info(
                "Current best top1: R@1: {:.2f}, R@5: {:.2f}, epoch: {} \n".format(
                    best_top1, best_top1_top5, best_top1_epoch)
            )
            logger.info(
                "Current best top5: R@1: {:.2f}, R@5: {:.2f}, epoch: {} \n".format(
                    best_top5_top1, best_top5, best_top5_epoch)
            )
Пример #2
0
def main():
    # utils variable
    global args, logger, writer, dataset_configs
    # statistics variable
    global best_accuracy, best_accuracy_epoch
    best_accuracy, best_accuracy_epoch = 0, 0
    # configs
    dataset_configs = get_and_save_args(parser)
    parser.set_defaults(**dataset_configs)
    args = parser.parse_args()
    # select GPUs
    os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    '''Create snapshot_pred dir for copying code and saving models '''
    if not os.path.exists(args.snapshot_pref):
        os.makedirs(args.snapshot_pref)

    if os.path.isfile(args.resume):
        args.snapshot_pref = os.path.dirname(args.resume)

    logger = Prepare_logger(args, eval=args.evaluate)

    if not args.evaluate:
        logger.info(f'\nCreating folder: {args.snapshot_pref}')
        logger.info('\nRuntime args\n\n{}\n'.format(
            json.dumps(vars(args), indent=4)))
    else:
        logger.info(
            f'\nLog file will be save in a {args.snapshot_pref}/Eval.log.')
    '''Dataset'''
    train_dataloader = DataLoader(AVEDataset('./data/', split='train'),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True)

    test_dataloader = DataLoader(AVEDataset('./data/', split='test'),
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 num_workers=8,
                                 pin_memory=True)
    '''model setting'''
    mainModel = main_model()
    mainModel = nn.DataParallel(mainModel).cuda()
    learned_parameters = mainModel.parameters()
    optimizer = torch.optim.Adam(learned_parameters, lr=args.lr)
    # scheduler = StepLR(optimizer, step_size=40, gamma=0.2)
    scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.5)
    criterion = nn.BCEWithLogitsLoss().cuda()
    criterion_event = nn.CrossEntropyLoss().cuda()
    '''Resume from a checkpoint'''
    if os.path.isfile(args.resume):
        logger.info(f"\nLoading Checkpoint: {args.resume}\n")
        mainModel.load_state_dict(torch.load(args.resume))
    elif args.resume != "" and (not os.path.isfile(args.resume)):
        raise FileNotFoundError
    '''Only Evaluate'''
    if args.evaluate:
        logger.info(f"\nStart Evaluation..")
        validate_epoch(mainModel,
                       test_dataloader,
                       criterion,
                       criterion_event,
                       epoch=0,
                       eval_only=True)
        return
    '''Tensorboard and Code backup'''
    writer = SummaryWriter(args.snapshot_pref)
    recorder = Recorder(args.snapshot_pref, ignore_folder="Exps/")
    recorder.writeopt(args)
    '''Training and Testing'''
    for epoch in range(args.n_epoch):
        loss = train_epoch(mainModel, train_dataloader, criterion,
                           criterion_event, optimizer, epoch)

        if ((epoch + 1) % args.eval_freq == 0) or (epoch == args.n_epoch - 1):
            acc = validate_epoch(mainModel, test_dataloader, criterion,
                                 criterion_event, epoch)
            if acc > best_accuracy:
                best_accuracy = acc
                best_accuracy_epoch = epoch
                save_checkpoint(
                    mainModel.state_dict(),
                    top1=best_accuracy,
                    task='Supervised',
                    epoch=epoch + 1,
                )
        scheduler.step()