Пример #1
0
    def __init__(self):
        # Datasets

        self.valid_tuple = get_data_tuple(args.valid,
                                          bs=128,
                                          shuffle=False,
                                          drop_last=False)
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=512,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers, args.model)
        # Load pre-trained weights
        if args.load_pretrained is not None:
            self.model.encoder.load(args.load_pretrained)
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from src.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)
Пример #2
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .csv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = read_swag_examples(os.path.join(
            args.data_dir, 'train.csv'),
                                            is_training=True)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    model = BertForMultipleChoice.from_pretrained(
        args.bert_model,
        cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE,
                               'distributed_{}'.format(args.local_rank)),
        num_choices=4)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      tokenizer,
                                                      args.max_seq_length,
                                                      True)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor(select_field(train_features, 'input_ids'),
                                     dtype=torch.long)
        all_input_mask = torch.tensor(select_field(train_features,
                                                   'input_mask'),
                                      dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(train_features,
                                                    'segment_ids'),
                                       dtype=torch.long)
        all_label = torch.tensor([f.label for f in train_features],
                                 dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.fp16 and args.loss_scale != 1.0:
                    # rescale loss for fp16 training
                    # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
                    loss = loss * args.loss_scale
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForMultipleChoice(config, num_choices=4)
        model.load_state_dict(torch.load(output_model_file))
    else:
        model = BertForMultipleChoice.from_pretrained(args.bert_model,
                                                      num_choices=4)
    model.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = read_swag_examples(os.path.join(
            args.data_dir, 'val.csv'),
                                           is_training=True)
        eval_features = convert_examples_to_features(eval_examples, tokenizer,
                                                     args.max_seq_length, True)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'),
                                     dtype=torch.long)
        all_input_mask = torch.tensor(select_field(eval_features,
                                                   'input_mask'),
                                      dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(eval_features,
                                                    'segment_ids'),
                                       dtype=torch.long)
        all_label = torch.tensor([f.label for f in eval_features],
                                 dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples

        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': tr_loss / nb_tr_steps
        }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Пример #3
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input train corpus.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--on_memory",
        action='store_true',
        help="Whether to load train samples into memory or use disk")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train:
        raise ValueError(
            "Training is currently the only implemented execution option. Please set `do_train`."
        )

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    #train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        print("Loading Train Dataset", args.train_file)
        train_dataset = BERTDataset(args.train_file,
                                    tokenizer,
                                    seq_len=args.max_seq_length,
                                    corpus_lines=None,
                                    on_memory=args.on_memory)
        num_train_optimization_steps = int(
            len(train_dataset) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    model = BertForPreTraining.from_pretrained(args.bert_model)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            #TODO: check if this works with current data generator from disk that relies on next(file)
            # (it doesn't return item back by index)
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
                loss = model(input_ids, segment_ids, input_mask, lm_label_ids,
                             is_next)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

        # Save a trained model
        logger.info("** ** * Saving fine - tuned model ** ** * ")
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
        if args.do_train:
            torch.save(model_to_save.state_dict(), output_model_file)
Пример #4
0
class VQA:
    def __init__(self):
        # Datasets

        self.valid_tuple = get_data_tuple(args.valid,
                                          bs=128,
                                          shuffle=False,
                                          drop_last=False)
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=512,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers, args.model)
        # Load pre-trained weights
        if args.load_pretrained is not None:
            self.model.encoder.load(args.load_pretrained)
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from src.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)
        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Пример #5
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model checkpoints and predictions will be written.")

    ## Other parameters
    parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument("--predict_file", default=None, type=str,
                        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
    parser.add_argument("--max_seq_length", default=384, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--max_query_length", default=64, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
    parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
    parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.")
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
                        help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
                             "of training.")
    parser.add_argument("--n_best_size", default=20, type=int,
                        help="The total number of n-best predictions to generate in the nbest_predictions.json "
                             "output file.")
    parser.add_argument("--max_answer_length", default=30, type=int,
                        help="The maximum length of an answer that can be generated. This is needed because the start "
                             "and end predictions are not conditioned on one another.")
    parser.add_argument("--verbose_logging", action='store_true',
                        help="If true, all of the warnings related to data processing will be printed. "
                             "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--version_2_with_negative',
                        action='store_true',
                        help='If true, the SQuAD examples contain some that do not have an answer.')
    parser.add_argument('--null_score_diff_threshold',
                        type=float, default=0.0,
                        help="If null_score - best_non_null is greater than the threshold predict null.")
    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError("At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory () already exists and is not empty.")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = read_squad_examples(
            input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    # Prepare model
    model = BertForQuestionAnswering.from_pretrained(args.bert_model,
                cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)))

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    if args.do_train:
        cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
            list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
        train_features = None
        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except:
            train_features = convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s", cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)
        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                   all_start_positions, all_end_positions)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                if n_gpu == 1:
                    batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
                input_ids, input_mask, segment_ids, start_positions, end_positions = batch
                loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used and handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForQuestionAnswering(config)
        model.load_state_dict(torch.load(output_model_file))
    else:
        model = BertForQuestionAnswering.from_pretrained(args.bert_model)

    model.to(device)

    if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = read_squad_examples(
            input_file=args.predict_file, is_training=False, version_2_with_negative=args.version_2_with_negative)
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=False)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start evaluating")
        for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating"):
            if len(all_results) % 1000 == 0:
                logger.info("Processing example: %d" % (len(all_results)))
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            with torch.no_grad():
                batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                all_results.append(RawResult(unique_id=unique_id,
                                             start_logits=start_logits,
                                             end_logits=end_logits))
        output_prediction_file = os.path.join(args.output_dir, "predictions.json")
        output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
        output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
        write_predictions(eval_examples, eval_features, all_results,
                          args.n_best_size, args.max_answer_length,
                          args.do_lower_case, output_prediction_file,
                          output_nbest_file, output_null_log_odds_file, args.verbose_logging,
                          args.version_2_with_negative, args.null_score_diff_threshold)
