def main():
    parser = ArgumentParser()
    parser.add_argument('--pregenerated_neg_data', type=Path, required=True)
    parser.add_argument('--pregenerated_data', type=Path, required=True)
    parser.add_argument('--output_dir', type=Path, required=True)
    parser.add_argument(
        "--bert_model",
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Store training data as on-disc memmaps to massively reduce memory usage"
    )

    parser.add_argument("--max_seq_len", default=512, type=int)

    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument("--epochs",
                        type=int,
                        default=3,
                        help="Number of epochs to train for")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--kr_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--kr_freq", default=0.7, type=float)
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--learning_rate",
                        default=1e-4,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    args = parser.parse_args()

    assert args.pregenerated_data.is_dir(), \
        "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"

    samples_per_epoch = []
    for i in range(args.epochs):
        epoch_file = args.pregenerated_data / f"epoch_{i}.json"
        metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json"
        if epoch_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch.append(metrics['num_training_examples'])
        else:
            if i == 0:
                exit("No training data was found!")
            print(
                f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs})."
            )
            print(
                "This script will loop over the available data, but training diversity may be negatively impacted."
            )
            num_data_epochs = i
            break
    else:
        num_data_epochs = args.epochs

    if args.local_rank == -1 or args.no_cuda:
        print(torch.cuda.is_available())
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
        print(n_gpu)
        print("no gpu?")
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        print("GPU Device: ", device)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    logging.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # if n_gpu > 0:
    torch.cuda.manual_seed_all(args.seed)

    pt_output = Path(getenv('PT_OUTPUT_DIR', ''))
    args.output_dir = Path(os.path.join(pt_output, args.output_dir))

    if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
        logging.warning(
            f"Output directory ({args.output_dir}) already exists and is not empty!"
        )
    args.output_dir.mkdir(parents=True, exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    total_train_examples = 0
    for i in range(args.epochs):
        # The modulo takes into account the fact that we may loop over limited epochs of data
        total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]

    num_train_optimization_steps = int(total_train_examples /
                                       args.train_batch_size /
                                       args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    # Prepare model
    config = BertConfig.from_pretrained(args.bert_model)
    # config.num_hidden_layers = args.num_layers
    model = FuckWrapper(config)
    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=num_train_optimization_steps)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    global_step = 0
    logging.info("***** Running training *****")
    logging.info(f"  Num examples = {total_train_examples}")
    logging.info("  Batch size = %d", args.train_batch_size)
    logging.info("  Num steps = %d", num_train_optimization_steps)
    model.train()

    before_train_path = Path(os.path.join(args.output_dir, "before_training"))
    print("Before training path: ", before_train_path)
    before_train_path.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(os.path.join(args.output_dir, "before_training"))
    tokenizer.save_pretrained(os.path.join(args.output_dir, "before_training"))

    neg_epoch_dataset = PregeneratedDataset(
        epoch=0,
        training_path=args.pregenerated_neg_data,
        tokenizer=tokenizer,
        num_data_epochs=num_data_epochs,
        reduce_memory=args.reduce_memory)
    if args.local_rank == -1:
        neg_train_sampler = RandomSampler(neg_epoch_dataset)
    else:
        neg_train_sampler = DistributedSampler(neg_epoch_dataset)

    neg_train_dataloader = DataLoader(neg_epoch_dataset,
                                      sampler=neg_train_sampler,
                                      batch_size=args.train_batch_size)

    def inf_train_gen():
        while True:
            for kr_step, kr_batch in enumerate(neg_train_dataloader):
                yield kr_step, kr_batch

    kr_gen = inf_train_gen()

    for epoch in range(args.epochs):
        epoch_dataset = PregeneratedDataset(
            epoch=epoch,
            training_path=args.pregenerated_data,
            tokenizer=tokenizer,
            num_data_epochs=num_data_epochs,
            reduce_memory=args.reduce_memory)
        if args.local_rank == -1:
            train_sampler = RandomSampler(epoch_dataset)
        else:
            train_sampler = DistributedSampler(epoch_dataset)

        train_dataloader = DataLoader(epoch_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0

        if n_gpu > 1 and args.local_rank == -1 or (n_gpu <= 1):
            logging.info("** ** * Saving fine-tuned model ** ** * ")
            model.save_pretrained(args.output_dir)
            tokenizer.save_pretrained(args.output_dir)

        with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
            for step, batch in enumerate(train_dataloader):
                model.train()

                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids = batch

                outputs = model(input_ids=input_ids,
                                attention_mask=input_mask,
                                token_type_ids=segment_ids,
                                masked_lm_labels=lm_label_ids,
                                negated=False)
                loss = outputs[0]
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)

                if args.local_rank == 0 or args.local_rank == -1:
                    nb_tr_steps += 1
                    pbar.update(1)
                    mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
                    pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    scheduler.step()  # Update learning rate schedule
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if random.random() > args.kr_freq:
                    kr_step, kr_batch = next(kr_gen)
                    kr_batch = tuple(t.to(device) for t in kr_batch)
                    input_ids, input_mask, segment_ids, lm_label_ids = kr_batch

                    outputs = model(input_ids=input_ids,
                                    attention_mask=input_mask,
                                    token_type_ids=segment_ids,
                                    masked_lm_labels=lm_label_ids,
                                    negated=True)
                    loss = outputs[0]
                    if n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

                    tr_loss += loss.item()
                    nb_tr_examples += input_ids.size(0)
                    if args.local_rank == -1:
                        nb_tr_steps += 1
                        mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
                        pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        scheduler.step()  # Update learning rate schedule
                        optimizer.step()
                        optimizer.zero_grad()
                        global_step += 1

    # Save a trained model
    if n_gpu > 1 and args.local_rank == -1 or (n_gpu <= 1):
        logging.info("** ** * Saving fine-tuned model ** ** * ")
        model.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
                    tag_loss.item() / sum(lens),
                    class_loss.item() / len(lens)
                ])
                total_loss = opt.st_weight * tag_loss + (
                    1 - opt.st_weight) * class_loss
            else:
                losses.append([tag_loss.item() / sum(lens), 0])
                total_loss = tag_loss
            total_loss.backward()

            # Clips gradient norm of an iterable of parameters.
            if opt.optim.lower() != 'bertadam' and opt.max_norm > 0:
                torch.nn.utils.clip_grad_norm_(params, opt.max_norm)

            if opt.optim.lower() == 'adamw':
                scheduler.step()
            optimizer.step()

            if j % piece_sentences == 0:
                print('[learning] epoch %i >> %2.2f%%' %
                      (i, (j + opt.batchSize) * 100. / nsentences),
                      'completed in %.2f (sec) <<\r' %
                      (time.time() - start_time),
                      end='')
                sys.stdout.flush()
        print('')

        mean_loss = np.mean(losses, axis=0)
        logger.info(
            'Training:\tEpoch : %d\tTime : %.4fs\tLoss of tag : %.2f\tLoss of class : %.2f '
            % (i, time.time() - start_time, mean_loss[0], mean_loss[1]))
for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval() #모델 평가 부분
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))

