def get_model(args, version=None):
    """Build the model."""
    
    print_rank_0('building Bert model ...')
    if version is None:
        model = BertMixtureModel(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      layernorm_epsilon=args.layernorm_epsilon,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True,
                      num_experts=args.num_experts,
                      type_vocab_size=2)
    elif version == "v0":
        model = BertMixtureModel_v0(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      layernorm_epsilon=args.layernorm_epsilon,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True,
                      num_experts=args.num_experts,
                      type_vocab_size=2)
    
    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    #To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if args.deepspeed and args.fp16:
        model.half()

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if USE_TORCH_DDP:
        i = torch.cuda.current_device()
        model = DDP(model, device_ids=[i], output_device=i,
                    process_group=mpu.get_data_parallel_group())
    else:
        model = DDP(model)

    return model
コード例 #2
0
def get_model(args):
    """Build the model."""

    print_rank_0('building GPT3 model ...')
    assert args.num_attention_heads % args.model_parallel_size == 0
    num_local_heads = args.num_attention_heads // args.model_parallel_size
    deepspeed_sparsity_config = None
    if DEEPSPEED_WRAP and args.deepspeed:
        deepspeed_sparsity_config = get_sparse_attention_config(args, num_local_heads)
    if deepspeed_sparsity_config is not None:
        print_rank_0(f"Use sparse attention with mode {args.sparse_mode}")
    model = GPT3Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True,
                      deepspeed_sparsity_config=deepspeed_sparsity_config,
                      sparse_mode=args.sparse_mode)

    if args.load_huggingface is not None:
        model = load_huggingface_model(model, args.load_huggingface, args.huggingface_double_pos_embeddings)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if DEEPSPEED_WRAP and args.deepspeed and args.fp16:
        model.half()

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if USE_TORCH_DDP:
        i = torch.cuda.current_device()
        model = DDP(model, device_ids=[i], output_device=i,
                    process_group=mpu.get_data_parallel_group())
    else:
        model = DDP(model)

    return model
コード例 #3
0
def DataParallel(module,
                 device_ids=None,
                 output_device=None,
                 dim=0,
                 chunk_sizes=None):
    torch.cuda.set_device(device_ids)
    module = module.cuda(device_ids)
    if chunk_sizes is None:
        return DDP(module, device_ids, output_device, dim)
    standard_size = True
    for i in range(1, len(chunk_sizes)):
        if chunk_sizes[i] != chunk_sizes[0]:
            standard_size = False
    if standard_size:
        return DDP(module, device_ids, output_device, dim)
    return _DataParallel(module, device_ids, output_device, dim, chunk_sizes)
コード例 #4
0
def get_model(args):
    """Build the model."""

    print_rank_0('building ruGPT2048 model ...')
    model = GPT2Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True,
                      use_sparse=args.use_sparse)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])),
              flush=True)

    # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if args.deepspeed and args.fp16:
        raise NotImplemented("No installed deep speed")

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if USE_TORCH_DDP:
        i = torch.cuda.current_device()
        model = DDP(model,
                    device_ids=[i],
                    output_device=i,
                    process_group=mpu.get_data_parallel_group())
    else:
        model = DDP(model)

    return model
コード例 #5
0
def model_builder():
    # setup the model
    if args.model == 'i3d':
        if args.mode == 'flow':
            model = InceptionI3d(num_classes=7, in_channels=2)
            model.load_state_dict(
                {
                    k: v
                    for k, v in torch.load('models/flow_imagenet.pt').items()
                    if k.find('logits') < 0
                },
                strict=False)
        else:
            model = InceptionI3d(num_classes=7,
                                 in_channels=3,
                                 dropout_keep_prob=0.5)
            model.load_state_dict(
                {
                    k: v
                    for k, v in torch.load('models/rgb_imagenet.pt').items()
                    if k.find('logits') < 0
                },
                strict=False)
    elif args.model == 'r2plus1d':
        model = R2Plus1DClassifier(num_classes=7)
    elif args.model == 'w3d':
        model = W3D(num_classes=7)
        # model.load_state_dict(torch.load('pev_i3d_best.pt'))

    if args.resume is not None:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume, map_location=lambda storage, loc: storage)
                model.load_state_dict(checkpoint)
                print("=> loaded checkpoint '{}' ".format(args.resume))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        resume()

    model = model.cuda()

    if args.distributed:
        lr = args.lr * args.batch_size * args.world_size / 64.
    else:
        lr = args.lr * args.batch_size / 56.

    # lr_sched = optim.lr_scheduler.MultiStepLR(optimizer, [30, 60])
    if args.distributed:
        model = DDP(model)
    else:
        model = nn.DataParallel(model)

    return model
コード例 #6
0
def get_model(args):
    """Build the model."""

    print_rank_0('building BERT model ...')
    model = BertModel(args)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])),
              flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)
        if args.fp32_embedding:
            model.module.model.bert.embeddings.word_embeddings.float()
            model.module.model.bert.embeddings.position_embeddings.float()
            model.module.model.bert.embeddings.token_type_embeddings.float()
        if args.fp32_tokentypes:
            model.module.model.bert.embeddings.token_type_embeddings.float()
        if args.fp32_layernorm:
            for name, _module in model.named_modules():
                if 'LayerNorm' in name:
                    _module.float()

    # Wrap model for distributed training.
    if USE_TORCH_DDP:
        i = torch.cuda.current_device()
        model = DDP(model,
                    device_ids=[i],
                    output_device=i,
                    process_group=mpu.get_data_parallel_group())
    else:
        model = DDP(model)

    return model
コード例 #7
0
def get_model(args, config, do_fp16=False):
    """Build the model."""

    print_rank_0('building GPT2 model ...')
    model = GPT2Model(**config,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])),
              flush=True)

    # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if args.deepspeed and do_fp16:
        model.half()

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if do_fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training.
    if USE_TORCH_DDP:
        i = torch.cuda.current_device()
        model = DDP(model,
                    device_ids=[i],
                    output_device=i,
                    process_group=mpu.get_data_parallel_group())
    else:
        model = DDP(model)

    return model
コード例 #8
0
ファイル: controller.py プロジェクト: cleemesser/hetseq
 def model(self):
     if self._wrapped_model is None:
         if self.args.distributed_world_size > 1 and not self.args.use_bmuf:
             self._wrapped_model = DDP(
                 module=self._model,
                 device_ids=[self.args.device_id],
                 output_device=self.args.device_id,
                 broadcast_buffers=False,
                 bucket_cap_mb=self.args.bucket_cap_mb,
                 check_reduction=True,
                 find_unused_parameters=self.args.find_unused_parameters,
             )
         else:
             self._wrapped_model = self._model
     return self._wrapped_model
