Exemplo n.º 1
0
def main(args):
    # args = parser.parse_args()
    utils.init_distributed_mode(args)
    print("git:\n  {}\n".format(utils.get_sha()))
    print(args)

    if args.seed is not None:
        # random.seed(args.seed)
        # torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

        # fix the seed for reproducibility
        seed = args.seed + utils.get_rank()
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

    ##################################
    # Logging setting
    ##################################
    if args.output_dir and utils.is_main_process():
        logging.basicConfig(
            filename=os.path.join(args.output_dir, args.log_name),
            filemode='w',
            format=
            '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
            level=logging.INFO)
    warnings.filterwarnings("ignore")

    ##################################
    # Save to logging
    ##################################
    if utils.is_main_process():
        logging.info(str(args))

    ##################################
    # Initialize dataset
    ##################################

    if not args.evaluate:
        # build_vocab_flag=True, # Takes a long time to build a vocab
        train_dataset = GQATorchDataset(split='train_unbiased',
                                        build_vocab_flag=False,
                                        load_vocab_flag=False)

        if args.distributed:
            sampler_train = torch.utils.data.DistributedSampler(train_dataset)
        else:
            sampler_train = torch.utils.data.RandomSampler(train_dataset)

        batch_sampler_train = torch.utils.data.BatchSampler(sampler_train,
                                                            args.batch_size,
                                                            drop_last=True)

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_sampler=batch_sampler_train,
            collate_fn=GQATorchDataset_collate_fn,
            num_workers=args.workers)

        # Old version
        # train_loader = torch.utils.data.DataLoader(
        #     train_dataset, batch_size=args.batch_size, shuffle=True,
        #     collate_fn=GQATorchDataset_collate_fn,
        #     num_workers=args.workers, pin_memory=True)

    val_dataset_list = []
    for eval_split in args.evaluate_sets:
        val_dataset_list.append(
            GQATorchDataset(split=eval_split,
                            build_vocab_flag=False,
                            load_vocab_flag=args.evaluate))
    val_dataset = torch.utils.data.ConcatDataset(val_dataset_list)

    if args.distributed:
        sampler_val = torch.utils.data.DistributedSampler(val_dataset,
                                                          shuffle=False)
    else:
        sampler_val = torch.utils.data.SequentialSampler(val_dataset)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        sampler=sampler_val,
        drop_last=False,
        collate_fn=GQATorchDataset_collate_fn,
        num_workers=args.workers)

    # Old version
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,
    #     batch_size=args.batch_size, shuffle=False,
    #     collate_fn=GQATorchDataset_collate_fn,
    #     num_workers=args.workers, pin_memory=True)

    ##################################
    # Initialize model
    # - note: must init dataset first. Since we will use the vocab from the dataset
    ##################################
    model = PipelineModel()

    ##################################
    # Deploy model on GPU
    ##################################
    model = model.to(device=cuda)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    ##################################
    # define optimizer (and scheduler)
    ##################################

    # optimizer = torch.optim.SGD(model.parameters(), args.lr,
    #                             momentum=args.momentum,
    #                             weight_decay=args.weight_decay)
    optimizer = torch.optim.Adam(
        params=model.parameters(),
        lr=args.lr,
        betas=(0.9, 0.999),
        eps=1e-08,
        weight_decay=0,  #  weight_decay=args.weight_decay
        amsgrad=False,
    )
    # optimizer = torch.optim.AdamW(
    #     params=model.parameters(),
    #     lr=args.lr,
    #     weight_decay=args.weight_decay)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model_without_ddp.load_state_dict(checkpoint['model'])
            if not args.evaluate:
                if 'optimizer' in checkpoint:
                    optimizer.load_state_dict(checkpoint['optimizer'])
                if 'lr_scheduler' in checkpoint:
                    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
                if 'epoch' in checkpoint:
                    args.start_epoch = checkpoint['epoch'] + 1

            # checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']
            # model.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            # print("=> loaded checkpoint '{}' (epoch {})"
            #       .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # cudnn.benchmark = True

    ##################################
    # Define loss functions (criterion)
    ##################################
    # criterion = torch.nn.CrossEntropyLoss().cuda()

    text_pad_idx = GQATorchDataset.TEXT.vocab.stoi[
        GQATorchDataset.TEXT.pad_token]
    criterion = {
        "program":
        torch.nn.CrossEntropyLoss(ignore_index=text_pad_idx).to(device=cuda),
        "full_answer":
        torch.nn.CrossEntropyLoss(ignore_index=text_pad_idx).to(device=cuda),
        "short_answer":
        torch.nn.CrossEntropyLoss().to(device=cuda),
        # "short_answer": torch.nn.BCEWithLogitsLoss().to(device=cuda), # sigmoid
        "execution_bitmap":
        torch.nn.BCELoss().to(device=cuda),
    }

    ##################################
    # If Evaluate Only
    ##################################

    if args.evaluate:
        validate(val_loader, model, criterion, args, DUMP_RESULT=True)
        return

    ##################################
    # Main Training Loop
    ##################################

    # best_acc1 = 0
    for epoch in range(args.start_epoch, args.epochs):

        if args.distributed:
            ##################################
            # In distributed mode, calling the :meth`set_epoch(epoch) <set_epoch>` method
            # at the beginning of each epoch before creating the DataLoader iterator is necessary
            # to make shuffling work properly across multiple epochs.
            # Otherwise, the same ordering will be always used.
            ##################################
            sampler_train.set_epoch(epoch)

        lr_scheduler.step()

        # adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)
        # evaluate on validation set
        if (epoch + 1) % 5 == 0:
            validate(val_loader,
                     model,
                     criterion,
                     args,
                     FAST_VALIDATE_FLAG=False)

        # # remember best acc@1 and save checkpoint
        # save_checkpoint({
        #     'epoch': epoch + 1,
        #     # 'arch': args.arch,
        #     'state_dict': model.state_dict(),
        #     # 'best_acc1': best_acc1,
        #     'optimizer' : optimizer.state_dict(),
        # }, is_best)

        if args.output_dir:
            output_dir = pathlib.Path(args.output_dir)
            checkpoint_paths = [output_dir / 'checkpoint.pth']
            # extra checkpoint before LR drop and every 100 epochs
            if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
                checkpoint_paths.append(output_dir /
                                        f'checkpoint{epoch:04}.pth')
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master(
                    {
                        'model': model_without_ddp.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'epoch': epoch,
                        'args': args,
                    }, checkpoint_path)
