Exemplo n.º 1
0
def train(args, model, tokenizer, train_dataloader):
    """ Train the model """
    #if args.local_rank in [-1, 0]:
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    if args.max_steps > 0:
        t_total = args.max_steps
        #args.num_train_epochs = args.max_steps // (args.expected_train_size // args.gradient_accumulation_steps) + 1 
    else:
        t_total = args.expected_train_size // real_batch_size * args.num_train_epochs    

    # layerwise optimization for lamb
    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in ["roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3", "embeddingHead"]:
         layer = getattr_recursive(model, layer_name)
         if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    optimizer_grouped_parameters.append({"params": [p for p in model.parameters() if p not in layer_optim_params]})
    if len(optimizer_grouped_parameters)==0:
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
        ]   

    if args.optimizer.lower()=="lamb":
        optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    elif args.optimizer.lower()=="adamw":
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    else:
        raise Exception("optimizer {0} not recognized! Can only be lamb or adamW".format(args.optimizer))
    
    if args.scheduler.lower()=="linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
        )
    elif args.scheduler.lower()=="cosine":
        scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8)
    else:
        raise Exception("Scheduler {0} not recognized! Can only be linear or cosine".format(args.scheduler))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        # if is_first_worker():
        #     op_state = torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))
        #     print([len(x['params']) for x in op_state['param_groups']])
        #     real_op_state = optimizer.state_dict()
        #     print([len(x['params']) for x in real_op_state['param_groups']])
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    #logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        try:
            global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
            epochs_trained = global_step // (args.expected_train_size // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (args.expected_train_size // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except:
            logger.info("  Start training from a pretrained model") 

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    for m_epoch in train_iterator:
        #epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in tqdm(enumerate(train_dataloader), desc="Iteration", disable=args.local_rank not in [-1, 0]):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device).long() for t in batch)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                outputs = model(*batch)
            else:
                with model.no_sync():
                    outputs = model(*batch)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    loss.backward()
                else:
                    with model.no_sync():
                        loss.backward()          

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if is_first_worker() and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)
                dist.barrier()

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if args.evaluate_during_training and global_step % (args.logging_steps_per_eval*args.logging_steps)==0:
                        model.eval()
                        reranking_mrr, full_ranking_mrr = passage_dist_eval(args, model, tokenizer)
                        if is_first_worker():
                            print("Reranking/Full ranking mrr: {0}/{1}".format(str(reranking_mrr), str(full_ranking_mrr)))
                            mrr_dict = {"reranking": float(reranking_mrr), "full_raking": float(full_ranking_mrr)}
                            tb_writer.add_scalars("mrr", mrr_dict, global_step)
                            print(args.output_dir)

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss


                    if is_first_worker():
                        for key, value in logs.items():
                            print(key, type(value))
                            tb_writer.add_scalar(key, value, global_step)
                        tb_writer.add_scalar("epoch", m_epoch, global_step)
                        print(json.dumps({**logs, **{"step": global_step}}))
                    dist.barrier()

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemplo n.º 2
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    #if args.local_rank in [-1, 0]:
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    # layerwise optimization for lamb
    optimizer_grouped_parameters = []
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})

    if len(optimizer_grouped_parameters) == 0:
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                args.weight_decay,
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0
            },
        ]

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(
            os.path.join(args.model_name_or_path,
                         "optimizer.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    #logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        if "-" in args.model_name_or_path:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        else:
            global_step = 0
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

    tr_loss = 0.0
    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    dev_ndcg = 0
    step = 0

    if args.single_warmup:
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=args.max_steps)

    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0:
            # check if new ann training data is availabe
            ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
            if ann_path is not None and ann_no != last_ann_no:
                logger.info("Training on new add data at %s", ann_path)
                with open(ann_path, 'r') as f:
                    ann_training_data = f.readlines()
                dev_ndcg = ndcg_json['ndcg']
                ann_checkpoint_path = ndcg_json['checkpoint']
                ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]

                logger.info("Total ann queries: %d", len(ann_training_data))
                if args.triplet:
                    train_dataset = StreamingDataset(
                        ann_training_data,
                        GetTripletTrainingDataProcessingFn(
                            args, query_cache, passage_cache))
                else:
                    train_dataset = StreamingDataset(
                        ann_training_data,
                        GetTrainingDataProcessingFn(args, query_cache,
                                                    passage_cache))
                train_dataloader = DataLoader(train_dataset,
                                              batch_size=args.train_batch_size)
                train_dataloader_iter = iter(train_dataloader)

                # re-warmup
                if not args.single_warmup:
                    scheduler = get_linear_schedule_with_warmup(
                        optimizer,
                        num_warmup_steps=args.warmup_steps,
                        num_training_steps=len(ann_training_data))

                if args.local_rank != -1:
                    dist.barrier()

                if is_first_worker():
                    # add ndcg at checkpoint step used instead of current step
                    tb_writer.add_scalar("dev_ndcg", dev_ndcg,
                                         ann_checkpoint_no)
                    if last_ann_no != -1:
                        tb_writer.add_scalar("epoch", last_ann_no,
                                             global_step - 1)
                    tb_writer.add_scalar("epoch", ann_no, global_step)
                last_ann_no = ann_no

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)

        batch = tuple(t.to(args.device) for t in batch)
        step += 1

        if args.triplet:
            inputs = {
                "query_ids": batch[0].long(),
                "attention_mask_q": batch[1].long(),
                "input_ids_a": batch[3].long(),
                "attention_mask_a": batch[4].long(),
                "input_ids_b": batch[6].long(),
                "attention_mask_b": batch[7].long()
            }
        else:
            inputs = {
                "input_ids_a": batch[0].long(),
                "attention_mask_a": batch[1].long(),
                "input_ids_b": batch[3].long(),
                "attention_mask_b": batch[4].long(),
                "labels": batch[6]
            }

        # sync gradients only at gradient accumulation step
        if step % args.gradient_accumulation_steps == 0:
            outputs = model(**inputs)
        else:
            with model.no_sync():
                outputs = model(**inputs)
        loss = outputs[
            0]  # model outputs are always tuple in transformers (see doc)

        if args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if step % args.gradient_accumulation_steps == 0:
                loss.backward()
            else:
                with model.no_sync():
                    loss.backward()

        tr_loss += loss.item()
        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                loss_scalar = tr_loss / args.logging_steps
                learning_rate_scalar = scheduler.get_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["loss"] = loss_scalar
                tr_loss = 0

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step
Exemplo n.º 3
0
def train(args, model, tokenizer, shuffled_fh, train_fn, configObj, logger):
    """ Train the model """
    #if args.local_rank in [-1, 0]:
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    total_train_steps = len(
        shuffled_fh) * args.num_train_epochs // real_batch_size
    if args.warmup_steps <= 0:
        args.warmup_steps = int(total_train_steps * args.warmup_proportion)

    if args.max_steps > 0:
        t_total = args.max_steps
        #args.num_train_epochs = args.max_steps // (args.expected_train_size // args.gradient_accumulation_steps) + 1
    else:
        # t_total = args.expected_train_size // real_batch_size * args.num_train_epochs
        t_total = total_train_steps
        args.max_steps = total_train_steps

    # layerwise optimization for lamb
    optimizer_grouped_parameters = []
    no_decay = ["bias", "LayerNorm.weight", "layer_norm", "LayerNorm"]
    layer_optim_params = set()
    for layer_name in [
            "bert.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3", "embeddingHead"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)

    if getattr_recursive(model, "bert.encoder.layer") is not None:
        for layer in model.bert.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if len(optimizer_grouped_parameters) == 0:

        optimizer_grouped_parameters = [{
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.01
        }, {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        }]
    logger.info("len(optimizer_grouped_parameters): {}".format(
        len(optimizer_grouped_parameters)))  # 1

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    if args.scheduler.lower() == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total)
    elif args.scheduler.lower() == "cosine":
        scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8)
    else:
        raise Exception(
            "Scheduler {0} not recognized! Can only be linear or cosine".
            format(args.scheduler))

    # Check if saved optimizer or scheduler states exist
    # TODO: we find this consume huge amount of additional GPU memory with pytorch, thus disable for now
    # if os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt")) and args.resume:
    # Load in optimizer and scheduler states
    # if is_first_worker():
    #     op_state = torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))
    #     print([len(x['params']) for x in op_state['param_groups']])
    #     real_op_state = optimizer.state_dict()
    #     print([len(x['params']) for x in real_op_state['param_groups']])
    # optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
    # scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("   Train dataset size = %d", len(shuffled_fh))
    logger.info("   Num Epochs = %d", args.num_train_epochs)
    logger.info("   Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "   Total train batch size (w. parallel, distributed & accumulation) = %d",
        real_batch_size)
    logger.info("   Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("   Total optimization steps = %d", t_total)
    logger.info("   LR warmup steps = %d", args.warmup_steps)

    global_step = 0
    eval_cnt = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if (os.path.exists(args.model_name_or_path)
            and args.resume) or args.starting_step > 0:
        # set global_step to gobal_step of last saved checkpoint from model path
        try:
            global_step = args.starting_step

            if global_step <= 0:
                global_step = int(
                    args.model_name_or_path.split("-")[-1].split("/")[0])

            epochs_trained = global_step // (args.expected_train_size //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                args.expected_train_size // args.gradient_accumulation_steps)

            logger.info(
                "   Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("   Continuing training from epoch %d", epochs_trained)
            logger.info("   Continuing training from global step %d",
                        global_step)
            logger.info("   Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except:
            logger.info("  Start training from a pretrained model")

    tr_loss = 0.0

    tensorboard_scalars = {}
    model.zero_grad()

    eval_cfg = args.eval_configObj  # this is also produced in the load_model_config() method
    eval_fn = wrapped_process_fn(tokenizer, args, eval_cfg)

    ideal_path = args.eval_ideal_path
    is_first_eval = (eval_cnt == 0)

    best_checkpoints = []
    set_seed(args)  # Added here for reproductibility

    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    for m_epoch in train_iterator:
        # shuffle input after first epoch
        if m_epoch > 0:
            shuffled_fh.change_seed(m_epoch)
        sds = SimplifiedStreamingDataset(shuffled_fh, train_fn,
                                         configObj.ix_func)
        train_dataloader = DataLoader(sds,
                                      batch_size=args.per_gpu_train_batch_size,
                                      num_workers=4,
                                      pin_memory=True)
        acc_accum = []
        model.train()
        for step, batch in tqdm(enumerate(train_dataloader),
                                desc="Iteration",
                                disable=args.local_rank not in [-1, 0]):
            if step % 100 == 0 and step > 0:
                logger.info('train_step: {}'.format(step))
            # Skip past any already trained steps if resuming training
            # if steps_trained_in_current_epoch > 0:
            #     steps_trained_in_current_epoch -= 1
            #     continue

            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "query_ids": batch[0].long(),
                "query_attn_mask": batch[1].long(),
                "meta_ids": batch[3].long(),
                "meta_attn_mask": batch[4].long(),
                "labels": batch[6].float()
            }

            # sync gradients only at gradient accumulation step
            if (step + 1) % args.gradient_accumulation_steps == 0:
                outputs = model(**inputs)
            else:
                with model.no_sync():
                    outputs = model(**inputs)

            loss_combine = outputs[0]
            # assert len(loss_combine) == 3
            loss = loss_combine["Loss/total_loss"]
            sim_combine = outputs[1]
            # assert len(sim_combine) == 8
            acc = outputs[2]
            acc_accum.append(acc.item())

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    loss.backward()
                else:
                    with model.no_sync():
                        loss.backward()
            tr_loss += loss.item()

            # if is_first_worker():
            # print("unique labels: ", torch.unique(inputs["labels"]).int())
            #    print("Similarity combinations: ", sim_combine)

            for key, value in loss_combine.items():
                tensorboard_scalars[key] = tensorboard_scalars.setdefault(
                    key, 0.0) + value.item()
            for key, value in sim_combine.items():
                # print(f"{key}: {value.mean().item()}")
                value = value.mean()
                value[value != value] = 0
                tensorboard_scalars[key] = tensorboard_scalars.setdefault(
                    key, 0.0) + value.item()
                # print(f"tensorboardscalars: {key} : {tensorboard_scalars[key]}")

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    #for key, value in tensorboard_scalars.items():
                    #    print(f"{key}: {value}")
                    if args.evaluate_during_training and global_step % (
                            args.logging_steps_per_eval *
                            args.logging_steps) == 0:
                        if is_first_worker():
                            save_checkpoint(args,
                                            -1,
                                            model,
                                            tokenizer,
                                            logger=logger)

                        model.eval()
                        is_first_eval = (eval_cnt == 0)
                        args.global_step = global_step
                        init_time = time()
                        fidelity = eval_fidelity(args, model, eval_fn,
                                                 eval_cfg.path, ideal_path,
                                                 args.cache_dir, is_first_eval,
                                                 args.eval_full, logger)
                        logger.info("Eval cost time: {}".format(time() -
                                                                init_time))
                        eval_cnt += 1

                        model.train()

                        if is_first_worker():
                            if len(best_checkpoints) < 3:
                                save_checkpoint(args,
                                                global_step,
                                                model,
                                                tokenizer,
                                                optimizer,
                                                scheduler,
                                                logger=logger)
                                best_checkpoints.append(
                                    (global_step, fidelity))
                            else:
                                worst_checkpoint = sorted(
                                    best_checkpoints, key=lambda x: x[1])[0]
                                if fidelity > worst_checkpoint[1]:
                                    save_checkpoint(args,
                                                    global_step,
                                                    model,
                                                    tokenizer,
                                                    optimizer,
                                                    scheduler,
                                                    logger=logger)
                                    worst_cp_path = os.path.join(
                                        args.output_dir,
                                        "checkpoint-{}".format(
                                            str(worst_checkpoint[0])))
                                    shutil.rmtree(worst_cp_path)
                                    best_checkpoints.remove(worst_checkpoint)
                                    best_checkpoints.append(
                                        (global_step, fidelity))
                                else:
                                    logger.info("Fidelity not in top 3!")
                                assert len(best_checkpoints) == 3
                            tb_writer.add_scalar("fidelity", fidelity,
                                                 global_step)

                            logger.info("Fidelity: {0}".format(fidelity))
                        dist.barrier()

                    learning_rate_scalar = scheduler.get_lr()[0]
                    avg_acc = sum(acc_accum) * 1.0 / len(acc_accum)
                    logger.info("Train acc: {}".format(avg_acc))
                    if is_first_worker():
                        tb_writer.add_scalar("Training/learning_rate",
                                             learning_rate_scalar, global_step)
                        tb_writer.add_scalar("Training/epoch", m_epoch,
                                             global_step)
                        tb_writer.add_scalar("Training/accuracy", avg_acc,
                                             global_step)
                        for key, value in tensorboard_scalars.items():
                            tb_writer.add_scalar(key,
                                                 value / args.logging_steps,
                                                 global_step)
                        logger.info(
                            json.dumps({
                                **tensorboard_scalars,
                                **{
                                    "learning_rate": learning_rate_scalar,
                                    "Accuracy": avg_acc,
                                    "step": global_step
                                }
                            }))

                    tensorboard_scalars = {}
                    dist.barrier()

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemplo n.º 4
0
def worker_fn(rank, world_size):
    setup(rank, world_size)

    weights_filename = "weights.pt"
    batch_size = 512
    epochs = 240
    warmup_epochs = 8
    use_mixed_precision = True

    batch_size = batch_size // world_size #batch size per worker

    #Data
    all_data = os.listdir(datapath_preprocessed)
    train_filenames = [p for p in all_data if re.match(r'^PGM_' + re.escape(dataset_name) + r'_train_(\d+)\.npz$', p) is not None]
    val_filenames = [p for p in all_data if re.match(r'^PGM_' + re.escape(dataset_name) + r'_val_(\d+)\.npz$', p) is not None]
    train_dataset = PgmDataset(train_filenames)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8, pin_memory=False, sampler=train_sampler)#shuffle is done by the sampler
    val_dataloader = DataLoader(PgmDataset(val_filenames), batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=False)

    #Model
    device_ids = [rank]

    model = WReN(2).to(device_ids[0])#3-layer MLRN

    if weights_filename is not None and os.path.isfile("./" + weights_filename):
        model.load_state_dict(torch.load(weights_filename, map_location='cpu'))
        print('Weights loaded')
        cold_start = False
    else:
        print('No weights found')
        cold_start = True

    #Loss and optimizer
    final_lr = 2e-3

    def add_module_params_with_decay(module, weight_decay, param_groups):#adds parameters with decay unless they are bias parameters, which shouldn't receive decay
        group_with_decay = []
        group_without_decay = []
        for name, param in module.named_parameters():
            if not param.requires_grad: continue
            if name == 'bias' or name.endswith('bias'):
                group_without_decay.append(param)
            else:
                group_with_decay.append(param)
        param_groups.append({"params": group_with_decay, "weight_decay": weight_decay})
        param_groups.append({"params": group_without_decay})

    optimizer_param_groups = [
    ]

    add_module_params_with_decay(model.conv, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.post_cnn_linear, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.g, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.h, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.f, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.f_final, 2e-1, optimizer_param_groups)

    optimizer = Lamb(optimizer_param_groups, lr=final_lr)

    base_model = model
    if use_mixed_precision:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1") #Mixed Precision

    lossFunc = torch.nn.CrossEntropyLoss()
    softmax = torch.nn.Softmax(dim=1)

    #Parallel distributed model
    device = device_ids[0]
    torch.cuda.set_device(device)
    parallel_model = torch.nn.parallel.DistributedDataParallel(model, device_ids)

    if rank == 0:
        #accuracy logging
        sess = tf.Session()
        train_acc_placeholder = tf.placeholder(tf.float32, shape=())
        train_acc_summary = tf.summary.scalar('training_acc', train_acc_placeholder)
        val_acc_placeholder = tf.placeholder(tf.float32, shape=())
        val_acc_summary = tf.summary.scalar('validation_acc', val_acc_placeholder)
        writer = tf.summary.FileWriter("log", sess.graph)

    #training loop
    acc = []
    global_step = 0
    for epoch in range(epochs): 
        train_sampler.set_epoch(epoch) 

        # Validation
        val_acc = []
        parallel_model.eval()
        with torch.no_grad():
            for i, (local_batch, local_labels) in enumerate(val_dataloader):
                local_batch, targets = local_batch.to(device), local_labels.to(device)

                #answer = model(local_batch.type(torch.float32))
                answer, _ = parallel_model(local_batch.type(torch.float32))

                #Calc accuracy
                answerSoftmax = softmax(answer)
                maxIndex = answerSoftmax.argmax(dim=1)

                correct = maxIndex.eq(targets)
                accuracy = correct.type(dtype=torch.float16).mean(dim=0)
                val_acc.append(accuracy)

                if i % 50 == 0 and rank == 0:
                    print("batch " + str(i))

        total_val_acc = sum(val_acc) / len(val_acc)
        print('Validation accuracy: ' + str(total_val_acc.item()))
        if rank == 0:
            summary = sess.run(val_acc_summary, feed_dict={val_acc_placeholder: total_val_acc.item()})
            writer.add_summary(summary, global_step=global_step)

        # Training
        parallel_model.train()
        for i, (local_batch, local_labels) in enumerate(train_dataloader):
            global_step = global_step + 1

            if cold_start and epoch < warmup_epochs:#linear scaling of the lr for warmup during the first few epochs
                lr = final_lr * global_step / (warmup_epochs*len(train_dataset) / (batch_size * world_size))
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

            local_batch, targets = local_batch.to(device_ids[0]), local_labels.to(device_ids[0])

            optimizer.zero_grad()
            answer, activation_loss = parallel_model(local_batch.type(torch.float32))

            loss = lossFunc(answer, targets) + activation_loss * 2e-3

            #Calc accuracy
            answerSoftmax = softmax(answer)
            maxIndex = answerSoftmax.argmax(dim=1)

            correct = maxIndex.eq(targets)
            accuracy = correct.type(dtype=torch.float16).mean(dim=0)
            acc.append(accuracy)
            
            #Training step
            if use_mixed_precision:
                with amp.scale_loss(loss, optimizer) as scaled_loss: #Mixed precision
                    scaled_loss.backward()
            else:
                loss.backward()

            grad_norm = torch.nn.utils.clip_grad_norm_(parallel_model.parameters(), 1e1)

            optimizer.step()

            if i % 50 == 0 and rank == 0:
                print("epoch " + str(epoch) + " batch " + str(i))
                print("loss", loss)
                print("activation loss", activation_loss)
                print(grad_norm)

            #logging and saving weights
            if i % 1000 == 999:
                trainAcc = sum(acc) / len(acc)
                acc = []
                print('Training accuracy: ' + str(trainAcc.item()))
                if rank == 0:
                    if weights_filename is not None:
                        torch.save(base_model.state_dict(), weights_filename)
                        print('Weights saved')

                    summary = sess.run(train_acc_summary, feed_dict={train_acc_placeholder: trainAcc.item()})
                    writer.add_summary(summary, global_step=global_step)  

        if cold_start and weights_filename is not None and epoch % 10 == 0 and rank == 0:
            torch.save(base_model.state_dict(), weights_filename + "_cp" + str(epoch))
            print('Checkpoint saved')


    cleanup()
Exemplo n.º 5
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    #if args.local_rank in [-1, 0]:
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    # layerwise optimization for lamb
    optimizer_grouped_parameters = []

    no_decay = ["bias", "w", "b", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(
            os.path.join(args.model_name_or_path,
                         "optimizer.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    #logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    eval_cnt = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        if "-" in args.model_name_or_path:
            try:
                global_step = int(
                    args.model_name_or_path.split("-")[-1].split("/")[0])
            except:
                global_step = 0

        else:
            global_step = 0
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

    tr_loss = 0.0
    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    dev_ndcg = 0
    step = 0

    eval_path = os.path.join(args.data_dir, "eval_full.tsv")
    eval_cfg = L1InputConfig("eval",
                             model,
                             None,
                             eval_path,
                             args.configObj.chunk_cfg,
                             qid=0,
                             docid=1,
                             query=4,
                             title=5,
                             anchor=6,
                             url=7,
                             click=8,
                             desc=9,
                             rating=10,
                             market=12,
                             lang=13)

    def eval_fn(line, i):
        return L1_process_fn(line, i, tokenizer, args, eval_cfg.map,
                             eval_cfg.chunk_cfg)

    model.eval()
    ideal_path = os.path.join(args.data_dir, "ideal_map_UN.tsv")
    is_first_eval = (eval_cnt == 0)

    args.global_step = global_step
    # fidelity = eval_fidelity(args, model, eval_fn, eval_cfg.path, ideal_path, args.data_cache_dir, is_first_eval)
    # eval_cnt+=1
    # print("Fidelity: {0}".format(fidelity))
    # if is_first_worker():
    #     tb_writer.add_scalar("fidelity", fidelity, global_step)

    best_checkpoints = []
    acc_accum = []
    scheduler = None
    while global_step < args.max_steps:
        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0:
            # check if new ann training data is availabe
            ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)

            while ann_no == -1 or (ann_path is not None
                                   and ann_no != last_ann_no):
                try:
                    logger.info("Training on new add data at %s", ann_path)
                    with open(ann_path, 'r') as f:
                        ann_training_data = f.readlines()
                    dev_ndcg = ndcg_json['ndcg']
                    ann_checkpoint_path = ndcg_json['checkpoint']
                    ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                    aligned_size = (len(ann_training_data) //
                                    args.world_size) * args.world_size
                    ann_training_data = ann_training_data[:aligned_size]

                    logger.info("Total ann queries: %d",
                                len(ann_training_data))
                    if len(ann_training_data) == 0:
                        time.sleep(300)
                        continue

                    train_dataset = SimplifiedStreamingDataset(
                        ann_training_data,
                        args.configObj.process_fn(args, query_cache,
                                                  passage_cache))
                    train_dataloader = DataLoader(
                        train_dataset,
                        batch_size=args.train_batch_size,
                        num_workers=1)
                    train_dataloader_iter = iter(train_dataloader)

                    # re-warmup
                    if scheduler is None:
                        scheduler = get_linear_schedule_with_warmup(
                            optimizer,
                            num_warmup_steps=args.warmup_steps,
                            num_training_steps=args.max_steps)

                    if args.local_rank != -1:
                        dist.barrier()

                    if is_first_worker():
                        # add ndcg at checkpoint step used instead of current step
                        # tb_writer.add_scalar("dev_ndcg", dev_ndcg, ann_checkpoint_no)
                        if last_ann_no != -1:
                            tb_writer.add_scalar("epoch", last_ann_no,
                                                 global_step - 1)
                        tb_writer.add_scalar("epoch", ann_no, global_step)
                    last_ann_no = ann_no
                    break
                except:
                    if is_first_worker():
                        print("wait")
                    time.sleep(300)

                ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)

        batch = tuple(t.to(args.device).long() for t in batch)
        step += 1
        model.train()
        # if args.triplet:
        #     inputs = {"query_ids": batch[0].long(), "attention_mask_q": batch[1].long(),
        #                 "input_ids_a": batch[3].long(), "attention_mask_a": batch[4].long(),
        #                 "input_ids_b": batch[6].long(), "attention_mask_b": batch[7].long()}
        # else:
        #     inputs = {"input_ids_a": batch[0].long(), "attention_mask_a": batch[1].long(),
        #                 "input_ids_b": batch[3].long(), "attention_mask_b": batch[4].long(),
        #                 "labels": batch[6]}

        # sync gradients only at gradient accumulation step
        if step % args.gradient_accumulation_steps == 0:
            outputs = model(*batch)
        else:
            with model.no_sync():
                outputs = model(*batch)
        loss = outputs[
            0]  # model outputs are always tuple in transformers (see doc)
        acc = outputs[2]
        acc_accum.append(acc.item())

        if args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if step % args.gradient_accumulation_steps == 0:
                loss.backward()
            else:
                with model.no_sync():
                    loss.backward()

        tr_loss += loss.item()
        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            # print("w grad:", model.module.w.grad)
            # print("w:", model.module.w)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                if global_step % (args.logging_steps_per_eval *
                                  args.logging_steps) == 0:
                    print("Train acc:", sum(acc_accum) * 1.0 / len(acc_accum))
                    acc_accum = []
                    model.eval()
                    is_first_eval = (eval_cnt == 0)
                    args.global_step = global_step
                    fidelity = eval_fidelity(args, model, eval_fn,
                                             eval_cfg.path, ideal_path,
                                             args.data_cache_dir,
                                             is_first_eval)
                    eval_cnt += 1

                    if is_first_worker():
                        if len(best_checkpoints) < 10:
                            save_checkpoint(args, global_step, model,
                                            tokenizer, optimizer, scheduler)
                            best_checkpoints.append((global_step, fidelity))
                        else:
                            worst_checkpoint = sorted(best_checkpoints,
                                                      key=lambda x: x[1])[0]
                            if fidelity > worst_checkpoint[1]:
                                save_checkpoint(args, global_step, model,
                                                tokenizer, optimizer,
                                                scheduler)
                                worst_cp_path = os.path.join(
                                    args.output_dir, "checkpoint-{}".format(
                                        str(worst_checkpoint[0])))
                                shutil.rmtree(worst_cp_path)
                                best_checkpoints.remove(worst_checkpoint)
                                best_checkpoints.append(
                                    (global_step, fidelity))
                            else:
                                print("Fidelity not in top 10!")
                            assert len(best_checkpoints) == 10

                        save_checkpoint(args, -1, model, tokenizer)
                        print("Fidelity: {0}".format(fidelity))
                        logs["fidelity"] = fidelity
                    dist.barrier()
                loss_scalar = tr_loss / args.logging_steps
                learning_rate_scalar = scheduler.get_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["loss"] = loss_scalar
                tr_loss = 0

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step
Exemplo n.º 6
0
    model.train()
    for img, labels in dataloader_train:
        #img, labels = batch
        img, labels = img.to(device), labels.to(device)
        #print(labels[0])
        #labelsmat = F.one_hot(labels, num_classes=10).to(device)
        output = model(img)
        #loss = torch.sum((output-labelsmat)**2)
        loss = F.cross_entropy(output, labels)
        acc_train += torch.sum(torch.argmax(output,
                                            dim=-1) == labels)  #.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.detach()

    # Testing
    model.eval()
    for img, labels in dataloader_test:
        #img, labels = batch
        img, labels = img.to(device), labels.to(device)
        #print(labels[0])
        #labelsmat = F.one_hot(labels, num_classes=10).to(device)
        output = model(img)
        acc_test += torch.sum(torch.argmax(output, dim=-1) == labels)

    acc_train = acc_train.item() / n_data_train
    accliste_train[epoch - start_epoch] = acc_train
Exemplo n.º 7
0
def train(args, train_dataset, model_d, model_g, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_d_grouped_parameters = [
        {
            "params": [
                p for n, p in model_d.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model_d.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    # optimizer_d = AdamW(optimizer_d_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    optimizer_d = Lamb(optimizer_d_grouped_parameters,
                       lr=args.learning_rate,
                       betas=(0.9, 0.999),
                       eps=1e-6)
    scheduler_d = get_linear_schedule_with_warmup(
        optimizer_d,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    optimizer_g_grouped_parameters = [
        {
            "params": [
                p for n, p in model_g.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model_g.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    # optimizer_g = AdamW(optimizer_g_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    optimizer_g = Lamb(optimizer_g_grouped_parameters,
                       lr=args.learning_rate,
                       betas=(0.9, 0.999),
                       eps=1e-6)
    scheduler_g = get_linear_schedule_with_warmup(
        optimizer_g,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer_d.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler_d.pt")):
        # Load in optimizer and scheduler states
        optimizer_d.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path,
                                    "optimizer_d.pt")))
        scheduler_d.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path,
                                    "scheduler_d.pt")))
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer_g.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler_g.pt")):
        # Load in optimizer and scheduler states
        optimizer_g.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path,
                                    "optimizer_g.pt")))
        scheduler_g.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path,
                                    "scheduler_g.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model_d, optimizer_d = amp.initialize(model_d,
                                              optimizer_d,
                                              opt_level=args.fp16_opt_level)
        model_g, optimizer_g = amp.initialize(model_g,
                                              optimizer_g,
                                              opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model_d = torch.nn.DataParallel(model_d)
        model_g = torch.nn.DataParallel(model_g)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_d = torch.nn.parallel.DistributedDataParallel(
            model_d,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )
        model_g = torch.nn.parallel.DistributedDataParallel(
            model_g,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    model_to_resize_d = model_d.module if hasattr(
        model_d,
        "module") else model_d  # Take care of distributed/parallel training
    # model_to_resize_d.resize_token_embeddings(len(tokenizer))
    model_to_resize_g = model_g.module if hasattr(
        model_g,
        "module") else model_g  # Take care of distributed/parallel training
    # model_to_resize_g.resize_token_embeddings(len(tokenizer))

    # model_to_resize_d.bert.embeddings = model_to_resize_g.bert.embeddings

    tr_loss, logging_loss = 0.0, 0.0
    tr_loss_d, logging_loss_d = 0.0, 0.0
    tr_loss_g, logging_loss_g = 0.0, 0.0
    model_d.zero_grad()
    model_g.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model_d.train()
            model_g.train()
            # batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2]
                    if args.model_type in ["bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            # outputs = model(**inputs)
            # loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            masked_input_ids, mask_labels = mask_tokens(
                inputs['input_ids'], tokenizer, args)
            outputs_g = model_g(
                input_ids=masked_input_ids.to(args.device),
                masked_lm_labels=mask_labels.to(args.device),
                attention_mask=inputs['attention_mask'].to(args.device),
                token_type_ids=inputs['token_type_ids'].to(args.device))
            masked_lm_loss, prediction_scores_g = outputs_g[0], outputs_g[1]

            prediction_g = prediction_scores_g.max(dim=-1)[1].cpu()
            acc_g = (prediction_g[mask_labels >= 0] == mask_labels[
                mask_labels >= 0]).float().mean().item()

            prediction_probs_g = F.softmax(prediction_scores_g, dim=-1).cpu()
            bsz, seq_len, vocab_size = prediction_probs_g.size()
            prediction_samples_g = torch.multinomial(prediction_probs_g.view(
                -1, vocab_size),
                                                     num_samples=1)
            prediction_samples_g = prediction_samples_g.view(bsz, seq_len)
            input_ids_replace = inputs['input_ids'].clone()
            input_ids_replace[mask_labels >= 0] = prediction_samples_g[
                mask_labels >= 0]
            labels_d = input_ids_replace.eq(inputs['input_ids']).long()

            special_tokens_mask = [
                tokenizer.get_special_tokens_mask(
                    val, already_has_special_tokens=True)
                for val in inputs['input_ids'].tolist()
            ]
            labels_d.masked_fill_(torch.tensor(special_tokens_mask,
                                               dtype=torch.bool),
                                  value=-100)
            padding_mask = inputs['input_ids'].eq(tokenizer.pad_token_id)
            labels_d.masked_fill_(padding_mask, value=-100)

            labels_d_ones = labels_d[labels_d >= 0].float().mean().item()
            acc_replace = 1 - ((labels_d == 0).sum().float() /
                               (mask_labels >= 0).sum().float()).item()

            outputs_d = model_d(
                input_ids=input_ids_replace.to(args.device),
                attention_mask=inputs['attention_mask'].to(args.device),
                token_type_ids=inputs['token_type_ids'].to(args.device),
                labels=labels_d.to(args.device))
            loss_d, prediction_scores_d = outputs_d[0], outputs_d[1]
            prediction_d = prediction_scores_d.max(dim=-1)[1].cpu()
            acc_d = (prediction_d[labels_d >= 0] == labels_d[labels_d >= 0]
                     ).float().mean().item()
            acc_d_0 = (prediction_d[labels_d == 0] == labels_d[labels_d == 0]
                       ).float().mean().item()
            acc_d_1 = (prediction_d[labels_d == 1] == labels_d[labels_d == 1]
                       ).float().mean().item()

            if args.n_gpu > 1:
                loss_d = loss_d.mean(
                )  # mean() to average on multi-gpu parallel training
                masked_lm_loss = masked_lm_loss.mean()
            if args.gradient_accumulation_steps > 1:
                loss_d = loss_d / args.gradient_accumulation_steps
                masked_lm_loss = masked_lm_loss / args.gradient_accumulation_steps

            lambd = 50
            loss = loss_d * lambd + masked_lm_loss
            if args.fp16:
                loss_d = loss_d * lambd
                with amp.scale_loss(loss_d, optimizer_d) as scaled_loss_d:
                    scaled_loss_d.backward()
                with amp.scale_loss(masked_lm_loss,
                                    optimizer_g) as scaled_loss_g:
                    scaled_loss_g.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            tr_loss_d += loss_d.item()
            tr_loss_g += masked_lm_loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer_d), args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer_g), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model_d.parameters(),
                                                   args.max_grad_norm)
                    torch.nn.utils.clip_grad_norm_(model_g.parameters(),
                                                   args.max_grad_norm)

                optimizer_d.step()
                scheduler_d.step()  # Update learning rate schedule
                model_d.zero_grad()
                optimizer_g.step()
                scheduler_g.step()  # Update learning rate schedule
                model_g.zero_grad()

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    # if (
                    #     args.local_rank == -1 and args.evaluate_during_training
                    # ):  # Only evaluate when single GPU otherwise metrics may not average well
                    #     results = evaluate(args, model, tokenizer)
                    #     for key, value in results.items():
                    #         eval_key = "eval_{}".format(key)
                    #         logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    loss_scalar_d = (tr_loss_d -
                                     logging_loss_d) / args.logging_steps
                    loss_scalar_g = (tr_loss_g -
                                     logging_loss_g) / args.logging_steps
                    learning_rate_scalar_d = scheduler_d.get_lr()[0]
                    learning_rate_scalar_g = scheduler_g.get_lr()[0]
                    logs["learning_rate_d"] = learning_rate_scalar_d
                    logs["learning_rate_g"] = learning_rate_scalar_g
                    logs["loss"] = loss_scalar
                    logs["loss_d"] = loss_scalar_d
                    logs["loss_g"] = loss_scalar_g
                    logs["acc_repalce"] = acc_replace
                    logs["acc_d"] = acc_d
                    logs["acc_d_0"] = acc_d_0
                    logs["acc_d_1"] = acc_d_1
                    logs["acc_g"] = acc_g
                    logs["labels_d_ones"] = labels_d_ones
                    logs["masked_ratio"] = (mask_labels >= 0).float().sum(
                    ).item() / (labels_d >= 0).sum().float().item()
                    logging_loss = tr_loss
                    logging_loss_d = tr_loss_d
                    logging_loss_g = tr_loss_g

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    print(json.dumps({**logs, **{"step": global_step}}))

                # print(args.save_steps)
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    output_dir_d = os.path.join(
                        output_dir, "checkpoint-d-{}".format(global_step))
                    output_dir_g = os.path.join(
                        output_dir, "checkpoint-g-{}".format(global_step))
                    if not os.path.exists(output_dir_d):
                        os.makedirs(output_dir_d)
                    if not os.path.exists(output_dir_g):
                        os.makedirs(output_dir_g)
                    model_to_save_d = (
                        model_d.module if hasattr(model_d, "module") else
                        model_d)  # Take care of distributed/parallel training
                    model_to_save_g = (
                        model_g.module if hasattr(model_g, "module") else
                        model_g)  # Take care of distributed/parallel training
                    model_to_save_d.save_pretrained(output_dir_d)
                    model_to_save_g.save_pretrained(output_dir_g)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer_d.state_dict(),
                               os.path.join(output_dir_d, "optimizer_d.pt"))
                    torch.save(scheduler_d.state_dict(),
                               os.path.join(output_dir_d, "scheduler_d.pt"))
                    torch.save(optimizer_g.state_dict(),
                               os.path.join(output_dir_d, "optimizer_g.pt"))
                    torch.save(scheduler_g.state_dict(),
                               os.path.join(output_dir_d, "scheduler_g.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

                global_step += 1

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step