コード例 #9
0
ファイル: apex.py プロジェクト: banctilrobitaille/kerosene
    def initialize(self, amp_id: int, num_losses: int, use_amp: bool,
                   amp_opt_level: str, device: torch.device):
        self._amp_id = amp_id
        self._use_amp = use_amp

        if APEX_AVAILABLE and self._use_amp:
            self._model, self._optimizer = amp.initialize(
                self._model,
                self._optimizer,
                opt_level=amp_opt_level,
                num_losses=num_losses)
            if on_multiple_gpus(get_devices()):
                self._model = ApexDDP(self._model, delay_allreduce=True)
        if not APEX_AVAILABLE and on_multiple_gpus(get_devices()):
            self._model = DDP(self._model, device_ids=[device])
コード例 #10
0
ファイル: base_task.py プロジェクト: mtli/llcv
    def __init__(self, args, loader, is_train):
        self.loader = loader
        self.dataset = loader.dataset
        self.is_train = is_train
        self.device = args.device
        self.gather = False
        self.gpu_gather = args.gpu_gather
        self.resume_epoch = 0
        self.has_val_score = False
        self.exp_dir = args.exp_dir
        if self.is_train:
            self.last_lr = args.lr
            self.lr_update_per_epoch = args.lr_update_per_epoch

        self.model = build_model(args, self.dataset)
        logging.debug(str(self.model))
        logging.debug(
            f'Total number of parameters: {sum([p.numel() for p in self.model.parameters()])}'
        )

        self.rank = dist_get_rank()
        if self.rank >= 0:
            self.device = torch.cuda.current_device(
            ) if args.use_cuda else 'cpu'
            self.model = self.model.to(self.device)
            self.model = DDP(
                self.model,
                [self.device] if args.use_cuda else None,
                find_unused_parameters=True,
            )
        else:
            if args.use_cuda:
                if torch.cuda.device_count() > 1:
                    self.model = DP(self.model)
                self.model = self.model.to(self.device)
        self.output_device = args.device if args.gpu_gather else 'cpu'

        if is_train:
            logging.debug(
                f'Optimizer: {args.optim} with base learning rate {args.lr:.6g}'
            )
            self.set_optim(args)
            self.set_lr_schedule(args)

        self.auto_load(args)
コード例 #11
0
    def __init__(
        self,
        *,
        model,
        diffusion,
        data,
        batch_size,
        microbatch,
        lr,
        ema_rate,
        log_interval,
        save_interval,
        resume_checkpoint,
        use_fp16=False,
        fp16_scale_growth=1e-3,
        schedule_sampler=None,
        weight_decay=0.0,
        lr_anneal_steps=0,
    ):
        self.model = model
        self.diffusion = diffusion
        self.data = data
        self.batch_size = batch_size
        self.microbatch = microbatch if microbatch > 0 else batch_size
        self.lr = lr
        self.ema_rate = (
            [ema_rate]
            if isinstance(ema_rate, float)
            else [float(x) for x in ema_rate.split(",")]
        )
        self.log_interval = log_interval
        self.save_interval = save_interval
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16
        self.fp16_scale_growth = fp16_scale_growth
        self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
        self.weight_decay = weight_decay
        self.lr_anneal_steps = lr_anneal_steps

        self.step = 0
        self.resume_step = 0
        self.global_batch = self.batch_size * dist.get_world_size()

        self.model_params = list(self.model.parameters())
        self.master_params = self.model_params
        self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
        self.sync_cuda = th.cuda.is_available()

        self._load_and_sync_parameters()
        if self.use_fp16:
            self._setup_fp16()

        self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
        if self.resume_step:
            self._load_optimizer_state()
            # Model was resumed, either due to a restart or a checkpoint
            # being specified at the command line.
            self.ema_params = [
                self._load_ema_parameters(rate) for rate in self.ema_rate
            ]
        else:
            self.ema_params = [
                copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
            ]

        if th.cuda.is_available():
            self.use_ddp = True
            self.ddp_model = DDP(
                self.model,
                device_ids=[dist_util.dev()],
                output_device=dist_util.dev(),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,
            )
        else:
            if dist.get_world_size() > 1:
                logger.warn(
                    "Distributed training requires CUDA. "
                    "Gradients will not be synchronized properly!"
                )
            self.use_ddp = False
            self.ddp_model = self.model