Exemplo n.º 2
0
def validate(val_loader,
             model,
             criterion,
             args,
             FAST_VALIDATE_FLAG=False,
             DUMP_RESULT=False):
    batch_time = AverageMeter('Time', ':6.3f')

    program_acc = AverageMeter('Acc@Program', ':6.2f')
    program_group_acc = AverageMeter('Acc@ProgramGroup', ':4.2f')
    program_non_empty_acc = AverageMeter('Acc@ProgramNonEmpty', ':4.2f')

    # bitmap_precision = AverageMeter('Precision@Bitmap', ':4.2f')
    # bitmap_recall = AverageMeter('Recall@Bitmap', ':4.2f')

    # full_answer_acc = AverageMeter('Acc@Full', ':6.2f')
    short_answer_acc = AverageMeter('Acc@Short', ':6.2f')

    progress = ProgressMeter(len(val_loader), [
        batch_time, program_acc, program_group_acc, program_non_empty_acc,
        short_answer_acc
    ],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    if DUMP_RESULT:
        quesid2ans = {}

    with torch.no_grad():
        end = time.time()
        for i, (data_batch) in enumerate(val_loader):

            questionID, questions, gt_scene_graphs, programs, full_answers, short_answer_label, types = data_batch

            questions, gt_scene_graphs, programs, full_answers, short_answer_label = [
                datum.to(device=cuda, non_blocking=True) for datum in [
                    questions, gt_scene_graphs, programs, full_answers,
                    short_answer_label
                ]
            ]

            this_batch_size = questions.size(1)

            if FAST_VALIDATE_FLAG:
                raise NotImplementedError(
                    "Should not use fast validation. Only for short answer accuracy"
                )
                ##################################
                # Prepare training input and training target for text generation
                ##################################
                programs_input = programs[:-1]
                programs_target = programs[1:]
                full_answers_input = full_answers[:-1]
                full_answers_target = full_answers[1:]

                ##################################
                # Forward evaluate data
                ##################################
                output = model(questions, gt_scene_graphs, programs_input,
                               full_answers_input)
                programs_output, short_answer_logits = output

                ##################################
                # Convert output probability to top1 guess
                # So that we could measure accuracy
                ##################################
                programs_output_pred = programs_output.detach().topk(
                    k=1, dim=-1, largest=True, sorted=True)[1].squeeze(-1)
                # full_answers_output_pred = full_answers_output.detach().topk(
                #     k=1, dim=-1, largest=True, sorted=True
                # )[1].squeeze(-1)

            else:

                programs_target = programs
                full_answers_target = full_answers

                ##################################
                # Greedy decoding-based evaluation
                ##################################
                output = model(questions,
                               gt_scene_graphs,
                               None,
                               None,
                               SAMPLE_FLAG=True)
                programs_output_pred, short_answer_logits = output

            ##################################
            # Neural Execution Engine Bitmap loss
            # ground truth stored at gt_scene_graphs.y
            # using torch.nn.BCELoss - torch.nn.functional.binary_cross_entropy
            ##################################
            # precision, precision_div, recall, recall_div = bitmap_precision_recall(
            #     execution_bitmap, gt_scene_graphs.y, threshold=0.5
            # )

            # bitmap_precision.update(precision, precision_div)
            # bitmap_recall.update(recall, recall_div)

            ##################################
            # Calculate Fast Evaluation for each module
            ##################################
            this_short_answer_acc1 = accuracy(short_answer_logits.detach(),
                                              short_answer_label,
                                              topk=(1, ))
            short_answer_acc.update(this_short_answer_acc1[0].item(),
                                    this_batch_size)

            text_pad_idx = GQATorchDataset.TEXT.vocab.stoi[
                GQATorchDataset.TEXT.pad_token]
            this_program_acc, this_program_group_acc, this_program_non_empty_acc = program_string_exact_match_acc(
                programs_output_pred,
                programs_target,
                padding_idx=text_pad_idx,
                group_accuracy_WAY_NUM=GQATorchDataset.MAX_EXECUTION_STEP)
            program_acc.update(this_program_acc, this_batch_size)
            program_group_acc.update(
                this_program_group_acc,
                this_batch_size // GQATorchDataset.MAX_EXECUTION_STEP)
            program_non_empty_acc.update(this_program_non_empty_acc,
                                         this_batch_size)

            # this_full_answers_acc = string_exact_match_acc(
            #     full_answers_output_pred.detach(), full_answers_target, padding_idx=text_pad_idx
            # )
            # full_answer_acc.update(this_full_answers_acc, this_batch_size)

            ##################################
            # Example Visualization from the first batch
            ##################################

            if i == 0 and True:
                for batch_idx in range(min(this_batch_size, 128)):

                    ##################################
                    # print Question and Question ID
                    ##################################
                    question = questions[:, batch_idx].cpu()
                    question_sent, _ = GQATorchDataset.indices_to_string(
                        question, True)
                    print(
                        "Question({}) QID({}):".format(batch_idx,
                                                       questionID[batch_idx]),
                        question_sent)
                    if utils.is_main_process():
                        logging.info("Question({}) QID({}): {}".format(
                            batch_idx, questionID[batch_idx], question_sent))

                    ##################################
                    # print program prediction
                    ##################################

                    for instr_idx in range(GQATorchDataset.MAX_EXECUTION_STEP):
                        true_batch_idx = instr_idx + GQATorchDataset.MAX_EXECUTION_STEP * batch_idx
                        gt = programs[:, true_batch_idx].cpu()
                        pred = programs_output_pred[:, true_batch_idx]
                        pred_sent, _ = GQATorchDataset.indices_to_string(
                            pred, True)
                        gt_sent, _ = GQATorchDataset.indices_to_string(
                            gt, True)

                        if len(pred_sent) == 0 and len(gt_sent) == 0:
                            # skip if both target and prediciton are empty
                            continue

                        # gt_caption
                        print(
                            "Generated Program ({}): ".format(true_batch_idx),
                            pred_sent, " Ground Truth Program ({}):".format(
                                true_batch_idx), gt_sent)
                        if utils.is_main_process():
                            # gt_caption
                            logging.info(
                                "Generated Program ({}): {}  Ground Truth Program ({}): {}"
                                .format(true_batch_idx, pred_sent,
                                        true_batch_idx, gt_sent))

                    ##################################
                    # print full answer prediction
                    ##################################
                    # gt = full_answers[:, batch_idx].cpu()
                    # pred = full_answers_output_pred[:, batch_idx]
                    # pred_sent, _ = GQATorchDataset.indices_to_string(pred, True)
                    # gt_sent, _ = GQATorchDataset.indices_to_string(gt, True)
                    # # gt_caption
                    # print(
                    #     "Generated Full Answer ({}): ".format(batch_idx), pred_sent,
                    #     "Ground Truth Full Answer ({}):".format(batch_idx), gt_sent
                    # )
                    # if utils.is_main_process():
                    #     # gt_caption
                    #     logging.info("Generated Full Answer ({}): {} Ground Truth Full Answer ({}): {}".format(
                    #         batch_idx, pred_sent, batch_idx, gt_sent
                    #     ))

            ##################################
            # Dump Results if enabled
            ##################################
            if DUMP_RESULT:

                short_answer_pred_score, short_answer_pred_label = short_answer_logits.max(
                    1)
                short_answer_pred_score, short_answer_pred_label = short_answer_pred_score.cpu(
                ), short_answer_pred_label.cpu()
                for batch_idx in range(this_batch_size):
                    ##################################
                    # print Question and Question ID
                    ##################################
                    question = questions[:, batch_idx].cpu()
                    question_sent, _ = GQATorchDataset.indices_to_string(
                        question, True)

                    ##################################
                    # print program prediction
                    ##################################
                    ground_truth_program_list = []
                    predicted_program_list = []
                    for instr_idx in range(GQATorchDataset.MAX_EXECUTION_STEP):
                        true_batch_idx = instr_idx + GQATorchDataset.MAX_EXECUTION_STEP * batch_idx
                        gt = programs[:, true_batch_idx].cpu()
                        pred = programs_output_pred[:, true_batch_idx]
                        pred_sent, _ = GQATorchDataset.indices_to_string(
                            pred, True)
                        gt_sent, _ = GQATorchDataset.indices_to_string(
                            gt, True)

                        if len(pred_sent) == 0 and len(gt_sent) == 0:
                            # skip if both target and prediciton are empty
                            continue

                        ground_truth_program_list.append(gt_sent)
                        predicted_program_list.append(pred_sent)

                    ##################################
                    # print full answer prediction
                    ##################################
                    # gt = full_answers[:, batch_idx].cpu()
                    # pred = full_answers_output_pred[:, batch_idx]
                    # pred_sent, _ = GQATorchDataset.indices_to_string(pred, True)
                    # gt_sent, _ = GQATorchDataset.indices_to_string(gt, True)
                    # gt_caption

                    ##################################
                    # get short answer prediction
                    ##################################
                    qid = questionID[batch_idx]
                    quesid2ans[qid] = {
                        "questionId":
                        str(qid),
                        "question":
                        question_sent,
                        "ground_truth_program_list":
                        ground_truth_program_list,
                        "predicted_program_list":
                        predicted_program_list,
                        "answer":
                        GQATorchDataset.label2ans[
                            short_answer_label[batch_idx].cpu().item()],
                        # predicted short answer
                        "prediction":
                        GQATorchDataset.label2ans[
                            short_answer_pred_label[batch_idx].cpu().item()],
                        "prediction_score":
                        '{:.2f}'.format(
                            short_answer_pred_score[batch_idx].cpu().item()),
                        "types":
                        types[batch_idx],
                    }

            ##################################
            # measure elapsed time
            ##################################
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0 or i == len(val_loader) - 1:
                progress.display(i)

            ##################################
            # Only for dubugging: short cut the evaluation loop
            ##################################
            # break

    ##################################
    # Give final score
    ##################################
    progress.display(batch=len(val_loader))

    if DUMP_RESULT:
        result_dump_path = os.path.join(args.output_dir, "dump_results.json")
        with open(result_dump_path, 'w') as f:
            json.dump(quesid2ans, f, indent=4, sort_keys=True)
            print("Result Dumped!", str(result_dump_path))

    return
Exemplo n.º 3
0
            logging.info(('%s, ' * (len(not_loaded_keys) - 1) + '%s') %
                         tuple(not_loaded_keys))

        model_dict.update(pretrained_dict)
        super(PipelineModel, self).load_state_dict(model_dict)


if __name__ == "__main__":

    ##################################
    # Need to have the vocab first to debug
    ##################################
    from gqa_dataset_entry import GQATorchDataset, GQATorchDataset_collate_fn
    debug_dataset = GQATorchDataset(
        # split='train_unbiased',
        split='val_unbiased',  #
        # split='testdev',
        build_vocab_flag=False,
        load_vocab_flag=True)

    # debug_dataset = GQATorchDataset(
    #         split='train_unbiased',
    #         build_vocab_flag=True,
    #         load_vocab_flag=True
    #     )

    ##################################
    # Debugging: init model
    # Forwarding a tiny batch with CPU
    ##################################
    model = PipelineModel()
    model.train()