Пример #6
0
    def __init__(self):

        self.train_type = args.train_type
        self.device = torch.device(args.device)

        # Dataloaders for train and val set
        if not args.test:
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=args.batch_size,
                                              shuffle=False,
                                              drop_last=False)
            self.train_tuple = get_data_tuple(args.train,
                                              bs=args.batch_size,
                                              shuffle=True,
                                              drop_last=True)
            num_answers = self.train_tuple.dataset.num_answers
            file_name = args.train
            log_str = f"\n{ctime()} || Loaded train set of size {len(self.train_tuple[0])} and val set of size {len(self.valid_tuple[0])}."
        else:
            self.test_tuple = get_data_tuple(args.test,
                                             bs=args.batch_size,
                                             shuffle=False,
                                             drop_last=False)
            num_answers = self.test_tuple.dataset.num_answers
            file_name = args.test
            log_str = (
                f"\n{ctime()} || Loaded test set of size {len(self.test_tuple[0])}."
            )

        # get dataset name
        self.dtype = args.task

        # Model
        self.model = eUGModel(self.train_type, num_answers, self.dtype,
                              args.model)

        # Load pre-trained weights
        if self.train_type == "expl" and args.bb_path is not None:
            self.model.load_state_dict(torch.load(args.bb_path))
            # freeze backbone
            for p, n in self.model.named_parameters():
                if "decoder.model.transformer" not in p:
                    n.requires_grad = False
        elif args.load_pretrained is not None:
            self.model.encoder.load(args.load_pretrained)

        self.model = self.model.to(self.device)

        # Loss and Optimizer
        if not args.test:
            if self.dtype == "vqa_x":
                self.loss_func = nn.BCEWithLogitsLoss()
            else:
                self.loss_func = nn.CrossEntropyLoss()

            batch_per_epoch = len(self.train_tuple.loader) / args.grad_accum
            t_total = int(batch_per_epoch * args.epochs)

            if "bert" in args.optim:
                print("BertAdam Total Iters: %d" % t_total)
                from src.optimization import BertAdam

                self.optim = BertAdam(
                    list(self.model.parameters()),
                    lr=args.lr,
                    warmup=0.1,
                    t_total=t_total,
                )
            else:
                self.optim = args.optimizer(self.model.parameters(), args.lr)
                self.scheduler = get_linear_schedule_with_warmup(
                    self.optim,
                    num_warmup_steps=args.warmup_steps,
                    num_training_steps=t_total,
                )
        self.grad_accum = args.grad_accum

        # Output Directory
        self.output = args.output
        self.save_steps = args.save_steps
        os.makedirs(self.output, exist_ok=True)

        # print logs
        log_str += f"\n{ctime()} || Model loaded. Batch size {args.batch_size*args.grad_accum} | lr {args.lr} | task: {self.dtype} | type: {self.train_type}."
        print_log(args, log_str)