コード例 #12
0
ファイル: train.py プロジェクト: mlcommons/hpc
def main(pargs):

    #init distributed training
    comm_local_group = comm.init(pargs.wireup_method,
                                 pargs.batchnorm_group_size)
    comm_rank = comm.get_rank()
    comm_local_rank = comm.get_local_rank()
    comm_size = comm.get_size()
    comm_local_size = comm.get_local_size()

    # set up logging
    pargs.logging_frequency = max([pargs.logging_frequency, 1])
    log_file = os.path.normpath(
        os.path.join(pargs.output_dir, "logs", pargs.run_tag + ".log"))
    logger = mll.mlperf_logger(log_file, "deepcam", "Umbrella Corp.")
    logger.log_start(key="init_start", sync=True)
    logger.log_event(key="cache_clear")

    #set seed
    seed = pargs.seed
    logger.log_event(key="seed", value=seed)

    # Some setup
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        device = torch.device("cuda", comm_local_rank)
        torch.cuda.manual_seed(seed)
        torch.cuda.set_device(device)
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")

    #set up directories
    root_dir = os.path.join(pargs.data_dir_prefix)
    output_dir = pargs.output_dir
    plot_dir = os.path.join(output_dir, "plots")
    if comm_rank == 0:
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

    # logging of rank information
    logger.log_event(key="number_of_ranks", value=comm_size)
    logger.log_event(key="number_of_nodes",
                     value=(comm_size // comm_local_size))
    logger.log_event(key="accelerators_per_node", value=comm_local_size)

    # Logging hyperparameters
    logger.log_event(key="global_batch_size",
                     value=(pargs.local_batch_size * comm_size))
    logger.log_event(key="batchnorm_group_size",
                     value=pargs.batchnorm_group_size)
    logger.log_event(key="gradient_accumulation_frequency",
                     value=pargs.gradient_accumulation_frequency)
    logger.log_event(key="checkpoint", value=pargs.checkpoint)

    # Define architecture
    n_input_channels = len(pargs.channels)
    n_output_channels = 3
    net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels,
                                          n_classes=n_output_channels,
                                          os=16,
                                          pretrained=False,
                                          rank=comm_rank,
                                          process_group=comm_local_group)
    net.to(device)

    #select loss
    #some magic numbers
    loss_pow = -0.125
    class_weights = [
        0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow,
        0.01327431072255291**loss_pow
    ]
    # extract loss
    criterion = losses.CELoss(class_weights).to(device)
    criterion = torch.jit.script(criterion)

    #select optimizer
    optimizer = oh.get_optimizer(pargs, net, logger)

    #restart from checkpoint if desired
    if pargs.checkpoint is not None:
        checkpoint = torch.load(pargs.checkpoint, map_location=device)
        start_step = checkpoint['step']
        start_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer'])
        net.load_state_dict(checkpoint['model'])
    else:
        start_step = 0
        start_epoch = 0

    #broadcast model and optimizer state
    steptens = torch.tensor(np.array([start_step, start_epoch]),
                            requires_grad=False).to(device)
    if dist.is_initialized():
        dist.broadcast(steptens, src=0)

    #unpack the bcasted tensor
    start_step = int(steptens.cpu().numpy()[0])
    start_epoch = int(steptens.cpu().numpy()[1])

    #select scheduler
    scheduler = None
    if pargs.lr_schedule:
        pargs.lr_schedule["lr_warmup_steps"] = pargs.lr_warmup_steps
        pargs.lr_schedule["lr_warmup_factor"] = pargs.lr_warmup_factor
        scheduler = oh.get_lr_schedule(pargs.start_lr,
                                       pargs.lr_schedule,
                                       optimizer,
                                       logger,
                                       last_step=start_step)

    # print parameters
    if comm_rank == 0:
        print(net)
        print("Total number of elements:",
              sum(p.numel() for p in net.parameters() if p.requires_grad))

    # get input shapes for the upcoming model preprocessing
    # input_shape:
    tshape, _ = get_datashapes(pargs, root_dir)
    input_shape = tuple([tshape[2], tshape[0], tshape[1]])

    #distributed model parameters
    bucket_cap_mb = 25
    if pargs.batchnorm_group_size > 1:
        bucket_cap_mb = 220

    # get stream, relevant for graph capture
    ddp_net = DDP(net,
                  device_ids=[device.index],
                  output_device=device.index,
                  find_unused_parameters=False,
                  broadcast_buffers=False,
                  bucket_cap_mb=bucket_cap_mb,
                  gradient_as_bucket_view=False)

    # get stats handler here
    bnstats_handler = bns.BatchNormStatsSynchronize(ddp_net,
                                                    reduction="mean",
                                                    inplace=True)

    # create handles
    net_validate = ddp_net
    net_train = ddp_net

    # Set up the data feeder
    train_loader, train_size, validation_loader, validation_size = get_dataloaders(
        pargs, root_dir, device, seed, comm_size, comm_rank)

    # log size of datasets
    logger.log_event(key="train_samples", value=train_size)
    val_size = validation_size
    logger.log_event(key="eval_samples", value=val_size)

    # get start steps
    step = start_step
    epoch = start_epoch
    current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr(
    )[0]
    stop_training = False
    net_train.train()

    # start trining
    logger.log_end(key="init_stop", sync=True)
    logger.log_start(key="run_start", sync=True)

    # training loop
    while True:

        # start epoch
        logger.log_start(key="epoch_start",
                         metadata={
                             'epoch_num': epoch + 1,
                             'step_num': step
                         },
                         sync=True)

        train_loader.sampler.set_epoch(epoch)

        # training
        step = train_step(pargs, comm_rank, comm_size, device, step, epoch,
                          net_train, criterion, optimizer, scheduler,
                          train_loader, logger)

        # average BN stats
        bnstats_handler.synchronize()

        # validation
        stop_training = validate(pargs, comm_rank, comm_size, device, step,
                                 epoch, net_validate, criterion,
                                 validation_loader, logger)

        # log the epoch
        logger.log_end(key="epoch_stop",
                       metadata={
                           'epoch_num': epoch + 1,
                           'step_num': step
                       },
                       sync=True)
        epoch += 1

        #save model if desired
        if (pargs.save_frequency > 0) and (epoch % pargs.save_frequency == 0):
            logger.log_start(key="save_start",
                             metadata={
                                 'epoch_num': epoch + 1,
                                 'step_num': step
                             },
                             sync=True)
            if comm_rank == 0:
                checkpoint = {
                    'step': step,
                    'epoch': epoch,
                    'model': net_train.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                torch.save(
                    checkpoint,
                    os.path.join(
                        output_dir,
                        pargs.model_prefix + "_step_" + str(step) + ".cpt"))
                logger.log_end(key="save_stop",
                               metadata={
                                   'epoch_num': epoch + 1,
                                   'step_num': step
                               },
                               sync=True)

        # are we done?
        if (epoch >= pargs.max_epochs) or stop_training:
            break

    # run done
    logger.log_end(key="run_stop", sync=True, metadata={'status': 'success'})
コード例 #13
0
                                                    vocab_file=args.vocab_file)

# get the model
model = get_model(args.model)(dataloader.SRC, dataloader.TRG, args)
watcher.info(model)
watcher.info("total trainable parameters: {}".format(
    format(count_parameters(model), ',')))
watcher.info("Vocabulary size: {}/{}.".format(len(dataloader.SRC.vocab),
                                              len(dataloader.TRG.vocab)))

# use GPU
if torch.cuda.is_available():
    model.cuda()
if args.distributed:
    model = DDP(model,
                device_ids=[args.device_id],
                output_device=args.device_id)
model = setup_pretrained_model(args, model, watcher)

# start running
if args.mode == 'train':
    watcher.info('starting training')
    train_model(args,
                watcher,
                model,
                dataloader.train,
                dataloader.dev,
                decoding_path=None)

elif args.mode == 'test':
コード例 #14
0
def main(args):
    """
    Main function for training, evaluating, and checkpointing.

    Args:
        args: `argparse` object.
    """
    # Print arguments.
    print('\nusing arguments:')
    _print_arguments(args)
    print()

    # Check if GPU is available.
    if not args.use_gpu and torch.cuda.is_available():
        print('warning: GPU is available but args.use_gpu = False')
        print()

    local_rank = args.local_rank
    # world_size = torch.cuda.device_count() # assume all local GPUs

    # Set up distributed process group
    rank = setup_dist(local_rank)

    # Set up datasets.
    train_dataset = QADataset(args, args.train_path)
    dev_dataset = QADataset(args, args.dev_path)

    # Create vocabulary and tokenizer.
    vocabulary = Vocabulary(train_dataset.samples, args.vocab_size)
    tokenizer = Tokenizer(vocabulary)
    for dataset in (train_dataset, dev_dataset):
        dataset.register_tokenizer(tokenizer)
    args.vocab_size = len(vocabulary)
    args.pad_token_id = tokenizer.pad_token_id
    print(f'vocab words = {len(vocabulary)}')

    # Print number of samples.
    print(f'train samples = {len(train_dataset)}')
    print(f'dev samples = {len(dev_dataset)}')
    print()

    # Select model.
    model = _select_model(args)
    #model = model.to(rank)
    #model = DDP(model, device_ids=[rank], output_device=rank)

    num_pretrained = model.load_pretrained_embeddings(
        vocabulary, args.embedding_path
    )
    pct_pretrained = round(num_pretrained / len(vocabulary) * 100., 2)
    print(f'using pre-trained embeddings from \'{args.embedding_path}\'')
    print(
        f'initialized {num_pretrained}/{len(vocabulary)} '
        f'embeddings ({pct_pretrained}%)'
    )
    print()

    # device = torch.device(f'cuda:{rank}')
    model = model.to(rank)
    model = DDP(model, device_ids=[rank], output_device=rank)

    # if args.use_gpu:
    #     model = cuda(args, model)

    if args.resume and args.model_path:
        map_location = {"cuda:0": "cuda:{}".format(rank)}
        model.load_state_dict(torch.load(args.model_path, map_location=map_location))

    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'using model \'{args.model}\' ({params} params)')
    print(model)
    print()

    if args.do_train:
        # Track training statistics for checkpointing.
        eval_history = []
        best_eval_loss = float('inf')

        # Begin training.
        for epoch in range(1, args.epochs + 1):
            # Perform training and evaluation steps.
            try:
                train_loss = train(args, epoch, model, train_dataset)
            except RuntimeError:
                print(f'NCCL Wait Timeout, rank: \'{args.local_rank}\' (exit)')
                exit(1)
            eval_loss = evaluate(args, epoch, model, dev_dataset)

            # If the model's evaluation loss yields a global improvement,
            # checkpoint the model.
            if rank == 0:
                eval_history.append(eval_loss < best_eval_loss)
                if eval_loss < best_eval_loss:
                    best_eval_loss = eval_loss
                    torch.save(model.state_dict(), args.model_path)
                
                print(
                    f'epoch = {epoch} | '
                    f'train loss = {train_loss:.6f} | '
                    f'eval loss = {eval_loss:.6f} | '
                    f"{'saving model!' if eval_history[-1] else ''}"
                )

                # If early stopping conditions are met, stop training.
                if _early_stop(args, eval_history):
                    suffix = 's' if args.early_stop > 1 else ''
                    print(
                        f'no improvement after {args.early_stop} epoch{suffix}. '
                        'early stopping...'
                    )
                    print()
                    cleanup_dist()
                    break

    if args.do_test and rank == 0:
        # Write predictions to the output file. Use the printed command
        # below to obtain official EM/F1 metrics.
        write_predictions(args, model, dev_dataset)
        eval_cmd = (
            'python3 evaluate.py '
            f'--dataset_path {args.dev_path} '
            f'--output_path {args.output_path}'
        )
        print()
        print(f'predictions written to \'{args.output_path}\'')
        print(f'compute EM/F1 with: \'{eval_cmd}\'')
        print()