from google.colab import drive
Beispiel #4
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--from_pretrained",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--output_dir",
        default="save",
        type=str,
        help="The output directory where the model checkpoints will be written.",
    )
    parser.add_argument(
        "--config_file",
        default="config/bert_base_6layer_6conect.json",
        type=str,
        help="The config file which specified the model details.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=20,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--train_iter_multiplier",
        default=1.0,
        type=float,
        help="multiplier for the multi-task training.",
    )
    parser.add_argument(
        "--train_iter_gap",
        default=4,
        type=int,
        help="forward every n iteration is the validation score is not improving over the last 3 epoch, -1 means will stop",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help="Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        type=bool,
        help="Whether to lower case the input text. True for uncased models, False for cased models.",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="random seed for initialization"
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumualte before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=16,
        help="Number of workers in the dataloader.",
    )
    parser.add_argument(
        "--save_name", default="", type=str, help="save name for training."
    )
    parser.add_argument(
        "--in_memory",
        default=False,
        type=bool,
        help="whether use chunck for parallel training.",
    )
    parser.add_argument(
        "--optim", default="AdamW", type=str, help="what to use for the optimization."
    )
    parser.add_argument(
        "--tasks", default="", type=str, help="1-2-3... training task separate by -"
    )
    parser.add_argument(
        "--freeze",
        default=-1,
        type=int,
        help="till which layer of textual stream of vilbert need to fixed.",
    )
    parser.add_argument(
        "--vision_scratch",
        action="store_true",
        help="whether pre-trained the image or not.",
    )
    parser.add_argument(
        "--evaluation_interval", default=1, type=int, help="evaluate very n epoch."
    )
    parser.add_argument(
        "--lr_scheduler",
        default="mannul",
        type=str,
        help="whether use learning rate scheduler.",
    )
    parser.add_argument(
        "--baseline", action="store_true", help="whether use single stream baseline."
    )
    parser.add_argument(
        "--resume_file", default="", type=str, help="Resume from checkpoint"
    )
    parser.add_argument(
        "--dynamic_attention",
        action="store_true",
        help="whether use dynamic attention.",
    )
    parser.add_argument(
        "--clean_train_sets",
        default=True,
        type=bool,
        help="whether clean train sets for multitask data.",
    )
    parser.add_argument(
        "--visual_target",
        default=0,
        type=int,
        help="which target to use for visual branch. \
        0: soft label, \
        1: regress the feature, \
        2: NCE loss.",
    )
    parser.add_argument(
        "--task_specific_tokens",
        action="store_true",
        help="whether to use task specific tokens for the multi-task learning.",
    )

    args = parser.parse_args()
    with open("vilbert_tasks.yml", "r") as f:
        task_cfg = edict(yaml.safe_load(f))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.baseline:
        from pytorch_transformers.modeling_bert import BertConfig
        from vilbert.basebert import BaseBertForVLTasks
    else:
        from vilbert.vilbert import BertConfig
        from vilbert.vilbert import VILBertForVLTasks

    task_names = []
    task_lr = []
    for i, task_id in enumerate(args.tasks.split("-")):
        task = "TASK" + task_id
        name = task_cfg[task]["name"]
        task_names.append(name)
        task_lr.append(task_cfg[task]["lr"])

    base_lr = min(task_lr)
    loss_scale = {}
    for i, task_id in enumerate(args.tasks.split("-")):
        task = "TASK" + task_id
        loss_scale[task] = task_lr[i] / base_lr

    if args.save_name:
        prefix = "-" + args.save_name
    else:
        prefix = ""
    timeStamp = (
        "-".join(task_names)
        + "_"
        + args.config_file.split("/")[1].split(".")[0]
        + prefix
    )
    savePath = os.path.join(args.output_dir, timeStamp)

    bert_weight_name = json.load(
        open("config/" + args.bert_model + "_weight_name.json", "r")
    )

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        )
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        torch.distributed.init_process_group(backend="nccl")

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
            device, n_gpu, bool(args.local_rank != -1), args.fp16
        )
    )

    default_gpu = False
    if dist.is_available() and args.local_rank != -1:
        rank = dist.get_rank()
        if rank == 0:
            default_gpu = True
    else:
        default_gpu = True

    if default_gpu:
        if not os.path.exists(savePath):
            os.makedirs(savePath)

    config = BertConfig.from_json_file(args.config_file)
    if default_gpu:
        # save all the hidden parameters.
        with open(os.path.join(savePath, "command.txt"), "w") as f:
            print(args, file=f)  # Python 3.x
            print("\n", file=f)
            print(config, file=f)

    task_batch_size, task_num_iters, task_ids, task_datasets_train, task_datasets_val, task_dataloader_train, task_dataloader_val = LoadDatasets(
        args, task_cfg, args.tasks.split("-")
    )

    logdir = os.path.join(savePath, "logs")
    tbLogger = utils.tbLogger(
        logdir,
        savePath,
        task_names,
        task_ids,
        task_num_iters,
        args.gradient_accumulation_steps,
    )

    if args.visual_target == 0:
        config.v_target_size = 1601
        config.visual_target = args.visual_target
    else:
        config.v_target_size = 2048
        config.visual_target = args.visual_target

    if args.task_specific_tokens:
        config.task_specific_tokens = True

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_ave_iter = {}
    task_stop_controller = {}
    for task_id, num_iter in task_num_iters.items():
        task_ave_iter[task_id] = int(
            task_cfg[task]["num_epoch"]
            * num_iter
            * args.train_iter_multiplier
            / args.num_train_epochs
        )
        task_stop_controller[task_id] = utils.MultiTaskStopOnPlateau(
            mode="max",
            patience=1,
            continue_threshold=0.005,
            cooldown=1,
            threshold=0.001,
        )

    task_ave_iter_list = sorted(task_ave_iter.values())
    median_num_iter = task_ave_iter_list[-1]
    num_train_optimization_steps = (
        median_num_iter * args.num_train_epochs // args.gradient_accumulation_steps
    )
    num_labels = max([dataset.num_labels for dataset in task_datasets_train.values()])

    if args.dynamic_attention:
        config.dynamic_attention = True
    if "roberta" in args.bert_model:
        config.model = "roberta"

    if args.baseline:
        model = BaseBertForVLTasks.from_pretrained(
            args.from_pretrained,
            config=config,
            num_labels=num_labels,
            default_gpu=default_gpu,
        )
    else:
        model = VILBertForVLTasks.from_pretrained(
            args.from_pretrained,
            config=config,
            num_labels=num_labels,
            default_gpu=default_gpu,
        )

    task_losses = LoadLosses(args, task_cfg, args.tasks.split("-"))

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

    if args.freeze != -1:
        bert_weight_name_filtered = []
        for name in bert_weight_name:
            if "embeddings" in name:
                bert_weight_name_filtered.append(name)
            elif "encoder" in name:
                layer_num = name.split(".")[2]
                if int(layer_num) <= args.freeze:
                    bert_weight_name_filtered.append(name)

        optimizer_grouped_parameters = []
        for key, value in dict(model.named_parameters()).items():
            if key[12:] in bert_weight_name_filtered:
                value.requires_grad = False

        if default_gpu:
            print("filtered weight")
            print(bert_weight_name_filtered)

    optimizer_grouped_parameters = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if "vil_" in key:
                lr = 1e-4
            else:
                if args.vision_scratch:
                    if key[12:] in bert_weight_name:
                        lr = base_lr
                    else:
                        lr = 1e-4
                else:
                    lr = base_lr
            if any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [
                    {"params": [value], "lr": lr, "weight_decay": 0.0}
                ]
            if not any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [
                    {"params": [value], "lr": lr, "weight_decay": 0.01}
                ]

    if default_gpu:
        print(len(list(model.named_parameters())), len(optimizer_grouped_parameters))

    if args.optim == "AdamW":
        optimizer = AdamW(optimizer_grouped_parameters, lr=base_lr, correct_bias=False)
    elif args.optim == "RAdam":
        optimizer = RAdam(optimizer_grouped_parameters, lr=base_lr)

    warmpu_steps = args.warmup_proportion * num_train_optimization_steps

    if args.lr_scheduler == "warmup_linear":
        warmup_scheduler = WarmupLinearSchedule(
            optimizer, warmup_steps=warmpu_steps, t_total=num_train_optimization_steps
        )
    else:
        warmup_scheduler = WarmupConstantSchedule(optimizer, warmup_steps=warmpu_steps)

    lr_reduce_list = np.array([5, 7])
    if args.lr_scheduler == "automatic":
        lr_scheduler = ReduceLROnPlateau(
            optimizer, mode="max", factor=0.2, patience=1, cooldown=1, threshold=0.001
        )
    elif args.lr_scheduler == "cosine":
        lr_scheduler = CosineAnnealingLR(
            optimizer, T_max=median_num_iter * args.num_train_epochs
        )
    elif args.lr_scheduler == "cosine_warm":
        lr_scheduler = CosineAnnealingWarmRestarts(
            optimizer, T_0=median_num_iter * args.num_train_epochs
        )
    elif args.lr_scheduler == "mannul":

        def lr_lambda_fun(epoch):
            return pow(0.2, np.sum(lr_reduce_list <= epoch))

        lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda_fun)

    startIterID = 0
    global_step = 0
    start_epoch = 0

    if args.resume_file != "" and os.path.exists(args.resume_file):
        checkpoint = torch.load(args.resume_file, map_location="cpu")
        new_dict = {}
        for attr in checkpoint["model_state_dict"]:
            if attr.startswith("module."):
                new_dict[attr.replace("module.", "", 1)] = checkpoint[
                    "model_state_dict"
                ][attr]
            else:
                new_dict[attr] = checkpoint["model_state_dict"][attr]
        model.load_state_dict(new_dict)
        warmup_scheduler.load_state_dict(checkpoint["warmup_scheduler_state_dict"])
        # lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        global_step = checkpoint["global_step"]
        start_epoch = int(checkpoint["epoch_id"]) + 1
        task_stop_controller = checkpoint["task_stop_controller"]
        tbLogger = checkpoint["tb_logger"]
        del checkpoint

    model.to(device)

    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model, delay_allreduce=True)

    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if default_gpu:
        print("***** Running training *****")
        print("  Num Iters: ", task_num_iters)
        print("  Batch size: ", task_batch_size)
        print("  Num steps: %d" % num_train_optimization_steps)

    task_iter_train = {name: None for name in task_ids}
    task_count = {name: 0 for name in task_ids}
    for epochId in tqdm(range(start_epoch, args.num_train_epochs), desc="Epoch"):
        model.train()
        for step in range(median_num_iter):
            iterId = startIterID + step + (epochId * median_num_iter)
            first_task = True
            for task_id in task_ids:
                is_forward = False
                if (not task_stop_controller[task_id].in_stop) or (
                    iterId % args.train_iter_gap == 0
                ):
                    is_forward = True

                if is_forward:
                    loss, score = ForwardModelsTrain(
                        args,
                        task_cfg,
                        device,
                        task_id,
                        task_count,
                        task_iter_train,
                        task_dataloader_train,
                        model,
                        task_losses,
                    )

                    loss = loss * loss_scale[task_id]
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    loss.backward()
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        if args.fp16:
                            lr_this_step = args.learning_rate * warmup_linear(
                                global_step / num_train_optimization_steps,
                                args.warmup_proportion,
                            )
                            for param_group in optimizer.param_groups:
                                param_group["lr"] = lr_this_step

                        if first_task and (
                            global_step < warmpu_steps
                            or args.lr_scheduler == "warmup_linear"
                        ):
                            warmup_scheduler.step()

                        optimizer.step()
                        model.zero_grad()
                        if first_task:
                            global_step += 1
                            first_task = False

                        if default_gpu:
                            tbLogger.step_train(
                                epochId,
                                iterId,
                                float(loss),
                                float(score),
                                optimizer.param_groups[0]["lr"],
                                task_id,
                                "train",
                            )

            if "cosine" in args.lr_scheduler and global_step > warmpu_steps:
                lr_scheduler.step()

            if (
                step % (20 * args.gradient_accumulation_steps) == 0
                and step != 0
                and default_gpu
            ):
                tbLogger.showLossTrain()

            # decided whether to evaluate on each tasks.
            for task_id in task_ids:
                if (iterId != 0 and iterId % task_num_iters[task_id] == 0) or (
                    epochId == args.num_train_epochs - 1 and step == median_num_iter - 1
                ):
                    evaluate(
                        args,
                        task_dataloader_val,
                        task_stop_controller,
                        task_cfg,
                        device,
                        task_id,
                        model,
                        task_losses,
                        epochId,
                        default_gpu,
                        tbLogger,
                    )

        if args.lr_scheduler == "automatic":
            lr_scheduler.step(sum(val_scores.values()))
            logger.info("best average score is %3f" % lr_scheduler.best)
        elif args.lr_scheduler == "mannul":
            lr_scheduler.step()

        if epochId in lr_reduce_list:
            for task_id in task_ids:
                # reset the task_stop_controller once the lr drop
                task_stop_controller[task_id]._reset()

        if default_gpu:
            # Save a trained model
            logger.info("** ** * Saving fine - tuned model ** ** * ")
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Only save the model it-self
            output_model_file = os.path.join(
                savePath, "pytorch_model_" + str(epochId) + ".bin"
            )
            output_checkpoint = os.path.join(savePath, "pytorch_ckpt_latest.tar")
            torch.save(model_to_save.state_dict(), output_model_file)
            torch.save(
                {
                    "model_state_dict": model_to_save.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "warmup_scheduler_state_dict": warmup_scheduler.state_dict(),
                    # 'lr_scheduler_state_dict': lr_scheduler.state_dict(),
                    "global_step": global_step,
                    "epoch_id": epochId,
                    "task_stop_controller": task_stop_controller,
                    "tb_logger": tbLogger,
                },
                output_checkpoint,
            )
    tbLogger.txt_close()