Пример #7
0
class VQA:
    def __init__(self):

        self.train_type = args.train_type
        self.device = torch.device(args.device)

        # Dataloaders for train and val set
        if not args.test:
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=args.batch_size,
                                              shuffle=False,
                                              drop_last=False)
            self.train_tuple = get_data_tuple(args.train,
                                              bs=args.batch_size,
                                              shuffle=True,
                                              drop_last=True)
            num_answers = self.train_tuple.dataset.num_answers
            file_name = args.train
            log_str = f"\n{ctime()} || Loaded train set of size {len(self.train_tuple[0])} and val set of size {len(self.valid_tuple[0])}."
        else:
            self.test_tuple = get_data_tuple(args.test,
                                             bs=args.batch_size,
                                             shuffle=False,
                                             drop_last=False)
            num_answers = self.test_tuple.dataset.num_answers
            file_name = args.test
            log_str = (
                f"\n{ctime()} || Loaded test set of size {len(self.test_tuple[0])}."
            )

        # get dataset name
        self.dtype = args.task

        # Model
        self.model = eUGModel(self.train_type, num_answers, self.dtype,
                              args.model)

        # Load pre-trained weights
        if self.train_type == "expl" and args.bb_path is not None:
            self.model.load_state_dict(torch.load(args.bb_path))
            # freeze backbone
            for p, n in self.model.named_parameters():
                if "decoder.model.transformer" not in p:
                    n.requires_grad = False
        elif args.load_pretrained is not None:
            self.model.encoder.load(args.load_pretrained)

        self.model = self.model.to(self.device)

        # Loss and Optimizer
        if not args.test:
            if self.dtype == "vqa_x":
                self.loss_func = nn.BCEWithLogitsLoss()
            else:
                self.loss_func = nn.CrossEntropyLoss()

            batch_per_epoch = len(self.train_tuple.loader) / args.grad_accum
            t_total = int(batch_per_epoch * args.epochs)

            if "bert" in args.optim:
                print("BertAdam Total Iters: %d" % t_total)
                from src.optimization import BertAdam

                self.optim = BertAdam(
                    list(self.model.parameters()),
                    lr=args.lr,
                    warmup=0.1,
                    t_total=t_total,
                )
            else:
                self.optim = args.optimizer(self.model.parameters(), args.lr)
                self.scheduler = get_linear_schedule_with_warmup(
                    self.optim,
                    num_warmup_steps=args.warmup_steps,
                    num_training_steps=t_total,
                )
        self.grad_accum = args.grad_accum

        # Output Directory
        self.output = args.output
        self.save_steps = args.save_steps
        os.makedirs(self.output, exist_ok=True)

        # print logs
        log_str += f"\n{ctime()} || Model loaded. Batch size {args.batch_size*args.grad_accum} | lr {args.lr} | task: {self.dtype} | type: {self.train_type}."
        print_log(args, log_str)

    def train(self, train_tuple, eval_tuple):

        tb_writer = SummaryWriter(self.output)

        dset, loader, evaluator = train_tuple
        iter_wrapper = ((lambda x: tqdm(x, total=len(loader)))
                        if args.tqdm else (lambda x: x))

        # logger initialisations
        best_task = 0.0  # this refers to the model with the best S_T score
        best_expl = 0.0  # this refers to the model with the best S_E score
        best_global = 0.0  # this refers to the model with the best S_O score
        prev_losses = [[1], [1]]
        prev_task, prev_expl = 0, 0
        global_step = 0
        t_loss, tt_loss, te_loss = 0, 0, 0
        step_per_eval = 0

        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (
                    ques_id,
                    feats,
                    boxes,
                    sent,
                    target,
                    expl,
                    answer_choices,
            ) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                expl_gt = target

                if self.dtype == "vcr":
                    model_dict = answer_choices
                    target = target.flatten()
                else:
                    model_dict = dset.label2ans

                logit, output, _, _, _ = self.model(
                    feats.to(self.device),
                    boxes.to(self.device),
                    sent,
                    expl,
                    answer_choices,
                    model_dict,
                    expl_gt,
                )

                if self.dtype == "vqa_x":
                    loss_multiplier = logit.size(1)
                elif self.dtype == "vcr":
                    loss_multiplier = 4
                else:
                    loss_multiplier = 1

                if self.train_type == "all":
                    task_loss = (
                        self.loss_func(logit, target.to(self.device)) *
                        loss_multiplier)
                    expl_loss = output[0]
                    # loss_weights = dwa(prev_losses, temp=args.temperature)
                    loss_weights = {"task": 1, "expl": 1}
                    # loss = loss_weights['task']*task_loss + loss_weights['expl']*expl_loss
                    loss = weighted_loss(task_loss, expl_loss, loss_weights,
                                         args.classifier_weight)
                    loss /= self.grad_accum

                    prev_task += float(task_loss)
                    prev_expl += float(expl_loss)

                    # record loss for every 1024 datapoints
                    if (i + 1) % int((1024 / args.batch_size)) == 0:
                        prev_losses[0].append(prev_task /
                                              (1024 / args.batch_size))
                        prev_losses[1].append(prev_expl /
                                              (1024 / args.batch_size))
                        prev_task, prev_expl = 0, 0

                elif self.train_type == "bb":
                    loss = (self.loss_func(logit, target.to(self.device)) *
                            loss_multiplier)
                    loss /= self.grad_accum
                    task_loss = float(loss)
                    expl_loss = 0

                elif self.train_type == "expl":
                    loss = output[0]
                    loss /= self.grad_accum
                    task_loss = 0
                    expl_loss = float(loss)

                loss.backward()

                if self.dtype == "vcr":
                    logit = binary_to_mp(logit)

                score, label = logit.max(1)
                if not isinstance(ques_id, list):
                    ques_id = ques_id.cpu().numpy()

                if self.dtype == "vcr":  # vcr
                    for qid, l in zip(ques_id, label.cpu().numpy()):
                        ans = dset.label2ans[qid][l]
                        quesid2ans[qid] = ans
                else:
                    for qid, l in zip(ques_id, label.cpu().numpy()):
                        ans = dset.label2ans[l]
                        quesid2ans[qid] = ans

                t_loss += float(loss) * self.grad_accum
                tt_loss += float(task_loss)
                te_loss += float(expl_loss)
                step_per_eval += 1

                # global step
                # grad accum snippet: https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3
                if (i + 1) % self.grad_accum == 0:

                    nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
                    self.optim.step()
                    if args.optim != "bert":
                        self.scheduler.step()  # Update learning rate schedule

                    # logging
                    tb_writer.add_scalar("task loss", task_loss, global_step)
                    tb_writer.add_scalar("explanation loss", expl_loss,
                                         global_step)
                    tb_writer.add_scalar("total loss",
                                         float(loss) * self.grad_accum,
                                         global_step)
                    if self.train_type == "all":
                        tb_writer.add_scalar("task weight",
                                             loss_weights["task"], global_step)
                        tb_writer.add_scalar("explanation weight",
                                             loss_weights["expl"], global_step)

                    global_step += 1

                    # do eval
                    if self.save_steps > 0 and global_step % self.save_steps == 0:
                        log_str = f"\n\n{ctime()} || EVALUATION TIME"
                        log_str += f"\nEpoch-step {epoch}-{global_step}: Loss {t_loss/step_per_eval:.2f} | Task loss {tt_loss/step_per_eval:.2f} | Expl loss {te_loss/step_per_eval:.2f} | Train acc {evaluator.evaluate(quesid2ans)[0]:.2f}"
                        print_log(args, log_str)
                        t_loss, tt_loss, te_loss = 0, 0, 0
                        step_per_eval = 0

                        if self.valid_tuple is not None:  # Do Validation
                            valid_score, valid_perplexity, nlg_scores = self.evaluate(
                                eval_tuple)

                            # no explanations generated
                            if not nlg_scores:

                                if valid_score > best_task:
                                    best_task = valid_score
                                    self.save("best_task")

                                log_str = f"\nEpoch-step {epoch}-{global_step}: Valid Score: {valid_score:.3f} | Best Valid Score: {best_task:.3f}"
                                tb_writer.add_scalar("valid_task_score",
                                                     valid_score * 100.0,
                                                     global_step)
                                tb_writer.add_scalar(
                                    "valid_expl_perplexity",
                                    valid_perplexity * 100.0,
                                    global_step,
                                )
                                print_log(args, log_str)
                                continue

                            if valid_score > best_task:
                                best_task = valid_score
                                self.save("best_task")

                            if self.train_type == "bb":
                                nlg_avg = 0
                                global_score = 0
                                valid_perplexity = 0
                            else:
                                global_score = nlg_scores["global_score"]
                                if global_score > best_global:
                                    best_global = global_score
                                    self.save("best_global")

                                nlg_avg = nlg_scores["avg_all"]
                                if nlg_avg > best_expl:
                                    best_expl = nlg_avg
                                    self.save("best_expl")

                            log_str = f"\nEpoch-step {epoch}-{global_step}: Valid Score: {valid_score:.3f} | NLG average: {nlg_avg:.3f} | Global score: {global_score:.3f}"
                            log_str += f"\nEpoch-step {epoch}-{global_step}: Best Valid Score: {best_task:.3f} | Best NLG: {best_expl:.3f} | Best overall: {best_global:.3f}"

                            tb_writer.add_scalar("valid_task_score",
                                                 valid_score * 100.0,
                                                 global_step)
                            tb_writer.add_scalar(
                                "valid_expl_perplexity",
                                valid_perplexity * 100.0,
                                global_step,
                            )

                            if nlg_scores:
                                log_str += f"\nEpoch-step {epoch}-{global_step}: {print_dict(nlg_scores)}"
                                for k, v in nlg_scores.items():
                                    tb_writer.add_scalar(k, v, global_step)

                        print(log_str, end="")

                        print_log(args, log_str)

                        tb_writer.flush()

        self.save("LAST")
        tb_writer.close()

    def predict(self,
                train_type,
                eval_tuple: DataTuple,
                dump=None,
                gen_dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """

        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        expl_loss = 0.0
        nb_eval_steps = 0
        generated_explanations = None
        test_output = []

        if "bb" not in train_type:
            # initialisations for NL evaluation
            try:
                bert_metric = load_metric(
                    "bertscore",
                    experiment_id=str(random.randrange(999999)),
                    device=self.device,
                )
            except:
                bert_metric = None
            all_generated_explanations = []
            all_gt_expls = []
            tokenizer = VCRGpt2Tokenizer.from_pretrained("gpt2")
            gen_model = self.model.decoder.model.to(self.device)

        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent, label, expl, answers = datum_tuple

            if args.gt_cond:
                gt = label
            else:
                gt = None

            if self.dtype == "vcr":  # different label dict
                model_dict = answers
            else:
                model_dict = dset.label2ans

            if self.dtype == "vqa_x":  # multiple explanations
                triple_expl = [[x[y] for x in expl]
                               for y in range(len(expl[0]))]
                expl = expl[0]
            else:
                triple_expl = None

            with torch.no_grad():
                feats, boxes = feats.to(self.device), boxes.to(self.device)
                (
                    logit,
                    expl_output,
                    input_ids,
                    token_type_ids,
                    visual_representations,
                ) = self.model(feats, boxes, sent, expl, answers, model_dict,
                               gt)

                # get indices for when to generate explanations
                if self.dtype == "vqa_x":
                    if args.gt_cond:
                        logit = label
                    correct_indices = []
                    for idx, prediction in enumerate(
                            list(
                                torch.argmax(logit,
                                             1).detach().cpu().numpy())):
                        if float(label[idx][prediction]) != 0:
                            correct_indices.append(idx)
                    correct_indices = torch.tensor(correct_indices)
                elif self.dtype == "vcr":
                    logit = binary_to_mp(
                        logit)  # transform binary labels into 4-way
                    correct_indices = (torch.where(
                        label.argmax(1) == logit.cpu().argmax(1))
                                       [0].detach().cpu())
                else:
                    correct_indices = (torch.where(
                        label.to(self.device) == torch.argmax(logit, 1))
                                       [0].detach().cpu())
                    if args.gt_cond:
                        correct_indices = torch.range(0,
                                                      label.size(0) - 1,
                                                      dtype=int)

                # populate quesid2ans (where ans is predicted ans)
                if not isinstance(ques_id, list):
                    ques_id = ques_id.cpu().numpy()
                score, label = logit.max(1)
                if self.dtype == "vcr":
                    for qid, l in zip(ques_id, label.cpu().numpy()):
                        ans = dset.label2ans[qid][l]
                        quesid2ans[qid] = ans
                else:
                    for qid, l in zip(ques_id, label.cpu().numpy()):
                        ans = dset.label2ans[l]
                        quesid2ans[qid] = ans

                # generate and evaluate explanations
                get_gen_expl = 0
                if "bb" not in train_type:
                    expl_loss += expl_output[0].mean().item()

                    # only evaluate random subset during validation to save time
                    if args.test:
                        get_gen_expl = 1
                    else:
                        get_gen_expl = np.random.choice(
                            np.arange(0, 2),
                            p=[1 - args.prob_eval, args.prob_eval])

                    # get subset where label was predicted correctly
                    (
                        input_ids,
                        token_type_ids,
                        visual_representations,
                        expl,
                        triple_expl,
                    ) = input_subset(
                        correct_indices,
                        input_ids,
                        token_type_ids,
                        visual_representations,
                        expl,
                        triple_expl,
                        self.device,
                    )
                    generated_explanations = None

                    if input_ids.shape[
                            0] != 0:  # if not all predictions were wrong
                        if get_gen_expl:
                            generated_explanations = generate_text(
                                gen_model,
                                tokenizer,
                                input_ids,
                                token_type_ids,
                                visual_representations,
                                max_rationale_length=51,
                            )

                            if self.dtype == "vcr":
                                expl = [
                                    map_vcr_tag_to_num(x) for x in expl
                                ]  # to make sure same kind of explanations are compared

                            # free memory
                            input_ids, token_type_ids, visual_representations = (
                                None,
                                None,
                                None,
                            )

                            if self.dtype == "vqa_x":
                                try:
                                    bert_metric.add_batch(
                                        predictions=generated_explanations,
                                        references=triple_expl,
                                    )
                                except:
                                    print("BertScore failed")
                                all_gt_expls.extend(triple_expl)
                            else:
                                try:
                                    bert_metric.add_batch(
                                        predictions=generated_explanations,
                                        references=expl,
                                    )
                                except:
                                    print("BertScore failed")
                                all_gt_expls.extend(expl)

                            all_generated_explanations.extend(
                                generated_explanations)

                            # printing examples during eval
                            if not args.test:
                                if self.dtype == "vcr":
                                    labels = [
                                        label[i].max(0)[1].item()
                                        for i in correct_indices
                                    ]
                                    model_dict = [
                                        answers[i] for i in correct_indices
                                    ]
                                else:
                                    labels = [
                                        label[i].item()
                                        for i in correct_indices
                                    ]
                                random_print_samples(
                                    [sent[i] for i in correct_indices],
                                    labels,
                                    generated_explanations,
                                    model_dict,
                                )

                gen_expl_all = len(ques_id) * ["None"]
                if generated_explanations:
                    for ci, gen_expl in zip(correct_indices,
                                            generated_explanations):
                        gen_expl_all[ci] = gen_expl

                # write explanations to file
                if gen_dump:
                    for idx, (qid, gen_expl) in enumerate(
                            zip(list(ques_id), gen_expl_all)):
                        input_record = {}

                        input_record["question_id"] = str(qid)
                        input_record["question"] = dset.id2datum[qid]["sent"]
                        input_record["generated_explanation"] = gen_expl
                        if self.dtype == "vcr":
                            input_record["correct_explanations"] = (
                                dset.id2datum[qid]["explanation"].replace(
                                    "<|det", "").replace("|>", ""))
                        else:
                            input_record[
                                "correct_explanations"] = dset.id2datum[qid][
                                    "explanation"]
                        input_record["prediction"] = quesid2ans[qid]
                        input_record["gt"] = dset.id2datum[qid]["label"]
                        if self.dtype == "vcr":
                            input_record["img_id"] = dset.id2datum[qid][
                                "raw_img_id"]
                            input_record["movie"] = dset.id2datum[qid]["movie"]
                            input_record["answer_choices"] = [
                                x.replace("<|det", "").replace("|>", "")
                                for x in dset.id2datum[qid]["answer_choices"]
                            ]
                        elif self.dtype == "vqax":
                            input_record["img_id"] = dset.id2datum[qid][
                                "img_id"]
                        else:
                            input_record["img_id"] = str(qid)[:-5]
                        if idx in list(correct_indices.numpy()):
                            input_record["correct"] = 1
                        else:
                            input_record["correct"] = 0

                        test_output.append(input_record)

            nb_eval_steps += 1

        valid_score, correct_idx = eval_tuple.evaluator.evaluate(quesid2ans)
        nlg_weight = correct_idx.count(1) / len(
            correct_idx)  # because for vqa-x we also take half-correct answers

        # getting perplexity
        expl_loss = expl_loss / nb_eval_steps
        perplexity = torch.exp(torch.tensor(expl_loss)).item()

        if "bb" not in train_type and len(all_generated_explanations) != 0:

            # getting NLG metrics
            nlg_global_scores = get_nlg_scores(
                self.dtype,
                all_generated_explanations,
                all_gt_expls,
                bert_metric,
                self.device,
            )
            nlg_global_scores["global_score"] = (nlg_global_scores["avg_all"] *
                                                 nlg_weight)
            if not nlg_global_scores["global_score"]:
                nlg_global_scores["global_score"] = 0

            if gen_dump is not None:
                scores_to_print = nlg_global_scores
                scores_to_print["task_score"] = valid_score
                write_items(
                    [json.dumps(r) for r in ["scores", scores_to_print]],
                    os.path.join(args.output, "scores.json"),
                )
                write_items(
                    [json.dumps(r) for r in test_output],
                    os.path.join(args.output, "gen_test.json"),
                )

            return valid_score, perplexity, nlg_global_scores
        else:
            scores_to_print = {"task_score": valid_score}
            print("Task Score: ", valid_score)
            write_items(
                [json.dumps(r) for r in ["scores", scores_to_print]],
                os.path.join(args.output, "scores.json"),
            )
            return valid_score, perplexity, None

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        valid_score, expl_perplexity, nlg_global_scores = self.predict(
            self.train_type, eval_tuple, dump)
        return valid_score, expl_perplexity, nlg_global_scores

    @staticmethod
    def oracle_score(data_tuple):
        """
        Purpose:
        """
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path,
                                map_location=torch.device("cpu"))
        self.model.load_state_dict(state_dict, strict=False)
        self.model = self.model.to(self.device)