コード例 #15
0
def worker(gpu, npgpus_per_node, args):

    START_TIME = time.time()

    args.device_id = gpu
    args.device = "cuda:{}".format(gpu)
    args.local_rank = args.local_rank * npgpus_per_node + gpu
    args.world_size = args.world_size * npgpus_per_node

    if args.distributed:
        torch.distributed.init_process_group(backend='nccl',
                                             init_method=args.init_method,
                                             world_size=args.world_size,
                                             rank=args.local_rank)

    torch.cuda.set_device(args.device_id)

    # setup watcher settings
    watcher = Watcher(
        rank=args.local_rank,
        log_path=os.path.join(args.workspace_prefix, 'logs',
                              'log-{}.txt'.format(args.prefix)) if
        ((args.logfile is None) or (args.logfile == 'none')) else args.logfile)
    watcher.info('\n'.join([
        '{}:\t{}'.format(a, b)
        for a, b in sorted(args.__dict__.items(), key=lambda x: x[0])
    ]))
    watcher.info('Starting with HPARAMS: {}'.format(args.hp_str))
    watcher.info("RANK:{}, WORLD_SIZE:{}, DEVICE-ID:{}, MASTER={}".format(
        args.local_rank, args.world_size, args.init_method, args.device_id))

    # get the dataloader
    dataloader = get_dataloader(setup_dataloader(args))(
        args, watcher, vocab_file=args.vocab_file)

    # get the model
    model = get_model(args.model)(dataloader.SRC, dataloader.TRG, args)
    watcher.info(model)
    watcher.info("total trainable parameters: {}".format(
        format(count_parameters(model), ',')))
    watcher.info("Vocabulary size: {}/{}.".format(len(dataloader.SRC.vocab),
                                                  len(dataloader.TRG.vocab)))

    # use GPU
    if torch.cuda.is_available():
        model.cuda()
    if args.distributed:
        model = DDP(model,
                    device_ids=[args.device_id],
                    output_device=args.device_id)
    model = setup_pretrained_model(args, model, watcher)

    # start running
    if args.mode == 'train':
        watcher.info('starting training')
        train_model(args,
                    watcher,
                    model,
                    dataloader.train,
                    dataloader.dev,
                    decoding_path=None)

    elif args.mode == 'test':

        if (args.local_rank == 0) and (not os.path.exists(args.decoding_path)):
            os.mkdir(args.decoding_path)

        watcher.info(
            'starting decoding from the pre-trained model, on the test set...')
        assert args.load_from is not None, 'must decode from a pre-trained model.'

        with torch.no_grad():
            test_set = dataloader.test if args.decode_test else dataloader.dev
            name = '{}.b={}_a={}.txt'.format(
                args.test_set if args.decode_test else args.dev_set,
                args.beam_size, args.alpha)
            args.decoding_path += '/{}'.format(name)

            for set_i in test_set:
                valid_model(args,
                            watcher,
                            model,
                            set_i,
                            print_out=True,
                            decoding_path=args.decoding_path,
                            dataflow=['src', 'trg'])

    elif args.mode == 'valid_ppl':

        watcher.info(
            'starting to evaluate the model from the pre-trained model, on the test set...'
        )
        assert args.load_from is not None, 'must decode from a pre-trained model.'

        with torch.no_grad():
            test_set = dataloader.test if args.decode_test else dataloader.dev
            for set_i in test_set:
                if args.sweep_target_tokens is not None:
                    target_tokens = [
                        '<{}>'.format(a)
                        for a in args.sweep_target_tokens.split(',')
                    ]
                else:
                    target_tokens = [set_i.init_tokens['trg']]

                for trg_tok in target_tokens:
                    set_i.init_tokens['trg'] = trg_tok
                    watcher.info("{} -> {}".format(set_i.task,
                                                   set_i.init_tokens))
                    output_file = open(
                        args.decoding_path + '/{}->{}.txt'.format(
                            set_i.task, set_i.init_tokens['trg'][1:-1]), 'w')
                    outputs = valid_model_ppl(args,
                                              watcher,
                                              model,
                                              set_i,
                                              dataflow=['src', 'trg'],
                                              lm_only=args.lm_only)

                    if args.local_rank == 0:
                        for s, t, ppl in zip(
                                *
                            [outputs['src'], outputs['trg'], outputs['loss']]):
                            line = '{}\t{}\t{}'.format(ppl, s, t)
                            print(line, file=output_file, flush=True)
                        print('write done.')

    watcher.info("all done.  Total clock time = {}".format(
        str(datetime.timedelta(seconds=(time.time() - START_TIME)))))