Beispiel #5
0
class TRADE(nn.Module):
    def __init__(self,
                 hidden_size,
                 lang,
                 path,
                 task,
                 lr,
                 dropout,
                 slots,
                 gating_dict,
                 t_total,
                 device,
                 nb_train_vocab=0):
        super(TRADE, self).__init__()
        self.name = "TRADE"
        self.task = task
        self.hidden_size = hidden_size
        self.lang = lang[0]
        self.mem_lang = lang[1]
        self.lr = lr
        self.dropout = dropout
        self.slots = slots[0]
        self.slot_temp = slots[2]
        self.gating_dict = gating_dict
        self.device = device
        self.nb_gate = len(gating_dict)
        self.cross_entorpy = nn.CrossEntropyLoss()
        self.cell_type = args['cell_type']

        if args['encoder'] == 'RNN':
            self.encoder = EncoderRNN(self.lang.n_words, hidden_size,
                                      self.dropout, self.device,
                                      self.cell_type)
            self.decoder = Generator(self.lang, self.encoder.embedding,
                                     self.lang.n_words, hidden_size,
                                     self.dropout, self.slots, self.nb_gate,
                                     self.device, self.cell_type)
        elif args['encoder'] == 'TPRNN':
            self.encoder = EncoderTPRNN(self.lang.n_words, hidden_size,
                                        self.dropout, self.device,
                                        self.cell_type, args['nSymbols'],
                                        args['nRoles'], args['dSymbols'],
                                        args['dRoles'], args['temperature'],
                                        args['scale_val'], args['train_scale'])
            self.decoder = Generator(self.lang, self.encoder.embedding,
                                     self.lang.n_words, hidden_size,
                                     self.dropout, self.slots, self.nb_gate,
                                     self.device, self.cell_type)
        else:
            self.encoder = BERTEncoder(hidden_size, self.dropout, self.device)
            self.decoder = Generator(self.lang, None, self.lang.n_words,
                                     hidden_size, self.dropout, self.slots,
                                     self.nb_gate, self.device, self.cell_type)

        if path:
            print("MODEL {} LOADED".format(str(path)))
            trained_encoder = torch.load(str(path) + '/enc.th',
                                         map_location=self.device)
            trained_decoder = torch.load(str(path) + '/dec.th',
                                         map_location=self.device)

            # fix small confusion between old and newer trained models
            encoder_dict = trained_encoder.state_dict()
            new_encoder_dict = {}
            for key in encoder_dict:
                mapped_key = key
                if key.startswith('gru.'):
                    mapped_key = 'rnn.' + key[len('gru.'):]
                new_encoder_dict[mapped_key] = encoder_dict[key]

            decoder_dict = trained_decoder.state_dict()
            new_decoder_dict = {}
            for key in decoder_dict:
                mapped_key = key
                if key.startswith('gru.'):
                    mapped_key = 'rnn.' + key[len('gru.'):]
                new_decoder_dict[mapped_key] = decoder_dict[key]

            if not 'W_slot_embed.weight' in new_decoder_dict:
                new_decoder_dict['W_slot_embed.weight'] = torch.zeros(
                    (hidden_size, 2 * hidden_size), requires_grad=False)
                new_decoder_dict['W_slot_embed.bias'] = torch.zeros(
                    (hidden_size, ), requires_grad=False)

            self.encoder.load_state_dict(new_encoder_dict)
            self.decoder.load_state_dict(new_decoder_dict)

        # Initialize optimizers and criterion
        if args['encoder'] == 'RNN':
            self.optimizer = optim.Adam(self.parameters(), lr=lr)
            self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                            mode='max',
                                                            factor=0.5,
                                                            patience=1,
                                                            min_lr=0.0001,
                                                            verbose=True)
        else:
            if args['local_rank'] != -1:
                t_total = t_total // torch.distributed.get_world_size()

            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in self.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.01
            }, {
                'params': [
                    p for n, p in self.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]
            self.optimizer = AdamW(optimizer_grouped_parameters,
                                   lr=args['learn'],
                                   correct_bias=False)
            self.scheduler = WarmupLinearSchedule(
                self.optimizer,
                warmup_steps=args['warmup_proportion'] * t_total,
                t_total=t_total)

        self.reset()

    def print_loss(self):
        print_loss_avg = self.loss / self.print_every
        print_loss_ptr = self.loss_ptr / self.print_every
        print_loss_gate = self.loss_gate / self.print_every
        print_loss_class = self.loss_class / self.print_every
        # print_loss_domain = self.loss_domain / self.print_every
        self.print_every += 1
        return 'L:{:.2f},LP:{:.2f},LG:{:.2f}'.format(print_loss_avg,
                                                     print_loss_ptr,
                                                     print_loss_gate)

    def save_model(self, dec_type):
        directory = 'save/TRADE-' + args["addName"] + args['dataset'] + str(
            self.task) + '/' + 'HDD' + str(self.hidden_size) + 'BSZ' + str(
                args['batch']) + 'DR' + str(self.dropout) + str(dec_type)
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save(self.encoder, directory + '/enc.th')
        torch.save(self.decoder, directory + '/dec.th')

    def reset(self):
        self.loss, self.print_every, self.loss_ptr, self.loss_gate, self.loss_class = 0, 1, 0, 0, 0

    def forward(self, data, clip, slot_temp, reset=0, n_gpu=0):
        if reset: self.reset()
        # Zero gradients of both optimizers
        self.optimizer.zero_grad()

        # Encode and Decode
        use_teacher_forcing = random.random() < args["teacher_forcing_ratio"]
        all_point_outputs, gates, words_point_out, words_class_out = self.encode_and_decode(
            data, use_teacher_forcing, slot_temp)

        loss_ptr = masked_cross_entropy_for_value(
            all_point_outputs.transpose(0, 1).contiguous(),
            data["generate_y"].contiguous(
            ),  #[:,:len(self.point_slots)].contiguous(),
            data["y_lengths"])  #[:,:len(self.point_slots)])
        loss_gate = self.cross_entorpy(
            gates.transpose(0, 1).contiguous().view(-1, gates.size(-1)),
            data["gating_label"].contiguous().view(-1))

        if args["use_gate"]:
            loss = loss_ptr + loss_gate
        else:
            loss = loss_ptr

        self.loss_grad = loss
        self.loss_ptr_to_bp = loss_ptr

        # Update parameters with optimizers
        self.loss += loss.item()
        self.loss_ptr += loss_ptr.item()
        self.loss_gate += loss_gate.item()

        return self.loss_grad

    def optimize_GEM(self, clip):
        torch.nn.utils.clip_grad_norm_(self.parameters(), clip)
        self.optimizer.step()
        if isinstance(self.scheduler, WarmupLinearSchedule):
            self.scheduler.step()

    def encode_and_decode(self, data, use_teacher_forcing, slot_temp):
        if args['encoder'] == 'RNN' or args['encoder'] == 'TPRNN':
            # Build unknown mask for memory to encourage generalization
            if args['unk_mask'] and self.decoder.training:
                story_size = data['context'].size()
                rand_mask = np.ones(story_size)
                bi_mask = np.random.binomial(
                    [np.ones(
                        (story_size[0], story_size[1]))], 1 - self.dropout)[0]
                rand_mask = rand_mask * bi_mask
                rand_mask = torch.Tensor(rand_mask).to(self.device)
                story = data['context'] * rand_mask.long()
            else:
                story = data['context']

            story = story.to(self.device)
            # encoded_outputs, encoded_hidden = self.encoder(story.transpose(0, 1), data['context_len'])
            encoded_outputs, encoded_hidden = self.encoder(
                story, data['context_len'])

        # Encode dialog history
        # story  32 396
        # data['context_len'] 32
        elif args['encoder'] == 'BERT':
            # import pdb; pdb.set_trace()
            story = data['context']
            # story_plain = data['context_plain']

            all_input_ids = data['all_input_ids']
            all_input_mask = data['all_input_mask']
            all_segment_ids = data['all_segment_ids']
            all_sub_word_masks = data['all_sub_word_masks']

            encoded_outputs, encoded_hidden = self.encoder(
                all_input_ids, all_input_mask, all_segment_ids,
                all_sub_word_masks)
            encoded_hidden = encoded_hidden.unsqueeze(0)

        # Get the words that can be copied from the memory
        # import pdb; pdb.set_trace()
        batch_size = len(data['context_len'])
        self.copy_list = data['context_plain']
        max_res_len = data['generate_y'].size(
            2) if self.encoder.training else 10

        all_point_outputs, all_gate_outputs, words_point_out, words_class_out = self.decoder.forward(batch_size, \
            encoded_hidden, encoded_outputs, data['context_len'], story, max_res_len, data['generate_y'], \
            use_teacher_forcing, slot_temp)

        return all_point_outputs, all_gate_outputs, words_point_out, words_class_out

    def evaluate(self,
                 dev,
                 matric_best,
                 slot_temp,
                 device,
                 save_dir="",
                 save_string="",
                 early_stop=None):
        # Set to not-training mode to disable dropout
        self.encoder.train(False)
        self.decoder.train(False)
        print("STARTING EVALUATION")
        all_prediction = {}
        inverse_unpoint_slot = dict([(v, k)
                                     for k, v in self.gating_dict.items()])
        pbar = enumerate(dev)
        for j, data_dev in pbar:
            # Encode and Decode
            eval_data = {}
            # wrap all numerical values as tensors for multi-gpu training
            for k, v in data_dev.items():
                if isinstance(v, torch.Tensor):
                    eval_data[k] = v.to(device)
                elif isinstance(v, list):
                    if k in [
                            'ID', 'turn_belief', 'context_plain',
                            'turn_uttr_plain'
                    ]:
                        eval_data[k] = v
                    else:
                        eval_data[k] = torch.tensor(v).to(device)
                else:
                    # print('v is: {} and this ignoring {}'.format(v, k))
                    pass
            batch_size = len(data_dev['context_len'])
            with torch.no_grad():
                _, gates, words, class_words = self.encode_and_decode(
                    eval_data, False, slot_temp)

            for bi in range(batch_size):
                if data_dev["ID"][bi] not in all_prediction.keys():
                    all_prediction[data_dev["ID"][bi]] = {}
                all_prediction[data_dev["ID"][bi]][data_dev["turn_id"][bi]] = {
                    "turn_belief": data_dev["turn_belief"][bi]
                }
                predict_belief_bsz_ptr, predict_belief_bsz_class = [], []
                gate = torch.argmax(gates.transpose(0, 1)[bi], dim=1)
                # import pdb; pdb.set_trace()

                # pointer-generator results
                if args["use_gate"]:
                    for si, sg in enumerate(gate):
                        if sg == self.gating_dict["none"]:
                            continue
                        elif sg == self.gating_dict["ptr"]:
                            pred = np.transpose(words[si])[bi]
                            st = []
                            for e in pred:
                                if e == 'EOS': break
                                else: st.append(e)
                            st = " ".join(st)
                            if st == "none":
                                continue
                            else:
                                predict_belief_bsz_ptr.append(slot_temp[si] +
                                                              "-" + str(st))
                        else:
                            predict_belief_bsz_ptr.append(
                                slot_temp[si] + "-" +
                                inverse_unpoint_slot[sg.item()])
                else:
                    for si, _ in enumerate(gate):
                        pred = np.transpose(words[si])[bi]
                        st = []
                        for e in pred:
                            if e == 'EOS': break
                            else: st.append(e)
                        st = " ".join(st)
                        if st == "none":
                            continue
                        else:
                            predict_belief_bsz_ptr.append(slot_temp[si] + "-" +
                                                          str(st))

                all_prediction[data_dev["ID"][bi]][data_dev["turn_id"][bi]][
                    "pred_bs_ptr"] = predict_belief_bsz_ptr

                #if set(data_dev["turn_belief"][bi]) != set(predict_belief_bsz_ptr) and args["genSample"]:
                #    print("True", set(data_dev["turn_belief"][bi]) )
                #    print("Pred", set(predict_belief_bsz_ptr), "\n")

        if args["genSample"]:
            if save_dir is not "" and not os.path.exists(save_dir):
                os.mkdir(save_dir)
            json.dump(all_prediction,
                      open(
                          os.path.join(
                              save_dir, "prediction_{}_{}.json".format(
                                  self.name, save_string)), 'w'),
                      indent=4)
            print(
                "saved generated samples",
                os.path.join(
                    save_dir,
                    "prediction_{}_{}.json".format(self.name, save_string)))

        joint_acc_score_ptr, F1_score_ptr, turn_acc_score_ptr = self.evaluate_metrics(
            all_prediction, "pred_bs_ptr", slot_temp)

        evaluation_metrics = {
            "Joint Acc": joint_acc_score_ptr,
            "Turn Acc": turn_acc_score_ptr,
            "Joint F1": F1_score_ptr
        }
        print(evaluation_metrics)

        # Set back to training mode
        self.encoder.train(True)
        self.decoder.train(True)

        joint_acc_score = joint_acc_score_ptr  # (joint_acc_score_ptr + joint_acc_score_class)/2
        F1_score = F1_score_ptr

        if (early_stop == 'F1'):
            if (F1_score >= matric_best):
                self.save_model('ENTF1-{:.4f}'.format(F1_score))
                print("MODEL SAVED")
            return F1_score
        else:
            if (joint_acc_score >= matric_best):
                self.save_model('ACC-{:.4f}'.format(joint_acc_score))
                print("MODEL SAVED")
            return joint_acc_score

    def evaluate_metrics(self, all_prediction, from_which, slot_temp):
        total, turn_acc, joint_acc, F1_pred, F1_count = 0, 0, 0, 0, 0
        for d, v in all_prediction.items():
            for t in range(len(v)):
                cv = v[t]
                if set(cv["turn_belief"]) == set(cv[from_which]):
                    joint_acc += 1
                total += 1

                # Compute prediction slot accuracy
                temp_acc = self.compute_acc(set(cv["turn_belief"]),
                                            set(cv[from_which]), slot_temp)
                turn_acc += temp_acc

                # Compute prediction joint F1 score
                temp_f1, temp_r, temp_p, count = self.compute_prf(
                    set(cv["turn_belief"]), set(cv[from_which]))
                F1_pred += temp_f1
                F1_count += count

        joint_acc_score = joint_acc / float(total) if total != 0 else 0
        turn_acc_score = turn_acc / float(total) if total != 0 else 0
        F1_score = F1_pred / float(F1_count) if F1_count != 0 else 0
        return joint_acc_score, F1_score, turn_acc_score

    def compute_acc(self, gold, pred, slot_temp):
        miss_gold = 0
        miss_slot = []
        for g in gold:
            if g not in pred:
                miss_gold += 1
                miss_slot.append(g.rsplit("-", 1)[0])
        wrong_pred = 0
        for p in pred:
            if p not in gold and p.rsplit("-", 1)[0] not in miss_slot:
                wrong_pred += 1
        ACC_TOTAL = len(slot_temp)
        ACC = len(slot_temp) - miss_gold - wrong_pred
        ACC = ACC / float(ACC_TOTAL)
        return ACC

    def compute_prf(self, gold, pred):
        TP, FP, FN = 0, 0, 0
        if len(gold) != 0:
            count = 1
            for g in gold:
                if g in pred:
                    TP += 1
                else:
                    FN += 1
            for p in pred:
                if p not in gold:
                    FP += 1
            precision = TP / float(TP + FP) if (TP + FP) != 0 else 0
            recall = TP / float(TP + FN) if (TP + FN) != 0 else 0
            F1 = 2 * precision * recall / float(precision + recall) if (
                precision + recall) != 0 else 0
        else:
            if len(pred) == 0:
                precision, recall, F1, count = 1, 1, 1, 1
            else:
                precision, recall, F1, count = 0, 0, 0, 1
        return F1, recall, precision, count
Beispiel #6
0
    def fine_tune(
        self,
        train_dataloader,
        get_inputs,
        device,
        max_steps=-1,
        num_train_epochs=1,
        max_grad_norm=1.0,
        gradient_accumulation_steps=1,
        n_gpu=1,
        optimizer=None,
        scheduler=None,
        weight_decay=0.0,
        learning_rate=5e-5,
        adam_epsilon=1e-8,
        warmup_steps=0,
        fp16=False,
        fp16_opt_level="O1",
        local_rank=-1,
        verbose=True,
        seed=None,
    ):
        if seed is not None:
            Transformer.set_seed(seed, n_gpu > 0)

        if max_steps > 0:
            t_total = max_steps
            num_train_epochs = (
                max_steps // (len(train_dataloader) // gradient_accumulation_steps) + 1
            )
        else:
            t_total = len(train_dataloader) // gradient_accumulation_steps * num_train_epochs

        if optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [
                        p
                        for n, p in self.model.named_parameters()
                        if not any(nd in n for nd in no_decay)
                    ],
                    "weight_decay": weight_decay,
                },
                {
                    "params": [
                        p
                        for n, p in self.model.named_parameters()
                        if any(nd in n for nd in no_decay)
                    ],
                    "weight_decay": 0.0,
                },
            ]
            optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)

        if scheduler is None:
            scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=t_total)

        if fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex")
            self.model, optimizer = amp.initialize(self.model, optimizer, opt_level=fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        # Distributed training (should be after apex fp16 initialization)
        if local_rank != -1:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True,
            )

        global_step = 0
        tr_loss = 0.0
        self.model.zero_grad()
        train_iterator = trange(
            int(num_train_epochs), desc="Epoch", disable=local_rank not in [-1, 0] or not verbose
        )

        for _ in train_iterator:
            epoch_iterator = tqdm(
                train_dataloader, desc="Iteration", disable=local_rank not in [-1, 0] or not verbose
            )
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                batch = tuple(t.to(device) for t in batch)
                inputs = get_inputs(batch, self.model_name)
                outputs = self.model(**inputs)
                loss = outputs[0]

                if n_gpu > 1:
                    loss = loss.mean()
                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps

                if step % 10 == 0 and verbose:
                    tqdm.write("Loss:{:.6f}".format(loss))

                if fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

                tr_loss += loss.item()
                if (step + 1) % gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()
                    self.model.zero_grad()
                    global_step += 1

                if max_steps > 0 and global_step > max_steps:
                    epoch_iterator.close()
                    break
            if max_steps > 0 and global_step > max_steps:
                train_iterator.close()
                break

            # empty cache
            del [batch]
            torch.cuda.empty_cache()
        return global_step, tr_loss / global_step
Beispiel #7
0
def train(args, tokenizer, device):
    logger.info("loading data")
    train_dataloader = make_dataloader(args.train_file, args.max_seq_length,
                                       args.train_batch_size, tokenizer)
    valid_dataloader = make_dataloader(args.valid_file, args.max_seq_length,
                                       args.train_batch_size, tokenizer)

    logger.info("building model")
    model = BertMouth.from_pretrained(args.bert_model,
                                      num_labels=tokenizer.vocab_size)
    model.to(device)

    param_optimizer = list(model.named_parameters())
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    logger.info("setting optimizer")
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    optimization_steps = len(train_dataloader) * args.num_train_epochs
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=0,
                                     t_total=optimization_steps)
    # scheduler = get_linear_schedule_with_warmup(optimizer,
    #                                     num_warmup_steps=0, 
    #                                     num_training_steps=optimization_steps)
    loss_fct = CrossEntropyLoss(ignore_index=0)

    def calc_batch_loss(batch):
        batch = tuple(t.to(device) for t in batch)
        input_ids, y, input_mask, input_type_id, masked_pos = batch

        logits = model(input_ids, input_type_id, input_mask)
        logits = logits.view(-1, tokenizer.vocab_size)
        y = y.view(-1)
        loss = loss_fct(logits, y)
        return loss

    logger.info("train starts")
    model.train()
    summary_writer = SummaryWriter(log_dir="logs")
    generated_texts = []
    try:
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            train_loss = 0.
            running_num = 0
            for step, batch in enumerate(train_dataloader):
                loss = calc_batch_loss(batch)
                loss.backward()

                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

                train_loss += loss.item()
                running_num += len(batch[0])
            logger.info("[{0} epochs] "
                        "train loss: {1:.3g} ".format(epoch + 1,
                                                      train_loss / running_num))
            summary_writer.add_scalar("train_loss",
                                      train_loss / running_num, epoch)

            model.eval()
            valid_loss = 0.
            valid_num = 0
            for batch in valid_dataloader:
                valid_loss += calc_batch_loss(batch).item()
                valid_num += len(batch[0])

            generated_texts.append(generate(tokenizer=tokenizer,
                                            device=device,
                                            length=25,
                                            max_length=args.max_seq_length,
                                            model=model))
            logger.info("[{0} epochs] valid loss: {1:.3g}".format(epoch + 1,
                                                                  valid_loss / valid_num))
            summary_writer.add_scalar("val_loss",
                                      valid_loss / valid_num, epoch)

            model.train()
    except KeyboardInterrupt:
        logger.info("KeyboardInterrupt")

    summary_writer.close()
    dt_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
    save(args, model, tokenizer, str(dt_now))