コード例 #16
0
    def _test_fsdp_parity(
        self,
        model_class: Type[FSDPTestModel],
        fsdp_init_mode: FSDPInitMode,
        cuda_init_mode: CUDAInitMode,
        ref_init_fn: Optional[Callable] = None,
        num_iters: int = 2,
        save_model: bool = True,
        cpu_offload: CPUOffload = CPUOffload(),
        backward_prefetch: Optional[BackwardPrefetch] = None,
        forward_prefetch: bool = False,
        sharding_strategy: Optional[ShardingStrategy] = None,
        mixed_precision: Optional[MixedPrecision] = None,
        enable_sharded_grad_scaler: bool = False,
        use_pure_fp16: bool = False,
        norm_type: Optional[Union[float, int]] = None,
        init_kwargs: Optional[Dict[str, Any]] = None,
        **fsdp_kwargs,
    ):
        """
        Tests FSDP training against a reference, which defaults to DDP but
        may be customized with ``ref_init_fn``.

        Args:
            model_class (Type[FSDPTestModel]): A model class that inherits from
                ``FSDPTestModel``, which defines the expected interface.
            fsdp_init_mode (FSDPInitMode): The mode to initialize the
                FSDP-wrapped model. This should not be ``NO_FSDP``.
            ref_init_fn (Optional[Callable]): A callable to invoke that wraps a
                non-wrapped model to construct the reference model, where this
                wrapper should provide data parallel semantics. If ``None``,
                then the callable defaults to the DDP constructor.
        """
        assert fsdp_init_mode != FSDPInitMode.NO_FSDP, "Expects an FSDP init mode that wraps with FSDP"
        if init_kwargs is None:
            init_kwargs = {}
        lr = 1e-2
        rank = self.process_group.rank()
        # Establish reference behavior with DDP
        model = model_class.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
            **init_kwargs,
        )
        if ref_init_fn is None:
            ref_model = DDP(model, device_ids=[rank], output_device=rank)
        else:
            ref_model = ref_init_fn(model)
        if use_pure_fp16:
            ref_model = ref_model.half()
        ref_loss = self._train_for_several_steps(
            ref_model,
            num_iters,
            autocast=mixed_precision is not None,
            lr=lr,
            fsdp_cpu_offload=cpu_offload,
            mixed_precision=mixed_precision,
            norm_type=norm_type,
            enable_sharded_grad_scaler=enable_sharded_grad_scaler,
            use_pure_fp16=use_pure_fp16,
        )
        ddp_params = list(ref_model.parameters())
        # Check against FSDP behavior
        fsdp_kwargs.update({
            "cpu_offload": cpu_offload,
            "backward_prefetch": backward_prefetch,
            "forward_prefetch": forward_prefetch,
            "sharding_strategy": sharding_strategy,
            "mixed_precision": mixed_precision,
        })
        try:
            fsdp_model = model_class.init(
                self.process_group,
                fsdp_init_mode,
                cuda_init_mode,
                fsdp_kwargs,
                deterministic=True,
                **init_kwargs,
            )
        except Exception as e:
            raise ValueError(
                f"Initializing {model_class} raised error {str(e)}")
        if not isinstance(fsdp_model, FSDP):
            # Enforce that we wrap with top-level FSDP since we are comparing
            # assuming a data parallel reference and some test models may not
            # do so in their `init()` method
            fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs)
        if use_pure_fp16:
            # Change the model parameter dtype after FSDP initialization
            fsdp_model = fsdp_model.half()
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
            fsdp_model = fsdp_model.cuda()
        offload_params = cpu_offload is not None and cpu_offload.offload_params
        # Offloading parameters with `CUDA_AFTER` should raise an error during
        # lazy initialization due to the parameter devices not being CPU;
        # otherwise, all parameter devices should be CPU
        expects_device_error = offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER
        expects_cpu_device = offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER
        if expects_cpu_device:
            cpu_device = torch.device("cpu")
            for param in fsdp_model.parameters():
                self.assertEqual(param.device, cpu_device)
        context = (self.assertRaisesRegex(AssertionError,
                                          "Expected param to be on CPU")
                   if expects_device_error else suppress())
        with context:
            fsdp_loss = self._train_for_several_steps(
                fsdp_model,
                num_iters,
                autocast=False,
                lr=lr,
                fsdp_cpu_offload=cpu_offload,
                save_model=save_model,
                mixed_precision=mixed_precision,
                norm_type=norm_type,
                enable_sharded_grad_scaler=enable_sharded_grad_scaler,
                use_pure_fp16=use_pure_fp16,
            )
        # No need to check for parameter and loss parity if expecting an error
        if expects_device_error:
            return
        # Check parameter devices are CPU if offloading to CPU before calling
        # `get_full_params()`, which will cast the parameters to FP32
        if offload_params:
            for param in fsdp_model.parameters():
                self.assertEqual(param.device, cpu_device)
            fsdp_loss = fsdp_loss.cuda()
        fsdp_unsharded_params = get_full_params(fsdp_model)
        torch.testing.assert_allclose(ref_loss, fsdp_loss)
        # Do not check for parameter parity if using mixed precision since (1)
        # the DDP parameters are in FP16 (from `half()`) while the FSDP
        # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs
        # the optimizer in FP16 while FSDP runs it in FP32
        if mixed_precision is not None:
            self.assertEqual(
                ddp_params,
                fsdp_unsharded_params,
                exact_device=True,
                msg="FSDP did not match DDP",
            )