Beispiel #8
0
def main():
    parser = ArgumentParser()
    parser.add_argument('--pregenerated_neg_data', type=Path, required=True)
    parser.add_argument('--pregenerated_pos_data', type=Path, required=True)
    parser.add_argument('--validation_neg_data', type=Path, required=True)
    parser.add_argument('--validation_pos_data', type=Path, required=True)
    parser.add_argument('--pregenerated_data', type=Path, required=True)
    parser.add_argument('--output_dir', type=Path, required=True)
    parser.add_argument('--exp_group', type=str, required=True)
    parser.add_argument(
        "--bert_model",
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--method",
                        type=str,
                        choices=[
                            'neg_samebatch', 'distill_samebatch',
                            'distill_samebatch_lstm', 'distill', 'kl',
                            'unlikelihood'
                        ])
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument("--save_before", action='store_true')
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Store training data as on-disc memmaps to massively reduce memory usage"
    )

    parser.add_argument("--max_seq_len", default=512, type=int)

    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument("--epochs",
                        type=int,
                        default=3,
                        help="Number of epochs to train for")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--port_idx", type=int)

    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--valid_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--kr_freq", default=0.0, type=float)
    parser.add_argument("--mlm_freq", default=0, type=float)
    parser.add_argument("--kl_w", default=1000, type=float)
    parser.add_argument("--ul_w", default=1, type=float)
    parser.add_argument("--gamma",
                        default=0.5,
                        type=float,
                        help="coeff of UL and 1-coeff of LL")
    parser.add_argument('--no_mlm',
                        action='store_true',
                        help="don't do any MLM training")
    parser.add_argument("--no_tie",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--no_ul',
                        action='store_true',
                        help="don't do any UL training")
    parser.add_argument('--no_ll',
                        action='store_true',
                        help="don't do any LL training")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--learning_rate",
                        default=1e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    args = parser.parse_args()





    assert args.pregenerated_data.is_dir(), \
        "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"

    samples_per_epoch = []
    for i in range(args.epochs):
        epoch_file = args.pregenerated_data / f"epoch_{i}.json"
        metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json"
        if epoch_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch.append(metrics['num_training_examples'])
        else:
            if i == 0:
                exit("No training data was found!")
            print(
                f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs})."
            )
            print(
                "This script will loop over the available data, but training diversity may be negatively impacted."
            )
            num_data_epochs = i
            break
    else:
        num_data_epochs = args.epochs

    if args.local_rank == -1 or args.no_cuda:
        print(torch.cuda.is_available())
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
        print("Num of gpus: ", n_gpu)
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        print("GPU Device: ", device)
        n_gpu = 1
        dist_comms.init_distributed_training(args.local_rank, args.port_idx)
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    logging.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    # if n_gpu > 0:
    torch.cuda.manual_seed_all(args.seed)

    pt_output = Path(getenv('PT_OUTPUT_DIR', ''))
    args.output_dir = Path(os.path.join(pt_output, args.output_dir))

    if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
        logging.warning(
            f"Output directory ({args.output_dir}) already exists and is not empty!"
        )
    args.output_dir.mkdir(parents=True, exist_ok=True)

    if args.bert_model != "roberta-base":
        tokenizer = BertTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case)
    else:
        tokenizer = RobertaTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case)
        tokenizer.vocab = tokenizer.encoder

    total_train_examples = 0
    for i in range(args.epochs):
        # The modulo takes into account the fact that we may loop over limited epochs of data
        total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]

    num_train_optimization_steps = int(total_train_examples /
                                       args.train_batch_size /
                                       args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    # Prepare model
    if args.bert_model != "roberta-base":
        if args.method == "neg_samebatch":
            config = BertConfig.from_pretrained(args.bert_model)
            config.bert_model = args.bert_model
            core_model = BertForNegSameBatch.from_pretrained(args.bert_model,
                                                             args.gamma,
                                                             config=config)
            core_model.init_orig_bert()
        elif args.method == "unlikelihood":
            config = BertConfig.from_pretrained(args.bert_model)
            core_model = BertForNegPreTraining.from_pretrained(args.bert_model,
                                                               config=config)
        else:
            raise NotImplementedError(
                f"method {args.method} is not implemented")
    else:
        config = RobertaConfig.from_pretrained(args.bert_model)
        core_model = RobertaForNegPreTraining.from_pretrained(args.bert_model)

    core_model = core_model.to(device)

    # Prepare optimizer
    param_optimizer = list(core_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=num_train_optimization_steps)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        core_model, optimizer = amp.initialize(core_model,
                                               optimizer,
                                               opt_level=args.fp16_opt_level)

    model = torch.nn.parallel.DistributedDataParallel(
        core_model,
        device_ids=[args.local_rank],
        output_device=args.local_rank,
        find_unused_parameters=True)

    global_step = 0
    logging.info("***** Running training *****")
    logging.info(f"  Num examples = {total_train_examples}")
    logging.info("  Batch size = %d", args.train_batch_size)
    logging.info("  Num steps = %d", num_train_optimization_steps)
    model.train()

    if args.local_rank == 0 or args.local_rank == -1:
        if args.save_before:
            before_train_path = Path(
                os.path.join(args.output_dir, "before_training"))
            print("Before training path: ", before_train_path)
            before_train_path.mkdir(parents=True, exist_ok=True)
            model.module.save_pretrained(
                os.path.join(args.output_dir, "before_training"))
            tokenizer.save_pretrained(
                os.path.join(args.output_dir, "before_training"))

        # writer = SummaryWriter(log_dir=args.output_dir)
        wandb.init(project="neg_v2",
                   name=str(args.output_dir).split("/")[-1],
                   group=args.exp_group,
                   entity='negation')
        mlm_averagemeter = AverageMeter()
        ul_averagemeter = AverageMeter()
        ll_averagemeter = AverageMeter()
        kl_averagemeter = AverageMeter()

    neg_epoch_dataset = PregeneratedDataset(
        epoch=0,
        training_path=args.pregenerated_neg_data,
        tokenizer=tokenizer,
        num_data_epochs=num_data_epochs,
        reduce_memory=args.reduce_memory)

    pos_epoch_dataset = PregeneratedDataset(
        epoch=0,
        training_path=args.pregenerated_pos_data,
        tokenizer=tokenizer,
        num_data_epochs=num_data_epochs,
        reduce_memory=args.reduce_memory)

    neg_validation_dataset = PregeneratedDataset(
        epoch=0,
        training_path=args.validation_neg_data,
        tokenizer=tokenizer,
        num_data_epochs=num_data_epochs,
        reduce_memory=args.reduce_memory)
    pos_validation_dataset = PregeneratedDataset(
        epoch=0,
        training_path=args.validation_pos_data,
        tokenizer=tokenizer,
        num_data_epochs=num_data_epochs,
        reduce_memory=args.reduce_memory)

    if args.local_rank == -1:
        neg_train_sampler = RandomSampler(neg_epoch_dataset)
        pos_train_sampler = RandomSampler(pos_epoch_dataset)

        neg_valid_sampler = RandomSampler(neg_validation_dataset)
        pos_valid_sampler = RandomSampler(pos_validation_dataset)
    else:
        neg_train_sampler = DistributedSampler(neg_epoch_dataset)
        pos_train_sampler = DistributedSampler(pos_epoch_dataset)

        neg_valid_sampler = DistributedSampler(neg_validation_dataset)
        pos_valid_sampler = DistributedSampler(pos_validation_dataset)

    neg_train_dataloader = DataLoader(neg_epoch_dataset,
                                      sampler=neg_train_sampler,
                                      batch_size=args.train_batch_size)
    pos_train_dataloader = DataLoader(pos_epoch_dataset,
                                      sampler=pos_train_sampler,
                                      batch_size=args.train_batch_size)

    neg_valid_dataloader = DataLoader(neg_validation_dataset,
                                      sampler=neg_valid_sampler,
                                      batch_size=args.valid_batch_size)
    pos_valid_dataloader = DataLoader(pos_validation_dataset,
                                      sampler=pos_valid_sampler,
                                      batch_size=args.valid_batch_size)

    def inf_train_gen():
        while True:
            for kr_step, kr_batch in enumerate(neg_train_dataloader):
                yield kr_step, kr_batch

    kr_gen = inf_train_gen()

    def pos_inf_train_gen():
        while True:
            for kr_step, kr_batch in enumerate(pos_train_dataloader):
                yield kr_step, kr_batch

    pos_kr_gen = pos_inf_train_gen()

    mlm_loss, neg_loss = 0, 0
    mlm_nb_it, neg_nb_it = 1, 1
    mlm_nb_ex, neg_nb_ex = 0, 0

    for epoch in range(args.epochs):
        epoch_dataset = PregeneratedDataset(
            epoch=epoch,
            training_path=args.pregenerated_data,
            tokenizer=tokenizer,
            num_data_epochs=num_data_epochs,
            reduce_memory=args.reduce_memory)
        if args.local_rank == -1:
            train_sampler = RandomSampler(epoch_dataset)
        else:
            train_sampler = DistributedSampler(epoch_dataset)

        train_dataloader = DataLoader(epoch_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0

        ul_tr_loss = 0
        nb_ul_tr_examples, nb_ul_tr_steps = 0, 1
        ll_tr_loss = 0
        nb_ll_tr_examples, nb_ll_tr_steps = 0, 1
        kl_tr_loss = 0
        nb_kl_tr_examples, nb_kl_tr_steps = 0, 1

        if n_gpu > 1 and args.local_rank == -1 or (n_gpu <= 1
                                                   and args.local_rank == 0):
            logging.info("** ** * Saving fine-tuned model ** ** * ")
            model.module.save_pretrained(args.output_dir)
            tokenizer.save_pretrained(args.output_dir)

        with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
            for step, batch in enumerate(train_dataloader):
                if not args.no_mlm and (random.random() > args.mlm_freq):
                    model.train()
                    batch = tuple(t.to(device) for t in batch)
                    input_ids, input_mask, segment_ids, lm_label_ids = batch

                    outputs = model(input_ids=input_ids,
                                    attention_mask=input_mask,
                                    token_type_ids=segment_ids,
                                    masked_lm_labels=lm_label_ids,
                                    negated=False)

                    loss = outputs[1]
                    loss_dict = outputs[0]
                    mlm_loss += loss_dict['mlm'].item()

                    mlm_nb_it += 1
                    mlm_nb_ex += input_ids.size(0)

                    if n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.

                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

                    tr_loss += loss.item()

                    if args.local_rank == 0 or args.local_rank == -1:
                        mlm_averagemeter.update(loss_dict['mlm'].item())
                        # writer.add_scalar('MLM/train', loss_dict['mlm'].item(), mlm_nb_it)
                        wandb.log({'MLM/train': loss_dict['mlm'].item()})

                        nb_tr_steps += 1
                        nb_ll_tr_steps += 1
                        mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps

                        pbar.set_postfix_str(
                            f"MLM: {mlm_averagemeter:.6f}, UL: {ul_averagemeter:.6f}, LL: {ll_averagemeter:.6f}, KL: {kl_averagemeter:.6f}"
                        )

                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        scheduler.step()  # Update learning rate schedule
                        optimizer.step()
                        optimizer.zero_grad()
                        global_step += 1
                pbar.update(1)
                random_num = random.random()
                if random_num > args.kr_freq:
                    if args.method in ["neg_samebatch"]:
                        ul_step, ul_batch = next(kr_gen)
                        ul_batch = tuple(t.to(device) for t in ul_batch)
                        ul_input_ids, ul_input_mask, ul_segment_ids, ul_lm_label_ids = ul_batch

                        ll_step, ll_batch = next(pos_kr_gen)
                        ll_batch = tuple(t.to(device) for t in ll_batch)
                        ll_input_ids, ll_input_mask, ll_segment_ids, ll_lm_label_ids = ll_batch

                        batch_mask = torch.zeros(
                            (ul_input_ids.size(0) + ll_input_ids.size(0)),
                            dtype=ll_input_mask.dtype,
                            device=device)
                        batch_mask[:ul_input_ids.size(0)] = 1.

                        outputs = model(
                            input_ids=torch.cat([ul_input_ids, ll_input_ids],
                                                0),
                            attention_mask=torch.cat(
                                [ul_input_mask, ll_input_mask], 0),
                            token_type_ids=torch.cat(
                                [ul_segment_ids, ll_segment_ids], 0),
                            masked_lm_labels=torch.cat(
                                [ul_lm_label_ids, ll_lm_label_ids], 0),
                            negated=True,
                            batch_neg_mask=batch_mask)

                        loss = outputs[1] * args.ul_w
                        loss_dict = outputs[0]

                        if args.local_rank == 0 or args.local_rank == -1:
                            wandb.log({
                                'UL/train': loss_dict['neg'].item(),
                                'LL/train': loss_dict['pos'].item()
                            })
                            ul_averagemeter.update(loss_dict['neg'].item())
                            ll_averagemeter.update(loss_dict['pos'].item())
                        neg_nb_it += 1

                    elif random.random() > 0.5 and not args.no_ul:
                        kr_step, kr_batch = next(kr_gen)
                        kr_batch = tuple(t.to(device) for t in kr_batch)
                        input_ids, input_mask, segment_ids, lm_label_ids = kr_batch

                        outputs = model(input_ids=input_ids,
                                        attention_mask=input_mask,
                                        token_type_ids=segment_ids,
                                        masked_lm_labels=lm_label_ids,
                                        negated=True)

                        loss = outputs[1] * args.ul_w

                        loss_dict = outputs[0]
                        nb_ul_tr_steps += 1

                        neg_loss += loss_dict['neg'].item()
                        if args.local_rank == 0 or args.local_rank == -1:
                            wandb.log({
                                'UL/train':
                                loss_dict['neg'].item(),
                                'KL/train':
                                loss_dict['kl'].item() * args.kl_w
                            })
                            ul_averagemeter.update(loss_dict['neg'].item())
                            kl_averagemeter.update(loss_dict['kl'].item() *
                                                   args.kl_w)

                        neg_nb_it += 1
                    elif not args.no_ll:
                        kr_step, kr_batch = next(pos_kr_gen)
                        kr_batch = tuple(t.to(device) for t in kr_batch)
                        input_ids, input_mask, segment_ids, lm_label_ids = kr_batch

                        outputs = model(input_ids=input_ids,
                                        attention_mask=input_mask,
                                        token_type_ids=segment_ids,
                                        masked_lm_labels=lm_label_ids,
                                        negated=False)
                        loss = outputs[1]
                        loss_dict = outputs[0]
                        nb_ll_tr_steps += 1

                        mlm_loss += loss_dict['mlm'].item()

                        mlm_nb_it += 1
                        if args.local_rank == 0 or args.local_rank == -1:
                            wandb.log({'LL/train': loss_dict['mlm'].item()})
                            ll_averagemeter.update(loss_dict['mlm'].item())

                        mlm_nb_ex += input_ids.size(0)
                    else:
                        continue

                    if n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

                    tr_loss += loss.item()
                    if args.local_rank == 0 or args.local_rank == -1:
                        nb_tr_steps += 1
                        mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
                        pbar.set_postfix_str(
                            f"MLM: {mlm_averagemeter:.6f}, UL: {ul_averagemeter:.6f}, LL: {ll_averagemeter:.6f}, KL: {kl_averagemeter:.6f}"
                        )
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        scheduler.step()  # Update learning rate schedule
                        optimizer.step()
                        optimizer.zero_grad()
                        global_step += 1
                if n_gpu > 1 and args.local_rank == -1 or (
                        n_gpu <= 1 and args.local_rank == 0):
                    if False and (step + 1) % 100 == 0:
                        neg_valid_res = validate(
                            model=model,
                            dataloader=neg_valid_dataloader,
                            device=device,
                            negated=True)
                        pos_valid_res = validate(
                            model=model,
                            dataloader=pos_valid_dataloader,
                            device=device,
                            negated=False)
                        wandb.log({
                            'neg/valid/p@1': neg_valid_res % 100.,
                            'pos/valid/p@1': pos_valid_res % 100.
                        })

    # Save a trained model
    if n_gpu > 1 and args.local_rank == -1 or (n_gpu <= 1
                                               and args.local_rank == 0):
        print("Saving model")
        logging.info("** ** * Saving fine-tuned model ** ** * ")
        model.module.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
        print(str(wandb.run.id))
        pickle.dump(
            str(wandb.run.id),
            open(os.path.join(args.output_dir, 'wandb_run_id.pkl'), 'wb'))
Beispiel #9
0
def train(args, tokenizer, device):
    logger.info("loading data")

    # 教師データを読み込む
    train_dataloader = make_dataloader(args.train_file, args.max_seq_length,
                                       args.train_batch_size, tokenizer)
    # 評価データを読み込む
    valid_dataloader = make_dataloader(args.valid_file, args.max_seq_length,
                                       args.train_batch_size, tokenizer)

    logger.info("building model")

    # BertMouthモデルの事前学習モデル(おそらく通常のBERTモデルと同じ構造)を読み込む
    model = BertMouth.from_pretrained(args.bert_model,
                                      num_labels=tokenizer.vocab_size)
    # GPU/CPUを設定する
    model.to(device)

    # 名前がpoolerではないパラメータを取得する
    param_optimizer = list(model.named_parameters())
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    logger.info("setting optimizer")

    # decayに含まれるパラメータとそうでないパラメータにわける
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    # ステップは教師データの大きさ * epoch
    optimization_steps = len(train_dataloader) * args.num_train_epochs
    # 最適化アルゴリズムの指定
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate, eps=args.adam_epsilon)
    # スケジューラーは学習率を調整してくれる
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=0,
                                     t_total=optimization_steps)
    loss_fct = CrossEntropyLoss(ignore_index=0)

    def calc_batch_loss(batch):
        # データ型をGPU/CPUにあわせて変更
        batch = tuple(t.to(device) for t in batch)
        input_ids, y, input_mask, input_type_id, masked_pos = batch

        # モデルから出力を計算
        logits = model(input_ids, input_type_id, input_mask)
        logits = logits.view(-1, tokenizer.vocab_size)
        y = y.view(-1)

        # 損失を計算
        loss = loss_fct(logits, y)
        return loss

    logger.info("train starts")
    # モデルを学習モードに変更
    model.train()
    # ログの出力先を指定
    summary_writer = SummaryWriter(log_dir="logs")

    generated_texts = []
    try:
        # trangeは進捗バー表示ができる便利ツール
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            train_loss = 0.
            running_num = 0

            # 教師データをバッチ別に処理していく
            for step, batch in enumerate(train_dataloader):

                # 損失を計算
                loss = calc_batch_loss(batch)

                # 勾配を計算、パラメーターを更新
                loss.backward()
                optimizer.step()
                scheduler.step()

                # 勾配は適宜初期化?
                optimizer.zero_grad()

                # 出力用、損失の合計とステップ数
                train_loss += loss.item()
                running_num += len(batch[0])
            logger.info("[{0} epochs] "
                        "train loss: {1:.3g} ".format(epoch + 1,
                                                      train_loss / running_num))
            summary_writer.add_scalar("train_loss",
                                      train_loss / running_num, epoch)

            # モデルを評価モードにする
            model.eval()
            valid_loss = 0.
            valid_num = 0
            for batch in valid_dataloader:
                # 評価データに対して予測を適用、出力用
                valid_loss += calc_batch_loss(batch).item()
                valid_num += len(batch[0])

            # 出力用リストにデータを追加
            generated_texts.append(generate(tokenizer=tokenizer,
                                            device=device,
                                            length=25,
                                            max_length=args.max_seq_length,
                                            model=model))
            logger.info("[{0} epochs] valid loss: {1:.3g}".format(epoch + 1,
                                                                  valid_loss / valid_num))
            summary_writer.add_scalar("val_loss",
                                      valid_loss / valid_num, epoch)

            # 学習モードに再度切り替え
            model.train()
    except KeyboardInterrupt:
        logger.info("KeyboardInterrupt")

    # ログ出力を終了
    summary_writer.close()
    dt_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
    # モデルを保存
    save(args, model, tokenizer, str(dt_now))
Beispiel #10
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--vocab_file",
                        default='bert-base-uncased-vocab.txt',
                        type=str,
                        required=True)
    parser.add_argument("--model_file",
                        default='bert-base-uncased.tar.gz',
                        type=str,
                        required=True)
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )
    parser.add_argument(
        "--predict_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the predictions will be written.")

    # Other parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json"
    )
    parser.add_argument("--test_file", default=None, type=str)
    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=2.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
        "of training.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        default=False,
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--view_id',
                        type=int,
                        default=1,
                        help="view id of multi-view co-training(two-view)")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--save_all', default=False, action='store_true')
    parser.add_argument('--max_grad_norm', default=1.0, type=float)
    parser.add_argument('--weight_decay', default=0.0, type=float)
    parser.add_argument('--adam_epsilon', default=1e-8, type=float)
    parser.add_argument('--patience', default=5, type=int)

    # Base setting
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--max_ctx', type=int, default=2)
    parser.add_argument('--task_name', type=str, default='race')
    parser.add_argument('--bert_name', type=str, default='pool-race')
    parser.add_argument('--reader_name', type=str, default='race')
    parser.add_argument('--per_eval_step', type=int, default=10000000)
    # model parameters
    parser.add_argument('--evidence_lambda', type=float, default=0.8)
    # Parameters for running labeling model
    parser.add_argument('--do_label', default=False, action='store_true')
    parser.add_argument('--sentence_id_file', nargs='*')
    parser.add_argument('--weight_threshold', type=float, default=0.0)
    parser.add_argument('--only_correct', default=False, action='store_true')
    parser.add_argument('--label_threshold', type=float, default=0.0)
    parser.add_argument('--multi_evidence', default=False, action='store_true')
    parser.add_argument('--metric', default='accuracy', type=str)
    parser.add_argument('--num_evidence', default=1, type=int)
    parser.add_argument('--power_length', default=1., type=float)
    parser.add_argument('--num_choices', default=4, type=int)
    parser.add_argument('--split_type', default=0, type=int)

    args = parser.parse_args()

    logger = setting_logger(args.output_dir)
    logger.info('================== Program start. ========================')

    model_params = prepare_model_params(args)
    read_params = prepare_read_params(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict and not args.do_label:
        raise ValueError(
            "At least one of `do_train` or `do_predict` or `do_label` must be True."
        )

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    if args.do_train:
        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory () already exists and is not empty.")
        os.makedirs(args.output_dir, exist_ok=True)
        os.makedirs(os.path.join(args.output_dir, "best_model"), exist_ok=True)
        os.makedirs(os.path.join(args.output_dir, "best_loss_model"),
                    exist_ok=True)

    if args.do_predict or args.do_label:
        os.makedirs(args.predict_dir, exist_ok=True)

    # tokenizer = BertTokenizer.from_pretrained(args.vocab_file)
    tokenizer = get_tokenizer(args.bert_model).from_pretrained(args.vocab_file)

    data_reader = initialize_reader(args.reader_name)

    num_train_steps = None
    if args.do_train or args.do_label:
        train_examples = data_reader.read(input_file=args.train_file,
                                          **read_params)

        cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}_{4}_{5}'.format(
            args.bert_model, str(args.max_seq_length), str(args.doc_stride),
            str(args.max_query_length), str(args.max_ctx), str(args.task_name))

        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except FileNotFoundError:
            train_features = data_reader.convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s",
                            cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)

        num_train_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    if args.pretrain is not None:
        logger.info('Load pretrained model from {}'.format(args.pretrain))
        model_state_dict = torch.load(args.pretrain, map_location='cuda:0')
        model = initialize_model(args.bert_name,
                                 args.model_file,
                                 state_dict=model_state_dict,
                                 **model_params)
    else:
        model = initialize_model(args.bert_name, args.model_file,
                                 **model_params)

    # if args.fp16:
    #     model.half()
    model.to(device)

    t_total = num_train_steps if num_train_steps is not None else -1
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=int(args.warmup_proportion *
                                                      t_total),
                                     t_total=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare data
    eval_examples = data_reader.read(input_file=args.predict_file,
                                     **read_params)
    eval_features = data_reader.convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length)

    eval_tensors = data_reader.data_to_tensors(eval_features)
    eval_data = TensorDataset(*eval_tensors)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.predict_batch_size)

    if args.do_train:

        if args.do_label:
            logger.info('Training in State Wise.')
            sentence_label_file = args.sentence_id_file
            if sentence_label_file is not None:
                for file in sentence_label_file:
                    train_features = data_reader.generate_features_sentence_ids(
                        train_features, file)
            else:
                logger.info('No sentence id supervision is found.')
        else:
            logger.info('Training in traditional way.')

        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Num train total optimization steps = %d", t_total)
        logger.info("  Batch size = %d", args.train_batch_size)
        train_loss = AverageMeter()
        best_acc = 0.0
        best_loss = 1000000
        summary_writer = SummaryWriter(log_dir=args.output_dir)
        global_step = 0
        eval_loss = AverageMeter()
        eval_accuracy = CategoricalAccuracy()
        eval_epoch = 0
        last_update = 0

        train_tensors = data_reader.data_to_tensors(train_features)
        train_data = TensorDataset(*train_tensors)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        for epoch in range(int(args.num_train_epochs)):
            logger.info(f'Running at Epoch {epoch}')
            # Train
            for step, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="Iteration",
                         dynamic_ncols=True)):
                model.train()
                if n_gpu == 1:
                    batch = batch_to_device(
                        batch, device)  # multi-gpu does scattering it-self
                inputs = data_reader.generate_inputs(
                    batch, train_features, model_state=ModelState.Train)
                model_output = model(**inputs)
                loss = model_output['loss']
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    global_step += 1

                    lr_this_step = scheduler.get_lr()[0]
                    summary_writer.add_scalar('lr', lr_this_step, global_step)

                    batch_size = inputs["labels"].size(0)
                    train_loss.update(loss.item() * batch_size, batch_size)
                    summary_writer.add_scalar('train_loss', train_loss.avg,
                                              global_step)

                    if global_step % args.per_eval_step == 0:
                        # Evaluation
                        model.eval()
                        logger.info("Start evaluating")
                        for _, eval_batch in enumerate(
                                tqdm(eval_dataloader,
                                     desc="Evaluating",
                                     dynamic_ncols=True)):
                            if n_gpu == 1:
                                eval_batch = batch_to_device(
                                    eval_batch, device
                                )  # multi-gpu does scattering it-self
                            inputs = data_reader.generate_inputs(
                                eval_batch,
                                eval_features,
                                model_state=ModelState.Evaluate)
                            batch_size = inputs["labels"].size(0)
                            with torch.no_grad():
                                output_dict = model(**inputs)
                                loss, choice_logits = output_dict[
                                    'loss'], output_dict['choice_logits']
                                eval_loss.update(loss.item() * batch_size,
                                                 batch_size)
                                eval_accuracy(choice_logits, inputs["labels"])

                        eval_epoch_loss = eval_loss.avg
                        summary_writer.add_scalar('eval_loss', eval_epoch_loss,
                                                  global_step)
                        eval_loss.reset()
                        current_acc = eval_accuracy.get_metric(reset=True)
                        summary_writer.add_scalar('eval_acc', current_acc,
                                                  global_step)
                        torch.cuda.empty_cache()

                        if args.save_all:
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Only save the model it-self
                            output_model_file = os.path.join(
                                args.output_dir,
                                f"checkpoint-{global_step}.bin")
                            model_to_save.save_pretrained(output_model_file)
                            # torch.save(model_to_save.state_dict(), output_model_file)

                        if current_acc > best_acc:
                            best_acc = current_acc
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Only save the model it-self
                            output_model_file = os.path.join(
                                args.output_dir, "best_model")
                            model_to_save.save_pretrained(output_model_file)
                            # torch.save(model_to_save.state_dict(), output_model_file)
                            last_update = global_step // args.per_eval_step
                        if eval_epoch_loss < best_loss:
                            best_loss = eval_epoch_loss
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Only save the model it-self
                            output_model_file = os.path.join(
                                args.output_dir, "best_loss_model")
                            model_to_save.save_pretrained(output_model_file)
                            # torch.save(model_to_save.state_dict(), output_model_file)

                        logger.info(
                            'Global Step: %d, Accuracy: %.4f (Best Accuracy: %.4f)'
                            % (global_step, current_acc, best_acc))
                        eval_epoch += 1

                        if global_step // args.per_eval_step - last_update >= args.patience:
                            logger.info(
                                f"Training reach patience: {args.patience}, training stopped."
                            )
                            break

            if global_step // args.per_eval_step - last_update >= args.patience:
                break

            logger.info(
                f'Epoch {epoch}: Accuracy: {best_acc}, Train Loss: {train_loss.avg}'
            )
        summary_writer.close()

    for output_model_name in ["best_model", "best_loss_model"]:
        # Loading trained model
        output_model_file = os.path.join(args.output_dir, output_model_name)
        # model_state_dict = torch.load(output_model_file, map_location='cuda:0')
        # model = initialize_model(args.bert_name, args.model_file, state_dict=model_state_dict, **model_params)
        model = initialize_model(args.bert_name, output_model_file,
                                 **model_params)
        model.to(device)

        # Write Yes/No predictions
        if args.do_predict and (args.local_rank == -1
                                or torch.distributed.get_rank() == 0):

            test_examples = data_reader.read(args.test_file)
            test_features = data_reader.convert_examples_to_features(
                test_examples, tokenizer, args.max_seq_length)

            test_tensors = data_reader.data_to_tensors(test_features)
            test_data = TensorDataset(*test_tensors)
            test_sampler = SequentialSampler(test_data)
            test_dataloader = DataLoader(test_data,
                                         sampler=test_sampler,
                                         batch_size=args.predict_batch_size)

            logger.info("***** Running predictions *****")
            logger.info("  Num orig examples = %d", len(test_examples))
            logger.info("  Num split examples = %d", len(test_features))
            logger.info("  Batch size = %d", args.predict_batch_size)

            model.eval()
            all_results = []
            test_acc = CategoricalAccuracy()
            logger.info("Start predicting yes/no on Dev set.")
            for batch in tqdm(test_dataloader, desc="Testing"):
                if n_gpu == 1:
                    batch = batch_to_device(
                        batch, device)  # multi-gpu does scattering it-self
                inputs = data_reader.generate_inputs(
                    batch, test_features, model_state=ModelState.Evaluate)
                with torch.no_grad():
                    batch_choice_logits = model(**inputs)['choice_logits']
                    test_acc(batch_choice_logits, inputs['labels'])
                example_indices = batch[-1]
                for i, example_index in enumerate(example_indices):
                    choice_logits = batch_choice_logits[i].detach().cpu(
                    ).tolist()

                    test_feature = test_features[example_index.item()]
                    unique_id = int(test_feature.unique_id)

                    all_results.append(
                        RawResultChoice(unique_id=unique_id,
                                        choice_logits=choice_logits))

            if "loss" in output_model_name:
                logger.info(
                    'Predicting question choice on test set using model with lowest loss on validation set.'
                )
                output_prediction_file = os.path.join(args.predict_dir,
                                                      'loss_predictions.json')
            else:
                logger.info(
                    'Predicting question choice on test set using model with best accuracy on validation set,'
                )
                output_prediction_file = os.path.join(args.predict_dir,
                                                      'predictions.json')
            data_reader.write_predictions(test_examples, test_features,
                                          all_results, output_prediction_file)
            logger.info(
                f"Accuracy on Test set: {test_acc.get_metric(reset=True)}")

    # Loading trained model.
    if args.metric == 'accuracy':
        logger.info("Load model with best accuracy on validation set.")
        output_model_file = os.path.join(args.output_dir, "best_model")
    elif args.metric == 'loss':
        logger.info("Load model with lowest loss on validation set.")
        output_model_file = os.path.join(args.output_dir, "best_loss_model")
    else:
        raise RuntimeError(
            f"Wrong metric type for {args.metric}, which must be in ['accuracy', 'loss']."
        )
    # model_state_dict = torch.load(output_model_file, map_location='cuda:0')
    # model = initialize_model(args.bert_name, args.model_file, state_dict=model_state_dict, **model_params)
    model = initialize_model(args.bert_name, output_model_file, **model_params)
    model.to(device)

    # Labeling sentence id.
    if args.do_label and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):

        f = open('debug_log.txt', 'w')

        def softmax(x):
            """Compute softmax values for each sets of scores in x."""
            e_x = np.exp(x - np.max(x))
            return e_x / e_x.sum()

        def topk(sentence_sim):
            """
            :param sentence_sim: numpy
            :return:
            """
            max_length = min(args.num_evidence, len(sentence_sim))
            sorted_scores = np.array(sorted(sentence_sim, reverse=True))
            scores = []
            for idx in range(max_length):
                scores.append(np.log(softmax(sorted_scores[idx:])[0]))
            scores = [np.mean(scores[:(j + 1)]) for j in range(max_length)]
            top_k = int(np.argmax(scores) + 1)
            sorted_scores = sorted(enumerate(sentence_sim),
                                   key=lambda x: x[1],
                                   reverse=True)
            evidence_ids = [x[0] for x in sorted_scores[:top_k]]
            sentence = {
                'sentences': evidence_ids,
                'value': float(np.exp(scores[top_k - 1]))
            }
            return sentence

        def batch_topk(sentence_sim, sentence_mask):
            batch_size = sentence_sim.size(0)
            num_choices = sentence_sim.size(1)
            sentence_sim = sentence_sim.numpy() + 1e-15
            sentence_mask = sentence_mask.numpy()
            sentence_ids = []
            for b in range(batch_size):
                choice_sentence_ids = [
                    topk(_sim[:int(sum(_mask))])
                    for _sim, _mask in zip(sentence_sim[b], sentence_mask[b])
                ]
                assert len(choice_sentence_ids) == num_choices
                sentence_ids.append(choice_sentence_ids)
            return sentence_ids

        test_examples = train_examples
        test_features = train_features

        test_tensors = data_reader.data_to_tensors(test_features)
        test_data = TensorDataset(*test_tensors)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.predict_batch_size)

        logger.info("***** Running labeling *****")
        logger.info("  Num orig examples = %d", len(test_examples))
        logger.info("  Num split examples = %d", len(test_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start labeling.")
        for batch in tqdm(test_dataloader, desc="Testing"):
            if n_gpu == 1:
                batch = batch_to_device(batch, device)
            inputs = data_reader.generate_inputs(batch,
                                                 test_features,
                                                 model_state=ModelState.Test)
            with torch.no_grad():
                output_dict = model(**inputs)
                batch_choice_logits, batch_sentence_logits = output_dict[
                    "choice_logits"], output_dict["sentence_logits"]
                batch_sentence_mask = output_dict["sentence_mask"]
            example_indices = batch[-1]
            # batch_beam_results = batch_choice_beam_search(batch_sentence_logits, batch_sentence_mask)
            batch_topk_results = batch_topk(batch_sentence_logits,
                                            batch_sentence_mask)
            for i, example_index in enumerate(example_indices):
                choice_logits = batch_choice_logits[i].detach().cpu()
                evidence_list = batch_topk_results[i]

                test_feature = test_features[example_index.item()]
                unique_id = int(test_feature.unique_id)

                all_results.append(
                    RawOutput(unique_id=unique_id,
                              model_output={
                                  "choice_logits": choice_logits,
                                  "evidence_list": evidence_list
                              }))

        output_prediction_file = os.path.join(args.predict_dir,
                                              'sentence_id_file.json')
        data_reader.predict_sentence_ids(
            test_examples,
            test_features,
            all_results,
            output_prediction_file,
            weight_threshold=args.weight_threshold,
            only_correct=args.only_correct,
            label_threshold=args.label_threshold)
Beispiel #11
0
def train(net):

    train_starttime = gettime()

    train_data = data[C.train_data]
    batch_number = (len(train_data) // C.batch_size) + int(
        (len(train_data) % C.batch_size) != 0)

    optim = tc.optim.Adam(params=net.parameters(), lr=C.lr)
    sched = WarmupLinearSchedule(optim,
                                 warmup_steps=400,
                                 t_total=batch_number * C.epoch_number)

    loss_func = nn.NLLLoss(ignore_index=0)

    step = 0
    tot_loss = 0
    for epoch_n in range(C.epoch_number):

        lprint("epoch %d started." % (epoch_n))

        pbar = tqdm(range(batch_number), ncols=70)
        for batch_n in pbar:
            pbar.set_description_str("(Train)Epoch %d" % (epoch_n + 1))

            #-----------------get data-----------------
            inputs = []
            golds = []
            for data_device in C.gpus:
                inp, gold = get_a_batch(train_data, batch_n, data_device)
                inputs.append(inp)
                golds.append(gold)

            #------------------repadding-----------------

            maxlen_gold = max([max([len(x) for x in gold]) for gold in golds])
            for _i in range(len(inputs)):
                for _j in range(len(golds[_i])):  #batch
                    inputs[_i][-2][_j] += [0] * (maxlen_gold -
                                                 len(golds[_i][_j]))
                    golds[_i][_j] += [0] * (maxlen_gold - len(golds[_i][_j]))
                golds[_i] = tc.LongTensor(golds[_i]).cuda(C.gpus[_i])
                for _j in range(1, len(inputs[_i])):  #first one is graph
                    inputs[_i][_j] = tc.LongTensor(inputs[_i][_j]).cuda(
                        C.gpus[_i])

            #-----------------get output-----------------
            if len(inputs) == 1:
                y = net(*inputs[0], attn_method=C.attn_method)
                gold = golds[0]
            else:
                replicas = net.replicate(net.module,
                                         net.device_ids[:len(inputs)])
                outputs = net.parallel_apply(replicas, inputs,
                                             [{
                                                 "attn_method": C.attn_method
                                             }] * len(inputs))

                y = tc.cat([x.to(C.gpus[0]) for x in outputs], dim=0)
                gold = tc.cat([x.to(C.gpus[0]) for x in golds], dim=0)

            #-----------------get loss-----------------
            y = tc.log(y).view(-1, y.size(-1))
            gold = gold.view(-1)
            loss = loss_func(y, gold.view(-1))

            tot_loss += float(loss)

            step += 1

            #-----------------back prop-----------------
            #if step % C.update_freq == 0:
            if True:
                optim.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(net.parameters(), C.clip)
                optim.step()
                sched.step()

            pbar.set_postfix_str("loss: %.4f , avg_loss: %.4f" %
                                 (float(loss), tot_loss / step))

        lprint("epoch %d ended." % (epoch_n))
        valid(net)

        save_path = os.path.join(C.save, "epoch_%d.pkl" % epoch_n)
        if C.save:

            if len(C.gpus) > 1:
                _net = net.module
            else:
                net = net.cpu()
                _net = net

            with open(save_path, "wb") as fil:
                pickle.dump([_net, epoch_n + 1, optim], fil)

            if len(C.gpus) == 1:
                net = net.cuda(C.gpus[0])

            os.system("cp %s %s/last.pkl" % (save_path, C.save))
            lprint("saved...")

    lprint("tot train time = %.2fs" % (gettime() - train_starttime))
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--file_path",
        default="data/conceptual_caption/",
        type=str,
        help="The input train corpus.",
    )
    parser.add_argument(
        "--from_pretrained",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-base-uncased, roberta-base, roberta-large, ",
    )
    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, roberta-base",
    )
    parser.add_argument(
        "--output_dir",
        default="save",
        type=str,
        # required=True,
        help="The output directory where the model checkpoints will be written.",
    )
    parser.add_argument(
        "--config_file",
        type=str,
        default="config/bert_base_6layer_6conect.json",
        help="The config file which specified the model details.",
    )
    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=36,
        type=int,
        help="The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.",
    )
    parser.add_argument(
        "--train_batch_size",
        default=512,
        type=int,
        help="Total batch size for training.",
    )
    parser.add_argument(
        "--learning_rate",
        default=1e-4,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=10.0,
        type=float,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--start_epoch",
        default=0,
        type=float,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help="Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument(
        "--img_weight", default=1, type=float, help="weight for image loss"
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
    )
    parser.add_argument(
        "--on_memory",
        action="store_true",
        help="Whether to load train samples into memory or use disk",
    )
    parser.add_argument(
        "--do_lower_case",
        type=bool,
        default=True,
        help="Whether to lower case the input text. True for uncased models, False for cased models.",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="random seed for initialization"
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumualte before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--dynamic_attention",
        action="store_true",
        help="whether use dynamic attention.",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=25,
        help="Number of workers in the dataloader.",
    )
    parser.add_argument(
        "--save_name", default="", type=str, help="save name for training."
    )
    parser.add_argument(
        "--baseline",
        action="store_true",
        help="Wheter to use the baseline model (single bert).",
    )
    parser.add_argument(
        "--freeze",
        default=-1,
        type=int,
        help="till which layer of textual stream of vilbert need to fixed.",
    )
    parser.add_argument(
        "--distributed",
        action="store_true",
        help="whether use chunck for parallel training.",
    )
    parser.add_argument(
        "--without_coattention", action="store_true", help="whether pair loss."
    )
    parser.add_argument(
        "--visual_target",
        default=0,
        type=int,
        help="which target to use for visual branch. \
        0: soft label, \
        1: regress the feature, \
        2: NCE loss.",
    )

    parser.add_argument(
        "--objective",
        default=0,
        type=int,
        help="which objective to use \
        0: with ICA loss, \
        1: with ICA loss, for the not aligned pair, no masking objective, \
        2: without ICA loss, do not sample negative pair.",
    )
    parser.add_argument(
        "--num_negative", default=255, type=int, help="num of negative to use"
    )

    parser.add_argument(
        "--resume_file", default="", type=str, help="Resume from checkpoint"
    )
    parser.add_argument(
        "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
    )

    args = parser.parse_args()

    if args.baseline:
        from pytorch_pretrained_bert.modeling import BertConfig
        from vilbert.basebert import BertForMultiModalPreTraining
    else:
        from vilbert.vilbert import BertForMultiModalPreTraining, BertConfig

    if args.save_name:
        prefix = "-" + args.save_name
    else:
        prefix = ""

    timeStamp = args.config_file.split("/")[1].split(".")[0] + prefix
    savePath = os.path.join(args.output_dir, timeStamp)

    bert_weight_name = json.load(
        open("config/" + args.from_pretrained + "_weight_name.json", "r")
    )

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        )
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend="nccl")

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
            device, n_gpu, bool(args.local_rank != -1), args.fp16
        )
    )

    default_gpu = False
    if dist.is_available() and args.local_rank != -1:
        rank = dist.get_rank()
        if rank == 0:
            default_gpu = True
    else:
        default_gpu = True

    if default_gpu:
        if not os.path.exists(savePath):
            os.makedirs(savePath)

    config = BertConfig.from_json_file(args.config_file)

    if default_gpu:
        # save all the hidden parameters.
        with open(os.path.join(savePath, "command.txt"), "w") as f:
            print(args, file=f)  # Python 3.x
            print("\n", file=f)
            print(config, file=f)

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    cache = 5000
    if dist.is_available() and args.local_rank != -1:
        num_replicas = dist.get_world_size()
        args.train_batch_size = args.train_batch_size // num_replicas
        args.num_workers = args.num_workers // num_replicas
        cache = cache // num_replicas

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if "roberta" in args.bert_model:
        tokenizer = RobertaTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case
        )
    else:
        tokenizer = BertTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case
        )
    num_train_optimization_steps = None
    train_dataset = ConceptCapLoaderTrain(
        args.file_path,
        tokenizer,
        args.bert_model,
        seq_len=args.max_seq_length,
        batch_size=args.train_batch_size,
        visual_target=args.visual_target,
        num_workers=args.num_workers,
        local_rank=args.local_rank,
        objective=args.objective,
        cache=cache,
    )

    validation_dataset = ConceptCapLoaderVal(
        args.file_path,
        tokenizer,
        args.bert_model,
        seq_len=args.max_seq_length,
        batch_size=args.train_batch_size,
        visual_target=args.visual_target,
        num_workers=2,
        objective=args.objective,
    )

    num_train_optimization_steps = int(
        train_dataset.num_dataset
        / args.train_batch_size
        / args.gradient_accumulation_steps
    ) * (args.num_train_epochs - args.start_epoch)

    task_names = ["Conceptual_Caption"]
    task_ids = ["TASK0"]
    task_num_iters = {"TASK0": train_dataset.num_dataset / args.train_batch_size}

    logdir = os.path.join("logs", timeStamp)
    if default_gpu:
        tbLogger = utils.tbLogger(
            logdir,
            savePath,
            task_names,
            task_ids,
            task_num_iters,
            args.gradient_accumulation_steps,
        )

    if args.visual_target == 0:
        config.v_target_size = 1601
        config.visual_target = args.visual_target
    else:
        config.v_target_size = 2048
        config.visual_target = args.visual_target

    if "roberta" in args.bert_model:
        config.model = "roberta"

    if args.freeze > config.t_biattention_id[0]:
        config.fixed_t_layer = config.t_biattention_id[0]

    if args.without_coattention:
        config.with_coattention = False

    if args.dynamic_attention:
        config.dynamic_attention = True

    if args.from_pretrained:
        model = BertForMultiModalPreTraining.from_pretrained(
            args.from_pretrained, config=config, default_gpu=default_gpu
        )
    else:
        model = BertForMultiModalPreTraining(config)

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

    if args.freeze != -1:
        bert_weight_name_filtered = []
        for name in bert_weight_name:
            if "embeddings" in name:
                bert_weight_name_filtered.append(name)
            elif "encoder" in name:
                layer_num = name.split(".")[2]
                if int(layer_num) <= args.freeze:
                    bert_weight_name_filtered.append(name)

        optimizer_grouped_parameters = []
        for key, value in dict(model.named_parameters()).items():
            if key[12:] in bert_weight_name_filtered:
                value.requires_grad = False

        if default_gpu:
            print("filtered weight")
            print(bert_weight_name_filtered)

    if not args.from_pretrained:
        param_optimizer = list(model.named_parameters())
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.01,
            },
            {
                "params": [
                    p for n, p in param_optimizer if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
    else:
        optimizer_grouped_parameters = []
        for key, value in dict(model.named_parameters()).items():
            if value.requires_grad:
                if key[12:] in bert_weight_name:
                    lr = args.learning_rate * 0.1
                else:
                    lr = args.learning_rate

                if any(nd in key for nd in no_decay):
                    optimizer_grouped_parameters += [
                        {"params": [value], "lr": lr, "weight_decay": 0.0}
                    ]

                if not any(nd in key for nd in no_decay):
                    optimizer_grouped_parameters += [
                        {"params": [value], "lr": lr, "weight_decay": 0.01}
                    ]

        if default_gpu:
            print(
                len(list(model.named_parameters())), len(optimizer_grouped_parameters)
            )

    # set different parameters for vision branch and lanugage branch.
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            bias_correction=False,
            max_grad_norm=1.0,
        )
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)

    else:

        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            eps=args.adam_epsilon,
            betas=(0.9, 0.98),
        )

    scheduler = WarmupLinearSchedule(
        optimizer,
        warmup_steps=args.warmup_proportion * num_train_optimization_steps,
        t_total=num_train_optimization_steps,
    )

    startIterID = 0
    global_step = 0

    if args.resume_file != "" and os.path.exists(args.resume_file):
        checkpoint = torch.load(args.resume_file, map_location="cpu")
        new_dict = {}
        for attr in checkpoint["model_state_dict"]:
            if attr.startswith("module."):
                new_dict[attr.replace("module.", "", 1)] = checkpoint[
                    "model_state_dict"
                ][attr]
            else:
                new_dict[attr] = checkpoint["model_state_dict"][attr]
        model.load_state_dict(new_dict)
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        global_step = checkpoint["global_step"]
        del checkpoint

    model.cuda()

    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()

    if args.fp16:
        model.half()
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if default_gpu:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", train_dataset.num_dataset)
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

    for epochId in range(int(args.start_epoch), int(args.num_train_epochs)):
        model.train()
        for step, batch in enumerate(train_dataset):

            iterId = startIterID + step + (epochId * len(train_dataset))
            image_ids = batch[-1]
            batch = tuple(t.cuda(device=device, non_blocking=True) for t in batch[:-1])

            input_ids, input_mask, segment_ids, lm_label_ids, is_next, image_feat, image_loc, image_target, image_label, image_mask = (
                batch
            )

            if args.objective == 1:
                image_label = image_label * (is_next == 0).long().unsqueeze(1)
                image_label[image_label == 0] = -1

                lm_label_ids = lm_label_ids * (is_next == 0).long().unsqueeze(1)
                lm_label_ids[lm_label_ids == 0] = -1

            masked_loss_t, masked_loss_v, next_sentence_loss = model(
                input_ids,
                image_feat,
                image_loc,
                segment_ids,
                input_mask,
                image_mask,
                lm_label_ids,
                image_label,
                image_target,
                is_next,
            )

            if args.objective == 2:
                next_sentence_loss = next_sentence_loss * 0

            masked_loss_v = masked_loss_v * args.img_weight
            loss = masked_loss_t + masked_loss_v + next_sentence_loss

            if n_gpu > 1:
                loss = loss.mean()
                masked_loss_t = masked_loss_t.mean()
                masked_loss_v = masked_loss_v.mean()
                next_sentence_loss = next_sentence_loss.mean()

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / num_train_optimization_steps,
                        args.warmup_proportion,
                    )
                    for param_group in optimizer.param_groups:
                        param_group["lr"] = lr_this_step

                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

                if default_gpu:
                    tbLogger.step_train_CC(
                        epochId,
                        iterId,
                        float(masked_loss_t),
                        float(masked_loss_v),
                        float(next_sentence_loss),
                        optimizer.param_groups[0]["lr"],
                        "TASK0",
                        "train",
                    )

            if (
                step % (20 * args.gradient_accumulation_steps) == 0
                and step != 0
                and default_gpu
            ):
                tbLogger.showLossTrainCC()

        # Do the evaluation
        torch.set_grad_enabled(False)
        numBatches = len(validation_dataset)

        model.eval()
        for step, batch in enumerate(validation_dataset):
            image_ids = batch[-1]
            batch = tuple(t.cuda(device=device, non_blocking=True) for t in batch[:-1])

            input_ids, input_mask, segment_ids, lm_label_ids, is_next, image_feat, image_loc, image_target, image_label, image_mask = (
                batch
            )

            batch_size = input_ids.size(0)
            masked_loss_t, masked_loss_v, next_sentence_loss = model(
                input_ids,
                image_feat,
                image_loc,
                segment_ids,
                input_mask,
                image_mask,
                lm_label_ids,
                image_label,
                image_target,
                is_next,
            )

            masked_loss_v = masked_loss_v * args.img_weight
            loss = masked_loss_t + masked_loss_v + next_sentence_loss

            if n_gpu > 1:
                loss = loss.mean()
                masked_loss_t = masked_loss_t.mean()
                masked_loss_v = masked_loss_v.mean()
                next_sentence_loss = next_sentence_loss.mean()

            if default_gpu:
                tbLogger.step_val_CC(
                    epochId,
                    float(masked_loss_t),
                    float(masked_loss_v),
                    float(next_sentence_loss),
                    "TASK0",
                    batch_size,
                    "val",
                )
                sys.stdout.write("%d / %d \r" % (step, numBatches))
                sys.stdout.flush()

        if default_gpu:
            ave_score = tbLogger.showLossValCC()

        torch.set_grad_enabled(True)

        if default_gpu:
            # Save a trained model
            logger.info("** ** * Saving fine - tuned model ** ** * ")
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Only save the model it-self
            output_model_file = os.path.join(
                savePath, "pytorch_model_" + str(epochId) + ".bin"
            )
            output_checkpoint = os.path.join(
                savePath, "pytorch_ckpt_" + str(epochId) + ".tar"
            )
            torch.save(model_to_save.state_dict(), output_model_file)
            torch.save(
                {
                    "model_state_dict": model_to_save.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "global_step": global_step,
                },
                output_checkpoint,
            )

    if default_gpu:
        tbLogger.txt_close()