コード例 #17
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--src_file",
                        default=None,
                        type=str,
                        help="The input data file name.")
    parser.add_argument("--tgt_file",
                        default=None,
                        type=str,
                        help="The output data file name.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--config_path",
                        default=None,
                        type=str,
                        help="Bert config file path.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--log_dir",
        default='',
        type=str,
        required=True,
        help="The output directory where the log will be written.")
    parser.add_argument("--model_recover_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The file of fine-tuned pretraining model.")
    parser.add_argument("--optim_recover_path",
                        default=None,
                        type=str,
                        help="The file of pretraining optimizer.")

    # Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--label_smoothing",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.01,
                        type=float,
                        help="The weight decay rate for Adam.")
    parser.add_argument("--finetune_decay",
                        action='store_true',
                        help="Weight decay to the original weights.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--hidden_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for hidden states.")
    parser.add_argument("--attention_probs_dropout_prob",
                        default=0.1,
                        type=float,
                        help="Dropout rate for attention probabilities.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp32_embedding',
        action='store_true',
        help=
        "Whether to use 32-bit float precision instead of 16-bit for embeddings"
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--amp',
                        action='store_true',
                        help="Whether to use amp for fp16")
    parser.add_argument(
        '--from_scratch',
        action='store_true',
        help=
        "Initialize parameters with random values (i.e., training from scratch)."
    )
    parser.add_argument('--new_segment_ids',
                        action='store_true',
                        help="Use new segment ids for bi-uni-directional LM.")
    parser.add_argument('--new_pos_ids',
                        action='store_true',
                        help="Use new position ids for LMs.")
    parser.add_argument('--tokenized_input',
                        action='store_true',
                        help="Whether the input is tokenized.")
    parser.add_argument('--max_len_a',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment A.")
    parser.add_argument('--max_len_b',
                        type=int,
                        default=0,
                        help="Truncate_config: maximum length of segment B.")
    parser.add_argument(
        '--trunc_seg',
        default='',
        help="Truncate_config: first truncate segment A/B (option: a, b).")
    parser.add_argument(
        '--always_truncate_tail',
        action='store_true',
        help="Truncate_config: Whether we should always truncate tail.")
    parser.add_argument(
        "--mask_prob",
        default=0.15,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument(
        "--mask_prob_eos",
        default=0,
        type=float,
        help=
        "Number of prediction is sometimes less than max_pred when sequence is short."
    )
    parser.add_argument('--max_pred',
                        type=int,
                        default=20,
                        help="Max tokens of prediction.")
    parser.add_argument("--num_workers",
                        default=0,
                        type=int,
                        help="Number of workers for the data loader.")

    parser.add_argument('--mask_source_words',
                        action='store_true',
                        help="Whether to mask source words for training")
    parser.add_argument('--skipgram_prb',
                        type=float,
                        default=0.0,
                        help='prob of ngram mask')
    parser.add_argument('--skipgram_size',
                        type=int,
                        default=1,
                        help='the max size of ngram mask')
    parser.add_argument('--mask_whole_word',
                        action='store_true',
                        help="Whether masking a whole word.")
    parser.add_argument('--do_l2r_training',
                        action='store_true',
                        help="Whether to do left to right training")
    parser.add_argument(
        '--has_sentence_oracle',
        action='store_true',
        help="Whether to have sentence level oracle for training. "
        "Only useful for summary generation")
    parser.add_argument('--max_position_embeddings',
                        type=int,
                        default=None,
                        help="max position embeddings")
    parser.add_argument('--relax_projection',
                        action='store_true',
                        help="Use different projection layers for tasks.")
    parser.add_argument('--ffn_type',
                        default=0,
                        type=int,
                        help="0: default mlp; 1: W((Wx+b) elem_prod x);")
    parser.add_argument('--num_qkv',
                        default=0,
                        type=int,
                        help="Number of different <Q,K,V>.")
    parser.add_argument('--seg_emb',
                        action='store_true',
                        help="Using segment embedding for self-attention.")
    parser.add_argument(
        '--s2s_special_token',
        action='store_true',
        help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.")
    parser.add_argument('--s2s_add_segment',
                        action='store_true',
                        help="Additional segmental for the encoder of S2S.")
    parser.add_argument(
        '--s2s_share_segment',
        action='store_true',
        help=
        "Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment)."
    )
    parser.add_argument('--pos_shift',
                        action='store_true',
                        help="Using position shift for fine-tuning.")

    args = parser.parse_args()

    assert Path(
        args.model_recover_path).exists(), "--model_recover_path doesn't exist"

    args.output_dir = args.output_dir.replace('[PT_OUTPUT_DIR]',
                                              os.getenv('PT_OUTPUT_DIR', ''))
    args.log_dir = args.log_dir.replace('[PT_OUTPUT_DIR]',
                                        os.getenv('PT_OUTPUT_DIR', ''))

    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    json.dump(args.__dict__,
              open(os.path.join(args.output_dir, 'opt.json'), 'w'),
              sort_keys=True,
              indent=2)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        dist.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    if args.max_position_embeddings:
        tokenizer.max_len = args.max_position_embeddings
    data_tokenizer = WhitespaceTokenizer(
    ) if args.tokenized_input else tokenizer
    if args.local_rank == 0:
        dist.barrier()

    if args.do_train:
        print("Loading Train Dataset", args.data_dir)
        bi_uni_pipeline = [
            seq2seq_loader.Preprocess4Seq2seq(
                args.max_pred,
                args.mask_prob,
                list(tokenizer.vocab.keys()),
                tokenizer.convert_tokens_to_ids,
                args.max_seq_length,
                new_segment_ids=args.new_segment_ids,
                truncate_config={
                    'max_len_a': args.max_len_a,
                    'max_len_b': args.max_len_b,
                    'trunc_seg': args.trunc_seg,
                    'always_truncate_tail': args.always_truncate_tail
                },
                mask_source_words=args.mask_source_words,
                skipgram_prb=args.skipgram_prb,
                skipgram_size=args.skipgram_size,
                mask_whole_word=args.mask_whole_word,
                mode="s2s",
                has_oracle=args.has_sentence_oracle,
                num_qkv=args.num_qkv,
                s2s_special_token=args.s2s_special_token,
                s2s_add_segment=args.s2s_add_segment,
                s2s_share_segment=args.s2s_share_segment,
                pos_shift=args.pos_shift)
        ]
        file_oracle = None
        if args.has_sentence_oracle:
            file_oracle = os.path.join(args.data_dir, 'train.oracle')
        fn_src = os.path.join(args.data_dir,
                              args.src_file if args.src_file else 'train.src')
        fn_tgt = os.path.join(args.data_dir,
                              args.tgt_file if args.tgt_file else 'train.tgt')
        train_dataset = seq2seq_loader.Seq2SeqDataset(
            fn_src,
            fn_tgt,
            args.train_batch_size,
            data_tokenizer,
            args.max_seq_length,
            file_oracle=file_oracle,
            bi_uni_pipeline=bi_uni_pipeline)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset, replacement=False)
            _batch_size = args.train_batch_size
        else:
            train_sampler = DistributedSampler(train_dataset)
            _batch_size = args.train_batch_size // dist.get_world_size()
        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=_batch_size,
            sampler=train_sampler,
            num_workers=args.num_workers,
            collate_fn=seq2seq_loader.batch_list_to_batch_tensors,
            pin_memory=False)

    # note: args.train_batch_size has been changed to (/= args.gradient_accumulation_steps)
    # t_total = int(math.ceil(len(train_dataset.ex_list) / args.train_batch_size)
    t_total = int(
        len(train_dataloader) * args.num_train_epochs /
        args.gradient_accumulation_steps)

    amp_handle = None
    if args.fp16 and args.amp:
        from apex import amp
        amp_handle = amp.init(enable_caching=True)
        logger.info("enable fp16 with amp")

    # Prepare model
    recover_step = _get_max_epoch_model(args.output_dir)
    cls_num_labels = 2
    type_vocab_size = 6 + \
        (1 if args.s2s_add_segment else 0) if args.new_segment_ids else 2
    num_sentlvl_labels = 2 if args.has_sentence_oracle else 0
    relax_projection = 4 if args.relax_projection else 0
    if args.local_rank not in (-1, 0):
        # Make sure only the first process in distributed training will download model & vocab
        dist.barrier()
    if (recover_step is None) and (args.model_recover_path is None):
        # if _state_dict == {}, the parameters are randomly initialized
        # if _state_dict == None, the parameters are initialized with bert-init
        _state_dict = {} if args.from_scratch else None
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=_state_dict,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb)
        global_step = 0
    else:
        if recover_step:
            logger.info("***** Recover model: %d *****", recover_step)
            model_recover = torch.load(os.path.join(
                args.output_dir, "model.{0}.bin".format(recover_step)),
                                       map_location='cpu')
            # recover_step == number of epochs
            global_step = math.floor(recover_step * t_total /
                                     args.num_train_epochs)
        elif args.model_recover_path:
            logger.info("***** Recover model: %s *****",
                        args.model_recover_path)
            model_recover = torch.load(args.model_recover_path,
                                       map_location='cpu')
            global_step = 0
        model = BertForPreTrainingLossMask.from_pretrained(
            args.bert_model,
            state_dict=model_recover,
            num_labels=cls_num_labels,
            num_rel=0,
            type_vocab_size=type_vocab_size,
            config_path=args.config_path,
            task_idx=3,
            num_sentlvl_labels=num_sentlvl_labels,
            max_position_embeddings=args.max_position_embeddings,
            label_smoothing=args.label_smoothing,
            fp32_embedding=args.fp32_embedding,
            relax_projection=relax_projection,
            new_pos_ids=args.new_pos_ids,
            ffn_type=args.ffn_type,
            hidden_dropout_prob=args.hidden_dropout_prob,
            attention_probs_dropout_prob=args.attention_probs_dropout_prob,
            num_qkv=args.num_qkv,
            seg_emb=args.seg_emb)
    if args.local_rank == 0:
        dist.barrier()

    if args.fp16:
        model.half()
        if args.fp32_embedding:
            model.bert.embeddings.word_embeddings.float()
            model.bert.embeddings.position_embeddings.float()
            model.bert.embeddings.token_type_embeddings.float()
    model.to(device)
    if args.local_rank != -1:
        try:
            from torch.nn.parallel.distributed import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("DistributedDataParallel")
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    find_unused_parameters=True)
    elif n_gpu > 1:
        #model = torch.nn.DataParallel(model)
        model = DataParallelImbalance(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            #from apex.optimizers.fp16_optimizer import FP16_Optimizer
            from pytorch_pretrained_bert.optimization_fp16 import FP16_Optimizer_State
            from apex.optimizers.fused_adam import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer_State(optimizer,
                                             dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer_State(optimizer,
                                             static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    if recover_step:
        logger.info("***** Recover optimizer: %d *****", recover_step)
        optim_recover = torch.load(os.path.join(
            args.output_dir, "optim.{0}.bin".format(recover_step)),
                                   map_location='cpu')
        if hasattr(optim_recover, 'state_dict'):
            optim_recover = optim_recover.state_dict()
        optimizer.load_state_dict(optim_recover)
        if args.loss_scale == 0:
            logger.info("***** Recover optimizer: dynamic_loss_scale *****")
            optimizer.dynamic_loss_scale = True

    logger.info("***** CUDA.empty_cache() *****")
    torch.cuda.empty_cache()

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", t_total)

        model.train()
        if recover_step:
            start_epoch = recover_step + 1
        else:
            start_epoch = 1
        for i_epoch in trange(start_epoch,
                              int(args.num_train_epochs) + 1,
                              desc="Epoch",
                              disable=args.local_rank not in (-1, 0)):
            if args.local_rank != -1:
                train_sampler.set_epoch(i_epoch)
            iter_bar = tqdm(train_dataloader,
                            desc='Iter (loss=X.XXX)',
                            disable=args.local_rank not in (-1, 0))
            for step, batch in enumerate(iter_bar):
                batch = [
                    t.to(device) if t is not None else None for t in batch
                ]
                if args.has_sentence_oracle:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx, oracle_pos, oracle_weights, oracle_labels = batch
                else:
                    input_ids, segment_ids, input_mask, mask_qkv, lm_label_ids, masked_pos, masked_weights, is_next, task_idx = batch
                    oracle_pos, oracle_weights, oracle_labels = None, None, None
                loss_tuple = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   lm_label_ids,
                                   is_next,
                                   masked_pos=masked_pos,
                                   masked_weights=masked_weights,
                                   task_idx=task_idx,
                                   masked_pos_2=oracle_pos,
                                   masked_weights_2=oracle_weights,
                                   masked_labels_2=oracle_labels,
                                   mask_qkv=mask_qkv)
                masked_lm_loss, next_sentence_loss = loss_tuple
                if n_gpu > 1:  # mean() to average on multi-gpu.
                    # loss = loss.mean()
                    masked_lm_loss = masked_lm_loss.mean()
                    next_sentence_loss = next_sentence_loss.mean()
                loss = masked_lm_loss + next_sentence_loss

                # logging for each step (i.e., before normalization by args.gradient_accumulation_steps)
                iter_bar.set_description('Iter (loss=%5.3f)' % loss.item())

                # ensure that accumlated gradients are normalized
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                    if amp_handle:
                        amp_handle._clear_cache()
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    lr_this_step = args.learning_rate * \
                        warmup_linear(global_step/t_total,
                                      args.warmup_proportion)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model
            if (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                logger.info(
                    "** ** * Saving fine-tuned model and optimizer ** ** * ")
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    args.output_dir, "model.{0}.bin".format(i_epoch))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_optim_file = os.path.join(
                    args.output_dir, "optim.{0}.bin".format(i_epoch))
                torch.save(optimizer.state_dict(), output_optim_file)

                logger.info("***** CUDA.empty_cache() *****")
                torch.cuda.empty_cache()
コード例 #18
0
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


watcher.info("total trainable parameters: {}".format(
    format(count_parameters(model), ',')))
watcher.info("Vocabulary size: {}/{}.".format(len(dataloader.SRC.vocab),
                                              len(dataloader.TRG.vocab)))

# use GPU
if torch.cuda.is_available():
    model.cuda()

if args.distributed:
    model = DDP(model,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

# load pre-trained parameters
if args.load_from != 'none':
    with torch.cuda.device(args.local_rank):
        pretrained_dict = torch.load(
            os.path.join(args.workspace_prefix, 'models',
                         args.load_from + '.pt'),
            map_location=lambda storage, loc: storage.cuda())
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
コード例 #19
0
def main():
    args = create_argparser().parse_args()

    dist_util.setup_dist()
    logger.configure()

    logger.log("creating model and diffusion...")
    model, diffusion = create_classifier_and_diffusion(
        **args_to_dict(args,
                       classifier_and_diffusion_defaults().keys()))
    model.to(dist_util.dev())
    if args.noised:
        schedule_sampler = create_named_schedule_sampler(
            args.schedule_sampler, diffusion)

    resume_step = 0
    if args.resume_checkpoint:
        resume_step = parse_resume_step_from_filename(args.resume_checkpoint)
        if dist.get_rank() == 0:
            logger.log(
                f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step"
            )
            model.load_state_dict(
                dist_util.load_state_dict(args.resume_checkpoint,
                                          map_location=dist_util.dev()))

    # Needed for creating correct EMAs and fp16 parameters.
    dist_util.sync_params(model.parameters())

    mp_trainer = MixedPrecisionTrainer(model=model,
                                       use_fp16=args.classifier_use_fp16,
                                       initial_lg_loss_scale=16.0)

    model = DDP(
        model,
        device_ids=[dist_util.dev()],
        output_device=dist_util.dev(),
        broadcast_buffers=False,
        bucket_cap_mb=128,
        find_unused_parameters=False,
    )

    logger.log("creating data loader...")
    data = load_data(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        class_cond=True,
        random_crop=True,
    )
    if args.val_data_dir:
        val_data = load_data(
            data_dir=args.val_data_dir,
            batch_size=args.batch_size,
            image_size=args.image_size,
            class_cond=True,
        )
    else:
        val_data = None

    logger.log(f"creating optimizer...")
    opt = AdamW(mp_trainer.master_params,
                lr=args.lr,
                weight_decay=args.weight_decay)
    if args.resume_checkpoint:
        opt_checkpoint = bf.join(bf.dirname(args.resume_checkpoint),
                                 f"opt{resume_step:06}.pt")
        logger.log(
            f"loading optimizer state from checkpoint: {opt_checkpoint}")
        opt.load_state_dict(
            dist_util.load_state_dict(opt_checkpoint,
                                      map_location=dist_util.dev()))

    logger.log("training classifier model...")

    def forward_backward_log(data_loader, prefix="train"):
        batch, extra = next(data_loader)
        labels = extra["y"].to(dist_util.dev())

        batch = batch.to(dist_util.dev())
        # Noisy images
        if args.noised:
            t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev())
            batch = diffusion.q_sample(batch, t)
        else:
            t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev())

        for i, (sub_batch, sub_labels, sub_t) in enumerate(
                split_microbatches(args.microbatch, batch, labels, t)):
            logits = model(sub_batch, timesteps=sub_t)
            loss = F.cross_entropy(logits, sub_labels, reduction="none")

            losses = {}
            losses[f"{prefix}_loss"] = loss.detach()
            losses[f"{prefix}_acc@1"] = compute_top_k(logits,
                                                      sub_labels,
                                                      k=1,
                                                      reduction="none")
            losses[f"{prefix}_acc@5"] = compute_top_k(logits,
                                                      sub_labels,
                                                      k=5,
                                                      reduction="none")
            log_loss_dict(diffusion, sub_t, losses)
            del losses
            loss = loss.mean()
            if loss.requires_grad:
                if i == 0:
                    mp_trainer.zero_grad()
                mp_trainer.backward(loss * len(sub_batch) / len(batch))

    for step in range(args.iterations - resume_step):
        logger.logkv("step", step + resume_step)
        logger.logkv(
            "samples",
            (step + resume_step + 1) * args.batch_size * dist.get_world_size(),
        )
        if args.anneal_lr:
            set_annealed_lr(opt, args.lr,
                            (step + resume_step) / args.iterations)
        forward_backward_log(data)
        mp_trainer.optimize(opt)
        if val_data is not None and not step % args.eval_interval:
            with th.no_grad():
                with model.no_sync():
                    model.eval()
                    forward_backward_log(val_data, prefix="val")
                    model.train()
        if not step % args.log_interval:
            logger.dumpkvs()
        if (step and dist.get_rank() == 0
                and not (step + resume_step) % args.save_interval):
            logger.log("saving model...")
            save_model(mp_trainer, opt, step + resume_step)

    if dist.get_rank() == 0:
        logger.log("saving model...")
        save_model(mp_trainer, opt, step + resume_step)
    dist.barrier()
コード例 #20
0
def get_model(args):
    """Build the model."""

    print_rank_0('building GPT3 model ...')
    print ("asserting we have a correct number of attention heads...")
    assert args.num_attention_heads % args.model_parallel_size == 0
    num_local_heads = args.num_attention_heads // args.model_parallel_size
    deepspeed_sparsity_config = None
    if DEEPSPEED_WRAP and args.deepspeed:
        print ("we're using deepspeed, and so we're getting a sparse attention config")
        deepspeed_sparsity_config = get_sparse_attention_config(args, num_local_heads)
    if deepspeed_sparsity_config is not None:
        print_rank_0(f"Using sparse attention with mode {args.sparse_mode}")
    print ("Calling GPT3Model constructor...")    
    model = GPT3Model(num_layers=args.num_layers,
                      vocab_size=args.vocab_size,
                      hidden_size=args.hidden_size,
                      num_attention_heads=args.num_attention_heads,
                      embedding_dropout_prob=args.hidden_dropout,
                      attention_dropout_prob=args.attention_dropout,
                      output_dropout_prob=args.hidden_dropout,
                      max_sequence_length=args.max_position_embeddings,
                      checkpoint_activations=args.checkpoint_activations,
                      checkpoint_num_layers=args.checkpoint_num_layers,
                      parallel_output=True,
                      deepspeed_sparsity_config=deepspeed_sparsity_config,
                      sparse_mode=args.sparse_mode)

    if args.load_huggingface is not None:
        print ("Loading huggingface model...")
        model = load_huggingface_model(model, args.load_huggingface, args.huggingface_double_pos_embeddings)

    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
    if DEEPSPEED_WRAP and args.deepspeed and args.fp16:
        print ("We've had deepspeed AND fp16, so we're halfing the model...")
        model.half()

    # GPU allocation.
    print (f"placing the model on device {torch.cuda.current_device()}")
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        print ("we've halfed the model before, but now we're wrapping it into a fp16_module. For...some reason...")
        model = FP16_Module(model)

    # Wrap model for distributed training.
    print ("Setting up distributed training...")
    if USE_TORCH_DDP:
        i = torch.cuda.current_device()
        print (f"Using classic pytorch DDP with device {i}")
        model = DDP(model, device_ids=[i], output_device=i,
                    process_group=mpu.get_data_parallel_group())
    else:
        print ("Using sberbank magic DDP")
        model = DDP(model)

#     input ("ready to return model")
    print ("ready to return model")
    return model