Ejemplo n.º 1
0
 def initialize_training(self):
     self.retinanet = model.resnet50(num_classes=12,
                                     pretrained=True).to(self.device)
     self.optimizer = optimizers.FusedAdam(
         params=self.retinanet.parameters(), lr=1e-5)
     self.amp = amp
     self.retinanet, self.optimizer = self.amp.initialize(
         models=self.retinanet, optimizers=self.optimizer)
     self.load_checkpoint()
     self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
         self.optimizer, patience=3, verbose=self.verbose)
Ejemplo n.º 2
0
def getOptimizer(model, hparams):
    if hparams.optimizer_fused:
        from apex import optimizers as apexopt
        if hparams.optimizer == "Adam":
            optimizer = apexopt.FusedAdam(model.parameters(),
                                          lr=hparams.learning_rate)
        elif hparams.optimizer == "LAMB":
            optimizer = apexopt.FusedLAMB(model.parameters(),
                                          lr=hparams.learning_rate)
    else:
        if hparams.optimizer == "Adam":
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=hparams.learning_rate)
        elif hparams.optimizer == "LAMB":
            raise NotImplementedError  # PyTorch doesn't currently include LAMB optimizer.
    return optimizer
Ejemplo n.º 3
0
    def get_optimizers(self, args):
        # Initialize utility lists and dicts for the optimizers
        nets_optims_names = utils.parse_str_to_dict(args.optims)
        nets_lrs = utils.parse_str_to_dict(args.lrs, value_type=float)

        # Initialize optimizers
        optims = {}

        for net_name, optim_name in nets_optims_names.items():
            # Prepare the options
            lr = nets_lrs[net_name]
            optim_name = optim_name.lower()
            params = self.nets[net_name].parameters()

            # Choose the required optimizer
            if optim_name == 'adam':
                opt = optim.Adam(params,
                                 lr=lr,
                                 eps=args.eps,
                                 betas=(args.adam_beta1, 0.999))

            elif optim_name == 'sgd':
                opt = optim.SGD(params, lr=lr)

            elif optim_name == 'fusedadam':
                from apex import optimizers

                opt = optimizers.FusedAdam(params,
                                           lr=lr,
                                           eps=args.eps,
                                           betas=(args.adam_beta1, 0.999))

            elif optim_name == 'fusedsgd':
                from apex import optimizers

                opt = optimizers.FusedSGD(params, lr=lr)

            elif optim_name == 'lbfgs':
                opt = optim.LBFGS(params, lr=lr)

            else:
                raise 'Unsupported optimizer name'

            optims[net_name] = opt

        return optims
Ejemplo n.º 4
0
def train(output_directory, log_directory, checkpoint_path, warm_start, warm_start_force, n_gpus,
          rank, group_name, hparams):
    """Training and validation logging results to tensorboard and stdout

    Params
    ------
    output_directory (string): directory to save checkpoints
    log_directory (string) directory to save tensorboard logs
    checkpoint_path(string): checkpoint path
    n_gpus (int): number of gpus
    rank (int): rank of current gpu
    hparams (object): comma separated list of "name=value" pairs.
    """
    hparams.n_gpus = n_gpus
    hparams.rank = rank
    if hparams.distributed_run:
        init_distributed(hparams, n_gpus, rank, group_name)

    torch.manual_seed(hparams.seed)
    torch.cuda.manual_seed(hparams.seed)

    train_loader, valset, collate_fn, train_sampler, trainset = prepare_dataloaders(hparams)
    speaker_lookup = trainset.speaker_ids
    
    if hparams.drop_frame_rate > 0.:
        if rank != 0: # if global_mean not yet calcuated, wait for main thread to do it
            while not os.path.exists(hparams.global_mean_npy): time.sleep(1)
        global_mean = calculate_global_mean(train_loader, hparams.global_mean_npy, hparams)
        hparams.global_mean = global_mean
    
    model = load_model(hparams)

    model.eval() # test if this is needed anymore

    learning_rate = hparams.learning_rate
	if hparams.Apex_optimizer: # apex optimizer is slightly faster with slightly more vram usage in my testing. Helps in both fp32 and fp16.
    	optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay)
	else:
	    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=hparams.weight_decay)
Ejemplo n.º 5
0
def get_optimizer(model, optimizer: str, fp16_run=True, optimizer_fused=True, learning_rate=1e-4, max_grad_norm=200):
    optimizer = optimizer.lower()
    if optimizer_fused:
        from apex import optimizers as apexopt
        if optimizer == "adam":
            optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            optimizer = apexopt.FusedLAMB(model.parameters(), lr=learning_rate, max_grad_norm=max_grad_norm)
    else:
        if optimizer == "adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            from lamb import Lamb as optLAMB
            optimizer = optLAMB(model.parameters(), lr=learning_rate)
    
    global amp
    if fp16_run:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    else:
        amp = None
    return model, optimizer
Ejemplo n.º 6
0
def main():
    global best_prec1, args

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank % torch.cuda.device_count()
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()

    assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()

    # Scale learning rate based on global batch size
    args.lr = args.lr*float(args.batch_size*args.world_size)/256. 
    if args.fused_adam:
        optimizer = optimizers.FusedAdam(model.parameters())
    else:
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    model, optimizer = amp.initialize(
        model, optimizer,
        # enabled=False,
        opt_level=args.opt_level,
        keep_batchnorm_fp32=args.keep_batchnorm_fp32,
        loss_scale=args.loss_scale
        )

    if args.distributed:
        # By default, apex.parallel.DistributedDataParallel overlaps communication with 
        # computation in the backward pass.
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))
        resume()

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')

    if(args.arch == "inception_v3"):
        crop_size = 299
        val_size = 320 # I chose this value arbitrarily, we can adjust.
    else:
        crop_size = 224
        val_size = 256

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(crop_size),
            transforms.RandomHorizontalFlip(),
            # transforms.ToTensor(), Too slow
            # normalize,
        ]))
    val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(val_size),
            transforms.CenterCrop(crop_size),
        ]))

    train_sampler = None
    val_sampler = None
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True,
        sampler=val_sampler,
        collate_fn=fast_collate)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)
        if args.prof:
            break
        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        if args.local_rank == 0:
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer' : optimizer.state_dict(),
            }, is_best)
Ejemplo n.º 7
0
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          sigma, loss_empthasis, iters_per_checkpoint, batch_size, seed, fp16_run,
          checkpoint_path, with_tensorboard, logdirname, datedlogdir, warm_start=False, optimizer='ADAM', start_zero=False):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======
    
    global WaveGlow
    global WaveGlowLoss
    
    ax = True # this is **really** bad coding practice :D
    if ax:
        from efficient_model_ax import WaveGlow
        from efficient_loss import WaveGlowLoss
    else:
        if waveglow_config["yoyo"]: # efficient_mode # TODO: Add to Config File
            from efficient_model import WaveGlow
            from efficient_loss import WaveGlowLoss
        else:
            from glow import WaveGlow, WaveGlowLoss
    
    criterion = WaveGlowLoss(sigma, loss_empthasis)
    model = WaveGlow(**waveglow_config).cuda()
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======
    STFTs = [STFT.TacotronSTFT(filter_length=window,
                                 hop_length=data_config['hop_length'],
                                 win_length=window,
                                 sampling_rate=data_config['sampling_rate'],
                                 n_mel_channels=160,
                                 mel_fmin=0, mel_fmax=16000) for window in data_config['validation_windows']]
    
    loader_STFT = STFT.TacotronSTFT(filter_length=data_config['filter_length'],
                                 hop_length=data_config['hop_length'],
                                 win_length=data_config['win_length'],
                                 sampling_rate=data_config['sampling_rate'],
                                 n_mel_channels=data_config['n_mel_channels'] if 'n_mel_channels' in data_config.keys() else 160,
                                 mel_fmin=data_config['mel_fmin'], mel_fmax=data_config['mel_fmax'])
    
    #optimizer = "Adam"
    optimizer = optimizer.lower()
    optimizer_fused = bool( 0 ) # use Apex fused optimizer, should be identical to normal but slightly faster and only works on RTX cards
    if optimizer_fused:
        from apex import optimizers as apexopt
        if optimizer == "adam":
            optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            optimizer = apexopt.FusedLAMB(model.parameters(), lr=learning_rate, max_grad_norm=200)
    else:
        if optimizer == "adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            from lamb import Lamb as optLAMB
            optimizer = optLAMB(model.parameters(), lr=learning_rate)
            #import torch_optimizer as optim
            #optimizer = optim.Lamb(model.parameters(), lr=learning_rate)
            #raise# PyTorch doesn't currently include LAMB optimizer.
    
    if fp16_run:
        global amp
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    else:
        amp = None
    
    ## LEARNING RATE SCHEDULER
    if True:
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        min_lr = 1e-8
        factor = 0.1**(1/5) # amount to scale the LR by on Validation Loss plateau
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=20, cooldown=2, min_lr=min_lr, verbose=True, threshold=0.0001, threshold_mode='abs')
        print("ReduceLROnPlateau used as Learning Rate Scheduler.")
    else: scheduler=False
    
    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration, scheduler = load_checkpoint(checkpoint_path, model,
                                                      optimizer, scheduler, fp16_run, warm_start=warm_start)
        iteration += 1  # next iteration is iteration + 1
    if start_zero:
        iteration = 0
    
    trainset = Mel2Samp(**data_config, check_files=True)
    speaker_lookup = trainset.speaker_ids
    # =====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        train_sampler = DistributedSampler(trainset, shuffle=True)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset, num_workers=3, shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)
    
    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)
    
    if with_tensorboard and rank == 0:
        from tensorboardX import SummaryWriter
        if datedlogdir:
            timestr = time.strftime("%Y_%m_%d-%H_%M_%S")
            log_directory = os.path.join(output_directory, logdirname, timestr)
        else:
            log_directory = os.path.join(output_directory, logdirname)
        logger = SummaryWriter(log_directory)
    
    moving_average = int(min(len(train_loader), 100)) # average loss over entire Epoch
    rolling_sum = StreamingMovingAverage(moving_average)
    start_time = time.time()
    start_time_iter = time.time()
    start_time_dekaiter = time.time()
    model.train()
    
    # best (averaged) training loss
    if os.path.exists(os.path.join(output_directory, "best_model")+".txt"):
        best_model_loss = float(str(open(os.path.join(output_directory, "best_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0])
    else:
        best_model_loss = -6.20
    
    # best (validation) MSE on inferred spectrogram.
    if os.path.exists(os.path.join(output_directory, "best_val_model")+".txt"):
        best_MSE = float(str(open(os.path.join(output_directory, "best_val_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0])
    else:
        best_MSE = 9e9
    
    epoch_offset = max(0, int(iteration / len(train_loader)))
    
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("{:,} total parameters in model".format(pytorch_total_params))
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("{:,} trainable parameters.".format(pytorch_total_params))
    
    print(f"Segment Length: {data_config['segment_length']:,}\nBatch Size: {batch_size:,}\nNumber of GPUs: {num_gpus:,}\nSamples/Iter: {data_config['segment_length']*batch_size*num_gpus:,}")
    
    training = True
    while training:
        try:
            if rank == 0:
                epochs_iterator = tqdm(range(epoch_offset, epochs), initial=epoch_offset, total=epochs, smoothing=0.01, desc="Epoch", position=1, unit="epoch")
            else:
                epochs_iterator = range(epoch_offset, epochs)
            # ================ MAIN TRAINING LOOP! ===================
            for epoch in epochs_iterator:
                print(f"Epoch: {epoch}")
                if num_gpus > 1:
                    train_sampler.set_epoch(epoch)
                
                if rank == 0:
                    iters_iterator = tqdm(enumerate(train_loader), desc=" Iter", smoothing=0, total=len(train_loader), position=0, unit="iter", leave=True)
                else:
                    iters_iterator = enumerate(train_loader)
                for i, batch in iters_iterator:
                    # run external code every iter, allows the run to be adjusted without restarts
                    if (i==0 or iteration % param_interval == 0):
                        try:
                            with open("run_every_epoch.py") as f:
                                internal_text = str(f.read())
                                if len(internal_text) > 0:
                                    #code = compile(internal_text, "run_every_epoch.py", 'exec')
                                    ldict = {'iteration': iteration, 'seconds_elapsed': time.time()-start_time}
                                    exec(internal_text, globals(), ldict)
                                else:
                                    print("No Custom code found, continuing without changes.")
                        except Exception as ex:
                            print(f"Custom code FAILED to run!\n{ex}")
                        globals().update(ldict)
                        locals().update(ldict)
                        if show_live_params:
                            print(internal_text)
                    if not iteration % 50: # check actual learning rate every 20 iters (because I sometimes see learning_rate variable go out-of-sync with real LR)
                        learning_rate = optimizer.param_groups[0]['lr']
                    # Learning Rate Schedule
                    if custom_lr:
                        old_lr = learning_rate
                        if iteration < warmup_start:
                            learning_rate = warmup_start_lr
                        elif iteration < warmup_end:
                            learning_rate = (iteration-warmup_start)*((A_+C_)-warmup_start_lr)/(warmup_end-warmup_start) + warmup_start_lr # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations.
                        else:
                            if iteration < decay_start:
                                learning_rate = A_ + C_
                            else:
                                iteration_adjusted = iteration - decay_start
                                learning_rate = (A_*(e**(-iteration_adjusted/B_))) + C_
                        assert learning_rate > -1e-8, "Negative Learning Rate."
                        if old_lr != learning_rate:
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = learning_rate
                    else:
                        scheduler.patience = scheduler_patience
                        scheduler.cooldown = scheduler_cooldown
                        if override_scheduler_last_lr:
                            scheduler._last_lr = override_scheduler_last_lr
                        if override_scheduler_best:
                            scheduler.best = override_scheduler_best
                        if override_scheduler_last_lr or override_scheduler_best:
                            print("scheduler._last_lr =", scheduler._last_lr, "scheduler.best =", scheduler.best, "  |", end='')
                    model.zero_grad()
                    mel, audio, speaker_ids = batch
                    mel = torch.autograd.Variable(mel.cuda(non_blocking=True))
                    audio = torch.autograd.Variable(audio.cuda(non_blocking=True))
                    speaker_ids = speaker_ids.cuda(non_blocking=True).long().squeeze(1)
                    outputs = model(mel, audio, speaker_ids)
                    
                    loss = criterion(outputs)
                    if num_gpus > 1:
                        reduced_loss = reduce_tensor(loss.data, num_gpus).item()
                    else:
                        reduced_loss = loss.item()
                    
                    if fp16_run:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    
                    if (reduced_loss > LossExplosionThreshold) or (math.isnan(reduced_loss)):
                        model.zero_grad()
                        raise LossExplosion(f"\nLOSS EXPLOSION EXCEPTION ON RANK {rank}: Loss reached {reduced_loss} during iteration {iteration}.\n\n\n")
                    
                    if use_grad_clip:
                        if fp16_run:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), grad_clip_thresh)
                        else:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                model.parameters(), grad_clip_thresh)
                        if type(grad_norm) == torch.Tensor:
                            grad_norm = grad_norm.item()
                        is_overflow = math.isinf(grad_norm) or math.isnan(grad_norm)
                    else: is_overflow = False; grad_norm=0.00001
                    
                    optimizer.step()
                    if not is_overflow and rank == 0:
                        # get current Loss Scale of first optimizer
                        loss_scale = amp._amp_state.loss_scalers[0]._loss_scale if fp16_run else 32768
                        
                        if with_tensorboard:
                            if (iteration % 100000 == 0):
                                # plot distribution of parameters
                                for tag, value in model.named_parameters():
                                    tag = tag.replace('.', '/')
                                    logger.add_histogram(tag, value.data.cpu().numpy(), iteration)
                            logger.add_scalar('training_loss', reduced_loss, iteration)
                            logger.add_scalar('training_loss_samples', reduced_loss, iteration*batch_size)
                            if (iteration % 20 == 0):
                                logger.add_scalar('learning.rate', learning_rate, iteration)
                            if (iteration % 10 == 0):
                                logger.add_scalar('duration', ((time.time() - start_time_dekaiter)/10), iteration)
                        
                        average_loss = rolling_sum.process(reduced_loss)
                        if (iteration % 10 == 0):
                            tqdm.write("{} {}:  {:.3f} {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective)  {:.2f}s/iter {:.4f}s/item".format(time.strftime("%H:%M:%S"), iteration, reduced_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), (time.time() - start_time_dekaiter)/10, ((time.time() - start_time_dekaiter)/10)/(batch_size*num_gpus)))
                            start_time_dekaiter = time.time()
                        else:
                            tqdm.write("{} {}:  {:.3f} {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective) {}LS".format(time.strftime("%H:%M:%S"), iteration, reduced_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), loss_scale))
                        start_time_iter = time.time()
                    
                    if rank == 0 and (len(rolling_sum.values) > moving_average-2):
                        if (average_loss+best_model_margin) < best_model_loss:
                            checkpoint_path = os.path.join(output_directory, "best_model")
                            try:
                                save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                            checkpoint_path)
                            except KeyboardInterrupt: # Avoid corrupting the model.
                                save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                            checkpoint_path)
                            text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8")
                            text_file.write(str(average_loss)+"\n"+str(iteration))
                            text_file.close()
                            best_model_loss = average_loss #Only save the model if X better than the current loss.
                    if rank == 0 and iteration > 0 and ((iteration % iters_per_checkpoint == 0) or (os.path.exists(save_file_check_path))):
                        checkpoint_path = f"{output_directory}/waveglow_{iteration}"
                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                        checkpoint_path)
                        if (os.path.exists(save_file_check_path)):
                            os.remove(save_file_check_path)
                    
                    if (iteration % validation_interval == 0):
                        if rank == 0:
                            MSE, MAE = validate(model, loader_STFT, STFTs, logger, iteration, data_config['validation_files'], speaker_lookup, sigma, output_directory, data_config)
                            if scheduler:
                                MSE = torch.tensor(MSE, device='cuda')
                                if num_gpus > 1:
                                    broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                                if MSE < best_MSE:
                                    checkpoint_path = os.path.join(output_directory, "best_val_model")
                                    try:
                                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                                    checkpoint_path)
                                    except KeyboardInterrupt: # Avoid corrupting the model.
                                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                                    checkpoint_path)
                                    text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8")
                                    text_file.write(str(MSE.item())+"\n"+str(iteration))
                                    text_file.close()
                                    best_MSE = MSE.item() #Only save the model if X better than the current loss.
                        else:
                            if scheduler:
                                MSE = torch.zeros(1, device='cuda')
                                broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                        learning_rate = optimizer.param_groups[0]['lr'] #check actual learning rate (because I sometimes see learning_rate variable go out-of-sync with real LR)
                    iteration += 1
            training = False # exit the While loop
        
        except LossExplosion as ex: # print Exception and continue from checkpoint. (turns out it takes < 4 seconds to restart like this, f*****g awesome)
            print(ex) # print Loss
            checkpoint_path = os.path.join(output_directory, "best_model")
            assert os.path.exists(checkpoint_path), "best_val_model must exist for automatic restarts"
            
            # clearing VRAM for load checkpoint
            audio = mel = speaker_ids = loss = None
            torch.cuda.empty_cache()
            
            model.eval()
            model, optimizer, iteration, scheduler = load_checkpoint(checkpoint_path, model, optimizer, scheduler, fp16_run)
            learning_rate = optimizer.param_groups[0]['lr']
            epoch_offset = max(0, int(iteration / len(train_loader)))
            model.train()
            iteration += 1
            pass # and continue training.
def main(argv):
    torch.manual_seed(FLAGS.seed)

    utils.init_logging(log_path=FLAGS.log_path)

    use_gpu = "cpu" not in FLAGS.base_device.lower()
    rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend,
                                                       use_gpu=use_gpu)
    device = FLAGS.base_device

    if is_main_process():
        dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')

        print("Command line flags:")
        pprint(FLAGS.flag_values_dict())

    print("Creating data loaders")

    FLAGS.set_default("test_batch_size",
                      FLAGS.test_batch_size // world_size * world_size)

    categorical_feature_sizes = get_categorical_feature_sizes(FLAGS)
    world_categorical_feature_sizes = np.asarray(categorical_feature_sizes)
    device_mapping = get_device_mapping(categorical_feature_sizes,
                                        num_gpus=world_size)

    batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size,
                                              num_gpus=world_size)
    batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu)))

    # sizes of embeddings for each GPU
    categorical_feature_sizes = world_categorical_feature_sizes[
        device_mapping['embedding'][rank]].tolist()

    bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping[
        'bottom_mlp'] else None

    data_loader_train, data_loader_test = get_data_loaders(
        FLAGS, device_mapping=device_mapping)

    model = DistributedDlrm(
        vectors_per_gpu=device_mapping['vectors_per_gpu'],
        embedding_device_mapping=device_mapping['embedding'],
        embedding_type=FLAGS.embedding_type,
        embedding_dim=FLAGS.embedding_dim,
        world_num_categorical_features=len(world_categorical_feature_sizes),
        categorical_feature_sizes=categorical_feature_sizes,
        num_numerical_features=FLAGS.num_numerical_features,
        hash_indices=FLAGS.hash_indices,
        bottom_mlp_sizes=bottom_mlp_sizes,
        top_mlp_sizes=FLAGS.top_mlp_sizes,
        interaction_op=FLAGS.interaction_op,
        fp16=FLAGS.amp,
        use_cpp_mlp=FLAGS.optimized_mlp,
        bottom_features_ordered=FLAGS.bottom_features_ordered,
        device=device)
    print(model)
    print(device_mapping)
    print(f"Batch sizes per gpu: {batch_sizes_per_gpu}")

    dist.setup_distributed_print(is_main_process())

    # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model.
    # Compensate it with further scaling lr
    scaled_lr = FLAGS.lr / FLAGS.loss_scale if FLAGS.amp else FLAGS.lr

    if FLAGS.Adam_embedding_optimizer:
        embedding_model_parallel_lr = scaled_lr
    else:
        embedding_model_parallel_lr = scaled_lr / world_size
    if FLAGS.Adam_MLP_optimizer:
        MLP_model_parallel_lr = scaled_lr
    else:
        MLP_model_parallel_lr = scaled_lr / world_size
    data_parallel_lr = scaled_lr

    if is_main_process():
        mlp_params = [{
            'params': list(model.top_model.parameters()),
            'lr': data_parallel_lr
        }, {
            'params': list(model.bottom_model.mlp.parameters()),
            'lr': MLP_model_parallel_lr
        }]
        mlp_lrs = [data_parallel_lr, MLP_model_parallel_lr]
    else:
        mlp_params = [{
            'params': list(model.top_model.parameters()),
            'lr': data_parallel_lr
        }]
        mlp_lrs = [data_parallel_lr]

    if FLAGS.Adam_MLP_optimizer:
        mlp_optimizer = apex_optim.FusedAdam(mlp_params)
    else:
        mlp_optimizer = apex_optim.FusedSGD(mlp_params)

    embedding_params = [{
        'params':
        list(model.bottom_model.embeddings.parameters()),
        'lr':
        embedding_model_parallel_lr
    }]
    embedding_lrs = [embedding_model_parallel_lr]

    if FLAGS.Adam_embedding_optimizer:
        embedding_optimizer = torch.optim.SparseAdam(embedding_params)
    else:
        embedding_optimizer = torch.optim.SGD(embedding_params)

    checkpoint_writer = make_distributed_checkpoint_writer(
        device_mapping=device_mapping,
        rank=rank,
        is_main_process=is_main_process(),
        config=FLAGS.flag_values_dict())

    checkpoint_loader = make_distributed_checkpoint_loader(
        device_mapping=device_mapping, rank=rank)

    if FLAGS.load_checkpoint_path:
        checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path)
        model.to(device)

    if FLAGS.amp:
        (model.top_model,
         model.bottom_model.mlp), mlp_optimizer = amp.initialize(
             [model.top_model, model.bottom_model.mlp],
             mlp_optimizer,
             opt_level="O2",
             loss_scale=1)

    if use_gpu:
        model.top_model = parallel.DistributedDataParallel(model.top_model)
    else:  # Use other backend for CPU
        model.top_model = torch.nn.parallel.DistributedDataParallel(
            model.top_model)

    if FLAGS.mode == 'test':
        auc = dist_evaluate(model, data_loader_test)

        results = {'auc': auc}
        dllogger.log(data=results, step=tuple())

        if auc is not None:
            print(f"Finished testing. Test auc {auc:.4f}")
        return

    if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process(
    ):
        logging.warning(
            "Saving checkpoint without --bottom_features_ordered flag will result in "
            "a device-order dependent model. Consider using --bottom_features_ordered "
            "if you plan to load the checkpoint in different device configurations."
        )

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")

    # Print per 16384 * 2000 samples by default
    default_print_freq = 16384 * 2000 // FLAGS.batch_size
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    steps_per_epoch = len(data_loader_train)
    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 1

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'loss', utils.SmoothedValue(window_size=1, fmt='{avg:.4f}'))
    metric_logger.add_meter(
        'step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}'))
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    # Accumulating loss on GPU to avoid memcpyD2H every step
    moving_loss = torch.zeros(1, device=device)
    moving_loss_stream = torch.cuda.Stream()

    lr_scheduler = utils.LearningRateScheduler(
        optimizers=[mlp_optimizer, embedding_optimizer],
        base_lrs=[mlp_lrs, embedding_lrs],
        warmup_steps=FLAGS.warmup_steps,
        warmup_factor=FLAGS.warmup_factor,
        decay_start_step=FLAGS.decay_start_step,
        decay_steps=FLAGS.decay_steps,
        decay_power=FLAGS.decay_power,
        end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr)

    data_stream = torch.cuda.Stream()
    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()
    stop_time = time()

    for epoch in range(FLAGS.epochs):
        epoch_start_time = time()

        batch_iter = prefetcher(iter(data_loader_train), data_stream)

        for step in range(len(data_loader_train)):
            timer.click()

            numerical_features, categorical_features, click = next(batch_iter)
            torch.cuda.synchronize()

            global_step = steps_per_epoch * epoch + step

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(
                    f"Reached max global steps of {FLAGS.max_steps}. Stopping."
                )
                break

            lr_scheduler.step()

            if click.shape[0] != FLAGS.batch_size:  # last batch
                logging.error("The last batch with size %s is not supported",
                              click.shape[0])
            else:
                output = model(numerical_features, categorical_features,
                               batch_sizes_per_gpu).squeeze()

                loss = loss_fn(
                    output, click[batch_indices[rank]:batch_indices[rank + 1]])

                if FLAGS.Adam_embedding_optimizer or FLAGS.Adam_MLP_optimizer:
                    model.zero_grad()
                else:
                    # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad()
                    for param_group in itertools.chain(
                            embedding_optimizer.param_groups,
                            mlp_optimizer.param_groups):
                        for param in param_group['params']:
                            param.grad = None

                if FLAGS.amp:
                    loss *= FLAGS.loss_scale
                    with amp.scale_loss(loss, mlp_optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if FLAGS.Adam_MLP_optimizer:
                    scale_MLP_gradients(mlp_optimizer, world_size)
                mlp_optimizer.step()

                if FLAGS.Adam_embedding_optimizer:
                    scale_embeddings_gradients(embedding_optimizer, world_size)
                embedding_optimizer.step()

                moving_loss_stream.wait_stream(torch.cuda.current_stream())
                with torch.cuda.stream(moving_loss_stream):
                    moving_loss += loss

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if step == 0:
                print(f"Started epoch {epoch}...")
            elif step % print_freq == 0:
                torch.cuda.current_stream().wait_stream(moving_loss_stream)
                # Averaging across a print_freq period to reduce the error.
                # An accurate timing needs synchronize which would slow things down.

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(
                        loss=moving_loss.item() / print_freq /
                        (FLAGS.loss_scale if FLAGS.amp else 1),
                        lr=mlp_optimizer.param_groups[0]["lr"] *
                        (FLAGS.loss_scale if FLAGS.amp else 1))
                else:
                    metric_logger.update(
                        step_time=timer.measured,
                        loss=moving_loss.item() / print_freq /
                        (FLAGS.loss_scale if FLAGS.amp else 1),
                        lr=mlp_optimizer.param_groups[0]["lr"] *
                        (FLAGS.loss_scale if FLAGS.amp else 1))
                stop_time = time()

                eta_str = datetime.timedelta(
                    seconds=int(metric_logger.step_time.global_avg *
                                (steps_per_epoch - step)))
                metric_logger.print(
                    header=
                    f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}"
                )

                with torch.cuda.stream(moving_loss_stream):
                    moving_loss = 0.

            if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                auc = dist_evaluate(model, data_loader_test)

                if auc is None:
                    continue

                print(f"Epoch {epoch} step {step}. auc {auc:.6f}")
                stop_time = time()

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    run_time_s = int(stop_time - start_time)
                    print(
                        f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                        f"{global_step / steps_per_epoch:.2f} in {run_time_s}s. "
                        f"Average speed {global_step * FLAGS.batch_size / run_time_s:.1f} records/s."
                    )
                    sys.exit()

        epoch_stop_time = time()
        epoch_time_s = epoch_stop_time - epoch_start_time
        print(
            f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. "
            f"Average speed {steps_per_epoch * FLAGS.batch_size / epoch_time_s:.1f} records/s."
        )

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    if FLAGS.save_checkpoint_path:
        checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path,
                                          epoch, step)

    results = {
        'best_auc': best_auc,
        'best_epoch': best_epoch,
        'average_train_throughput': avg_throughput
    }

    dllogger.log(data=results, step=tuple())
Ejemplo n.º 9
0
def main():
    args = get_args()

    log.info(f'Parsed arguments: \n{pformat(args.__dict__)}')
    assert args.cond_type.lower() in ['none', 'platanios', 'oestling']

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    log.info('Using device {}.'.format(device))

    use_apex = False
    if torch.cuda.is_available() and args.fp16:
        log.info('Loading Nvidia Apex and using AMP')
        from apex import amp, optimizers
        use_apex = True
    else:
        log.info('Using FP32')
        amp = None

    log.info(f'Using time stamp {timestamp} to save models and logs.')

    if not args.no_seed:
        log.info(f'Setting random seed to {args.seed} for reproducibility.')
        torch.manual_seed(args.seed)
        random.seed(args.seed)

    data = Corpus(args.datadir)

    data_splits = [
        {
            'split': 'train',
            'languages': args.dev_langs + args.target_langs,
            'invert_include': True,
        },
        {
            'split': 'valid',
            'languages': args.dev_langs,
        },
        {
            'split': 'test',
            'languages': args.target_langs,
        },
    ]

    if args.refine:
        data_splits.append({
            'split': 'train_100',
            'languages': args.target_langs,
            'ignore_missing': True,
        })

    data_splits = data.make_datasets(data_splits, force_rebuild=args.rebuild)
    train_set, val_set, test_set = data_splits['train'], data_splits[
        'valid'], data_splits['test']
    dictionary = data_splits['dictionary']

    train_language_distr = get_sampling_probabilities(train_set, 1.0)
    train_set = Dataset(train_set,
                        batchsize=args.batchsize,
                        bptt=args.bptt,
                        reset_on_iter=True,
                        language_probabilities=train_language_distr)
    val_set = Dataset(val_set,
                      make_config=True,
                      batchsize=args.valid_batchsize,
                      bptt=args.bptt,
                      eval=True)
    test_set = Dataset(test_set,
                       make_config=True,
                       batchsize=args.test_batchsize,
                       bptt=args.bptt,
                       eval=True)

    train_loader = DataLoader(train_set, num_workers=args.workers)
    val_loader = DataLoader(val_set, num_workers=args.workers)
    test_loader = DataLoader(test_set, num_workers=args.workers)

    if args.refine:
        refine_set = dict()
        for lang, lang_d in data_splits['train_100'].items():
            refine_set[lang] = Dataset({lang: lang_d},
                                       batchsize=args.valid_batchsize,
                                       bptt=args.bptt,
                                       make_config=True)

    n_token = len(dictionary.idx2tkn)

    # Load and preprocess matrix of typological features
    # TODO: implement this, the OEST
    # prior_matrix = load_prior(args.prior, corpus.dictionary.lang2idx)
    # n_components = min(50, *prior_matrix.shape)
    # pca = PCA(n_components=n_components, whiten=True)
    # prior_matrix = pca.fit_transform(prior_matrix)
    prior = None

    model = RNN(args.cond_type,
                prior,
                n_token,
                n_input=args.emsize,
                n_hidden=args.nhidden,
                n_layers=args.nlayers,
                dropout=args.dropouto,
                dropoute=args.dropoute,
                dropouth=args.dropouth,
                dropouti=args.dropouti,
                wdrop=args.wdrop,
                wdrop_layers=[0, 1, 2],
                tie_weights=True).to(device)

    if args.opt_level != 'O2':
        loss_function = SplitCrossEntropyLoss(args.emsize,
                                              splits=[]).to(device)
    else:
        loss_function = CrossEntropyLoss().to(
            device)  # Should be ok to use with a vocabulary of this small size

    if use_apex:
        optimizer = optimizers.FusedAdam(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.wdecay)
    else:
        params = list(filter(lambda p: p.requires_grad,
                             model.parameters())) + list(
                                 loss_function.parameters())
        optimizer = Adam(params, lr=args.lr, weight_decay=args.wdecay)

    if use_apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)

    parameters = {
        'model': model,
        'optimizer': optimizer,
        'loss_function': loss_function,
        'use_apex': use_apex,
        'amp': amp if use_apex else None,
        'clip': args.clip,
        'alpha': args.alpha,
        'beta': args.beta,
        'bptt': args.bptt,
        'device': device,
        'prior': args.prior,
    }

    # Add backward hook for gradient clipping
    if args.clip:
        if use_apex:
            for p in amp.master_params(optimizer):
                p.register_hook(
                    lambda grad: torch.clamp(grad, -args.clip, args.clip))
        else:
            for p in model.parameters():
                p.register_hook(
                    lambda grad: torch.clamp(grad, -args.clip, args.clip))

    if args.prior == 'vi':
        prior = VIPrior(model, device=device)
        parameters['prior'] = prior

        def sample_weights(module: torch.nn.Module, input: torch.Tensor):
            prior.sample_weights(module)

        sample_weights_hook = model.register_forward_pre_hook(sample_weights)

    # Load model checkpoint if available
    start_epoch = 1
    if args.resume:
        if args.checkpoint is None:
            log.error(
                'No checkpoint passed. Specify it using the --checkpoint flag')
            checkpoint = None
        else:
            log.info('Loading the checkpoint at {}'.format(args.checkpoint))
            checkpoint = load_model(args.checkpoint, **parameters)

            start_epoch = checkpoint['epoch']

        if args.wdrop:
            for rnn in model.rnns:
                if isinstance(rnn, WeightDrop):
                    rnn.dropout = args.wdrop
                elif rnn.zoneout > 0:
                    rnn.zoneout = args.wdrop

    saved_models = list()

    result_str = '| Language {} | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'

    def test():
        log.info('=' * 89)
        log.info('Running test set (zero-shot results)...')
        test_loss, avg_loss = evaluate(test_loader, **parameters)
        log.info('Test set finished | test loss {} | test bpc {}'.format(
            test_loss, test_loss / math.log(2)))

        for lang, avg_l_loss in avg_loss.items():
            langstr = dictionary.idx2lang[lang]
            log.info(
                result_str.format(langstr, avg_l_loss, math.exp(avg_l_loss),
                                  avg_l_loss / math.log(2)))

        log.info('=' * 89)

    if args.train:
        f = 1.
        stored_loss = 1e32
        epochs_no_improve = 0

        val_losses = list()

        # calculate specific language lr
        data_spec_count = sum([len(ds) for l, ds in train_set.data.items()])
        data_spec_avg = data_spec_count / len(train_set.data.items())
        data_spec_lrweights = dict([(l, data_spec_avg / len(ds))
                                    for l, ds in train_set.data.items()])

        # estimate total number of steps
        total_steps = sum(
            [len(ds) // args.bptt
             for l, ds in train_set.data.items()]) * args.no_epochs
        steps = 0

        try:
            pbar = tqdm.trange(start_epoch,
                               args.no_epochs + 1,
                               position=1,
                               dynamic_ncols=True)
            for epoch in pbar:

                steps = train(train_loader,
                              lr_weights=data_spec_lrweights,
                              **parameters,
                              total_steps=total_steps,
                              steps=steps,
                              scaling=args.scaling,
                              n_samples=args.n_samples,
                              tb_writer=tb_writer)

                val_loss, _ = evaluate(val_loader, **parameters)
                pbar.set_description('Epoch {} | Val loss {}'.format(
                    epoch, val_loss))

                # Save model
                if args.prior == 'vi':
                    sample_weights_hook.remove()

                filename = path.join(
                    args.checkpoint_dir, '{}_epoch{}{}_{}.pth'.format(
                        timestamp, epoch, '_with_apex' if use_apex else '',
                        args.prior))
                torch.save(make_checkpoint(epoch + 1, **parameters), filename)
                saved_models.append(filename)

                if args.prior == 'vi':
                    sample_weights_hook = model.register_forward_pre_hook(
                        sample_weights)

                # Early stopping
                if val_loss < stored_loss:
                    epochs_no_improve = 0
                    stored_loss = val_loss
                else:
                    epochs_no_improve += 1

                if epochs_no_improve == args.patience:
                    log.info('Early stopping at epoch {}'.format(epoch))
                    break

                val_losses.append(val_loss)

                # Reduce lr every 1/3 total epochs
                if epoch - 1 > f / 3 * args.no_epochs:
                    log.info('Epoch {}/{}. Dividing LR by 10'.format(
                        epoch, args.no_epochs))
                    for g in optimizer.param_groups:
                        g['lr'] = g['lr'] / 10

                    f += 1.
            test()
        except KeyboardInterrupt:
            log.info('Registered KeyboardInterrupt. Stopping training.')
            log.info('Saving last model to disk')

            if args.prior == 'vi':
                sample_weights_hook.remove()

            torch.save(
                make_checkpoint(epoch, **parameters),
                path.join(
                    args.checkpoint_dir, '{}_epoch{}{}_{}.pth'.format(
                        timestamp, epoch, '_with_apex' if use_apex else '',
                        args.prior)))
            return
    elif args.test:
        test()

    # Only test on existing languages if there are no held out languages
    if not args.target_langs:
        exit(0)

    importance = 1e-5

    # If use UNIV, calculate informed prior, else use boring prior
    if args.prior == 'laplace':
        if not isinstance(
                prior,
                LaplacePrior):  # only calculate matrix if it is not supplied.
            log.info('Creating laplace approximation dataset')
            laplace_set = Dataset(data_splits['train'],
                                  batchsize=args.batchsize,
                                  bptt=100,
                                  reset_on_iter=True)
            laplace_loader = DataLoader(laplace_set, num_workers=args.workers)
            log.info('Creating Laplacian prior')
            prior = LaplacePrior(model,
                                 loss_function,
                                 laplace_loader,
                                 use_apex=use_apex,
                                 amp=amp,
                                 device=device)
            parameters['prior'] = prior

            torch.save(
                make_checkpoint('fisher_matrix', **parameters),
                path.join(
                    args.checkpoint_dir, '{}_fishers_matrix{}_{}.pth'.format(
                        timestamp, '_with_apex' if use_apex else '',
                        args.prior)))
        importance = 1e5

    elif args.prior == 'ninf':
        log.info('Creating non-informative Gaussian prior')
        parameters['prior'] = GaussianPrior()
    elif args.prior == 'vi':
        importance = 1e-5
    elif args.prior == 'hmc':
        raise NotImplementedError
    else:
        raise ValueError(
            f'Passed prior {args.prior} is not an implemented inference technique.'
        )

    best_model = saved_models[-1] if not len(
        saved_models) == 0 else args.checkpoint

    # Remove sampling hook from model
    if args.prior == 'vi':
        sample_weights_hook.remove()

    # Refine on 100 samples on each target
    if args.refine:
        # reset learning rate
        optimizer.param_groups[0]['lr'] = args.lr
        loss = 0

        results = dict()

        # Create individual tests sets
        test_sets = dict()
        for lang, lang_d in data_splits['test'].items():
            test_sets[lang] = DataLoader(Dataset({lang: lang_d},
                                                 make_config=True,
                                                 batchsize=args.test_batchsize,
                                                 bptt=args.bptt,
                                                 eval=True),
                                         num_workers=args.workers)

        for lang, lang_data in tqdm.tqdm(refine_set.items()):
            final_loss = False
            refine_dataloader = DataLoader(lang_data, num_workers=args.workers)
            load_model(best_model, **parameters)

            log.info(f'Refining for language {dictionary.idx2lang[lang]}')
            for epoch in range(1, args.refine_epochs + 1):
                refine(refine_dataloader, **parameters, importance=importance)
                if epoch % 5 == 0:
                    final_loss = True
                    loss, avg_loss = evaluate(test_sets[lang],
                                              model,
                                              loss_function,
                                              only_l=lang,
                                              report_all=True,
                                              device=device)

                    for lang, avg_l_loss in avg_loss.items():
                        langstr = dictionary.idx2lang[lang]
                        log.debug(
                            result_str.format(langstr, avg_l_loss,
                                              math.exp(avg_l_loss),
                                              avg_l_loss / math.log(2)))

            if not final_loss:
                loss, avg_loss = evaluate(test_sets[lang],
                                          model,
                                          loss_function,
                                          only_l=lang,
                                          report_all=True,
                                          device=device)

            for lang, avg_l_loss in avg_loss.items():
                langstr = dictionary.idx2lang[lang]
                log.info(
                    result_str.format(langstr, avg_l_loss,
                                      math.exp(avg_l_loss),
                                      avg_l_loss / math.log(2)))
                results[lang] = avg_l_loss

        log.info('=' * 89)
        log.info('FINAL FEW SHOT RESULTS: ')
        log.info('=' * 89)
        for lang, avg_l_loss in results.items():
            langstr = dictionary.idx2lang[lang]
            log.info(
                result_str.format(langstr, avg_l_loss, math.exp(avg_l_loss),
                                  avg_l_loss / math.log(2)))
        log.info('=' * 89)
Ejemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--model_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="")
    parser.add_argument("--my_config", default=None, type=str, required=True)
    parser.add_argument("--feature_path",
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )

    ## Other parameters
    parser.add_argument("--train_pattern",
                        default=None,
                        type=str,
                        help="training data path.")
    parser.add_argument("--valid_pattern",
                        default=None,
                        type=str,
                        help="validation data path.")
    parser.add_argument("--test_pattern",
                        default=None,
                        type=str,
                        help="test data path.")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--report_steps",
                        default=100,
                        type=int,
                        help="report steps when training.")
    parser.add_argument("--train_batch_size",
                        default=4,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--learning_rate2",
                        default=0,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_steps", default=100, type=int)
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
        "of training.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--train_size',
                        type=int,
                        default=10000,
                        help="Use how many train data")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--frame_name',
                        type=str,
                        default='elgeish/cs224n-squad2.0-albert-large-v2')
    parser.add_argument('--DataName', type=str, default="SST")
    parser.add_argument("--adam_epsilon",
                        default=1e-6,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument(
        '--run_og',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    args = parser.parse_args()
    #print(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        #        torch.cuda.set_device(args.local_rank)
        #        print(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        n_gpu = 1

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_pattern:
            raise ValueError(
                "If `do_train` is True, then `train_pattern` must be specified."
            )

    if args.do_predict:
        if not args.test_pattern:
            raise ValueError(
                "If `do_predict` is True, then `test_pattern` must be specified."
            )

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Prepare model
    my_config = Config(args.my_config)
    my_config.num_edge_types = sum(EdgePosition.max_edge_types)
    #    my_config.forward_edges = [EdgeType.TOKEN_TO_SENTENCE,
    #                               EdgeType.SENTENCE_TO_PARAGRAPH,
    #                               EdgeType.PARAGRAPH_TO_DOCUMENT]
    #print(my_config)
    if args.do_train:

        model = NqModel(my_config=my_config, args=args)
        #model_dict = model.state_dict()
        #pretrained_model_dict = torch.load(pretrained_model_file, map_location=lambda storage, loc: storage)
        #pretrained_model_dict = {k: v for k, v in pretrained_model_dict.items() if k in model_dict.keys()}
        #model_dict.update(pretrained_model_dict)
        #model.load_state_dict(model_dict)
#    else:
#        pretrained_config_file = os.path.join(args.model_dir, CONFIG_NAME)
##        bert_config = BertConfig(pretrained_config_file)
#        model = NqModel( my_config=my_config)
#        pretrained_model_file = os.path.join(args.model_dir, WEIGHTS_NAME)
#        model.load_state_dict(torch.load(pretrained_model_file))

# if args.fp16:
#     model.half()
    global run_og
    run_og = args.run_og
    if args.run_og:
        if n_gpu:
            model.cuda()
        if args.local_rank != -1:
            #            model = torch.nn.parallel.DistributedDataParallel(model,find_unused_parameters=True)
            model = torch.nn.parallel.DistributedDataParallel(model)

    else:
        model.bert.to("cuda:0")
        model.encoder.to("cuda:1")
        model.tok_outputs.to("cuda:0")
        model.tok_dense.to("cuda:0")
        model.dropout.to("cuda:0")

#    if args.local_rank != -1:
#        try:
#            from apex.parallel import DistributedDataParallel as DDP
#        except ImportError:
#            raise ImportError(
#                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
#
#        model = DDP(model)
#    elif n_gpu > 1:
#        model = torch.nn.DataParallel(model)

    num_train_features = None
    num_train_optimization_steps = None

    #train_dataset = None
    #train_features = None

    #    output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
    #    model.load_state_dict(torch.load(output_model_file))

    if args.do_train:
        num_train_features = 0
        for data_path in glob(args.train_pattern):
            train_dataset = NqDataset(args, data_path, is_training=True)
            train_features = train_dataset.features
            num_train_features += len(train_dataset.features)
        print(num_train_features, args.train_batch_size,
              args.gradient_accumulation_steps)
        num_train_optimization_steps = int(
            (num_train_features / args.train_batch_size) /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in param_optimizer
            if (not any(nd in n for nd in no_decay)) and 'bert' in n
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in param_optimizer
            if any(nd in n for nd in no_decay) and 'bert' in n
        ],
        'weight_decay':
        0.00
    }, {
        'params': [
            p for n, p in param_optimizer
            if (not any(nd in n for nd in no_decay)) and 'bert' not in n
        ],
        'weight_decay':
        args.weight_decay,
        'lr':
        args.learning_rate2 if args.learning_rate2 != 0 else args.learning_rate
    }, {
        'params': [
            p for n, p in param_optimizer
            if any(nd in n for nd in no_decay) and 'bert' not in n
        ],
        'weight_decay':
        0.00,
        'lr':
        args.learning_rate2 if args.learning_rate2 != 0 else args.learning_rate
    }]

    if args.fp16:
        # optimizer = apex_optim.FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        optimizer = apex_optim.FusedAdam(optimizer_grouped_parameters,
                                         lr=args.learning_rate)
        model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

    else:

        #        optimizer = BertAdam(optimizer_grouped_parameters,
        #                             lr=args.learning_rate,
        #                             warmup=args.warmup_proportion,
        #                             t_total=num_train_optimization_steps)

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)


#        optimizer = SGD(optimizer_grouped_parameters, lr=args.learning_rate,momentum=0.9)

    if args.warmup_steps > 0:
        args.warmup_proportion = min(
            args.warmup_proportion,
            args.warmup_steps / num_train_optimization_steps)
    scheduler = WarmupLinearSchedule(
        optimizer,
        warmup_steps=int(args.warmup_proportion * num_train_optimization_steps)
        if args.warmup_proportion > 0 else args.warmup_steps,
        t_total=num_train_optimization_steps)
    # scheduler = WarmupConstantSchedule(optimizer,
    #                              warmup_steps=int(args.warmup_proportion * num_train_optimization_steps)
    #                              if args.warmup_proportion > 0 else args.warmup_steps)
    #print(get_lr(optimizer))
    # logger.info("Get lr:{} {}".format(get_lr(optimizer,0),get_lr(optimizer,-1)))
    # exit(0)

    global_step = 0
    last_acc = 87.0
    albert_toker = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num split examples = %d", num_train_features)
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        tr_loss, report_loss = 0.0, 0.0
        nb_tr_examples = 0
        model.zero_grad()
        optimizer.zero_grad()
        ErrorSelect = open("./Err_for_5ABLS_" + time.ctime() + ".txt", 'w+')
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            logging.info("Loggin TEST!")
            for data_path in glob(args.train_pattern):
                #logging.info("Reading data from {}.".format(data_path))
                model.train()
                train_dataset = NqDataset(args, data_path, is_training=True)
                train_features = train_dataset.features
                #logging.info("Data Load Done!")
                if args.local_rank == -1:
                    train_sampler = RandomSampler(train_features)
                else:
                    train_sampler = DistributedSampler(train_features)
                train_dataloader = DataLoader(train_features,
                                              sampler=train_sampler,
                                              batch_size=args.train_batch_size,
                                              collate_fn=batcher(
                                                  device, is_training=True),
                                              num_workers=0,
                                              pin_memory=True)
                train_features = train_dataset.features
                logging.info("Data ready {} ".format(len(train_features)))

                for step, batch in enumerate(train_dataloader):
                    loss = model(batch.input_ids.cuda(non_blocking=True),
                                 batch.input_mask.cuda(non_blocking=True),
                                 batch.segment_ids.cuda(non_blocking=True),
                                 batch.st_mask.cuda(non_blocking=True),
                                 (batch.edges_src.cuda(non_blocking=True),
                                  batch.edges_tgt.cuda(non_blocking=True),
                                  batch.edges_type.cuda(non_blocking=True),
                                  batch.edges_pos.cuda(non_blocking=True)),
                                 batch.label.cuda(non_blocking=True),
                                 batch.all_sen.cuda(non_blocking=True))
                    # print("Out!")
                    # sleep(3)
                    # model.report_scores(batch.input_ids)

                    if n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps
                    if args.local_rank != -1:
                        loss = loss + 0 * sum(
                            [x.sum() for x in model.parameters()])
                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()

                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    # torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),5.0)

                    if (step + 1) % args.gradient_accumulation_steps == 0:

                        #                        torch.cuda.empty_cache()
                        # torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),1.0)

                        optimizer.step()
                        scheduler.step()
                        optimizer.zero_grad()
                        global_step += 1

                    tr_loss += loss.item()
                    nb_tr_examples += 1

                    if (step + 1) % args.gradient_accumulation_steps == 0 and (
                            global_step + 1) % args.report_steps == 0 and (
                                args.local_rank == -1
                                or torch.distributed.get_rank() == 0):
                        # lr_this_step = get_lr(optimizer)
                        logging.info(
                            "Epoch={} iter={} lr1={:.12f} lr2={:.12f} train_ave_loss={:.6f} ."
                            .format(
                                # _, global_step, lr_this_step, tr_loss / nb_tr_examples))
                                _,
                                global_step,
                                get_lr(optimizer, 0),
                                get_lr(optimizer, -1),
                                (tr_loss - report_loss) / args.report_steps))
                        report_loss = tr_loss

            model.eval()
            model.zero_grad()
            model.ACC = model.ALL = 0
            train_dataset = NqDataset(args, "test.json", is_training=True)
            train_features = train_dataset.features
            #logging.info("Data Load Done!")

            if args.local_rank == -1:
                train_sampler = RandomSampler(train_features)
            else:
                train_sampler = DistributedSampler(train_features)
            if args.local_rank == -1:
                train_dataloader = DataLoader(train_features,
                                              sampler=train_sampler,
                                              batch_size=args.train_batch_size,
                                              collate_fn=batcher(
                                                  device, is_training=True),
                                              num_workers=0,
                                              pin_memory=True)
            else:
                train_dataloader = DataLoader(train_features,
                                              sampler=train_sampler,
                                              batch_size=args.train_batch_size,
                                              collate_fn=batcher(
                                                  device, is_training=True),
                                              num_workers=0,
                                              drop_last=True)

            train_features = train_dataset.features
            logging.info("Data ready {} ".format(len(train_features)))
            tgobal_step = 0
            ttr_loss = 0
            optimizer.zero_grad()
            logging.info("***** Running evalating *****")
            Err_cnt = 0
            Len_cnt = 0

            with torch.no_grad():
                for step, batch in enumerate(train_dataloader):
                    tgobal_step += 1
                    tmp_acc = model.ACC
                    loss = model(batch.input_ids.cuda(non_blocking=True),
                                 batch.input_mask.cuda(non_blocking=True),
                                 batch.segment_ids.cuda(non_blocking=True),
                                 batch.st_mask.cuda(non_blocking=True),
                                 (batch.edges_src.cuda(non_blocking=True),
                                  batch.edges_tgt.cuda(non_blocking=True),
                                  batch.edges_type.cuda(non_blocking=True),
                                  batch.edges_pos.cuda(non_blocking=True)),
                                 batch.label.cuda(non_blocking=True),
                                 batch.all_sen.cuda(non_blocking=True))
                    ttr_loss += loss.item()
                    if model.ACC == tmp_acc and _ != 0:
                        # WrOut = "Model Select:\n"
                        # for i in albert_toker.convert_ids_to_tokens(batch.input_ids[0][model.model_choice]):
                        #     if i !='<pad>':
                        #         WrOut+=str(i)
                        #     else: break
                        # Len_cnt+=len(WrOut)-14

                        # ErrorSelect.write(WrOut)
                        # WrOut = "\nTrue answer:\n"
                        # for i in albert_toker.convert_ids_to_tokens(batch.input_ids[0][model.ground_answer]):
                        #     if i !='<pad>':
                        #         WrOut+=str(i)
                        #     else: break
                        ErrorSelect.write(model.report_scores(batch.input_ids))
                        # ErrorSelect.write("\n")
                        Err_cnt += 1
            model.do_report = False
            logging.info("ACC:{}% LOSS:{}".format(model.ACC / model.ALL * 100,
                                                  ttr_loss / tgobal_step))
            if _ != 0:
                logging.info(
                    "Error count:{} Average Wrong QA lengths:{}".format(
                        Err_cnt, Len_cnt / Err_cnt))

            model.zero_grad()
            optimizer.zero_grad()
            model.encoder.TopNet[0].improveit()  #Use scheludar K
            logging.info("Next K use:{}".format(model.encoder.TopNet[0].k))
Ejemplo n.º 11
0
def train(params, args, world_rank):
    logging.info('rank %d, begin data loader init' % world_rank)
    train_data_loader = get_data_loader_distributed(params, world_rank)
    test_data_loader = get_data_loader_distributed_test(params, world_rank)
    logging.info('rank %d, data loader initialized' % world_rank)
    model = UNet.UNet(params).cuda()
    if not args.resuming:
        model.apply(model.get_weights_function(params.weight_init))

    optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision
    if params.distributed:
        model = DistributedDataParallel(model)

    iters = 0
    startEpoch = 0
    checkpoint = None
    if args.resuming:
        if world_rank == 0:
            logging.info("Loading checkpoint %s" % params.checkpoint_path)
        checkpoint = torch.load(params.checkpoint_path,
                                map_location='cuda:{}'.format(args.local_rank))
        model.load_state_dict(checkpoint['model_state'])
        iters = checkpoint['iters']
        startEpoch = checkpoint['epoch'] + 1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if world_rank == 0:
        logging.info(model)
        logging.info("Starting Training Loop...")

    device = torch.cuda.current_device()
    for epoch in range(startEpoch, startEpoch + params.num_epochs):
        start = time.time()
        tr_time = 0.
        log_time = 0.

        for i, data in enumerate(train_data_loader, 0):
            iters += 1
            adjust_LR(optimizer, params, iters)
            inp, tar = map(lambda x: x.to(device), data)
            tr_start = time.time()
            b_size = inp.size(0)

            model.zero_grad()
            gen = model(inp)
            loss = UNet.loss_func(gen, tar, params)

            loss.backward()  # fixed precision

            # automatic mixed precision:
            #with amp.scale_loss(loss, optimizer) as scaled_loss:
            #  scaled_loss.backward()

            optimizer.step()

            tr_end = time.time()
            tr_time += tr_end - tr_start

        # Output training stats
        if world_rank == 0:
            log_start = time.time()
            gens = []
            tars = []
            with torch.no_grad():
                for i, data in enumerate(test_data_loader, 0):
                    if i >= 50:
                        break
                    inp, tar = map(lambda x: x.to(device), data)
                    gen = model(inp)
                    gens.append(gen.detach().cpu().numpy())
                    tars.append(tar.detach().cpu().numpy())
            gens = np.concatenate(gens, axis=0)
            tars = np.concatenate(tars, axis=0)

            # Scalars
            args.tboard_writer.add_scalar('G_loss', loss.item(), iters)

            # Plots
            fig = plot_gens_tars(gens, tars)
            #fig, chi, L1score = meanL1(gens, tars)
            #args.tboard_writer.add_figure('pixhist', fig, iters, close=True)
            #args.tboard_writer.add_scalar('Metrics/chi', chi, iters)
            #args.tboard_writer.add_scalar('Metrics/rhoL1', L1score[0], iters)
            #args.tboard_writer.add_scalar('Metrics/vxL1', L1score[1], iters)
            #args.tboard_writer.add_scalar('Metrics/vyL1', L1score[2], iters)
            #args.tboard_writer.add_scalar('Metrics/vzL1', L1score[3], iters)
            #args.tboard_writer.add_scalar('Metrics/TL1', L1score[4], iters)
            #
            #fig = generate_images(inp.detach().cpu().numpy()[0], gens[-1], tars[-1])
            for figiter in range(5):
                figtag = 'test' + str(figiter)
                args.tboard_writer.add_figure(tag=figtag,
                                              figure=fig[figiter],
                                              close=True)
            #log_end = time.time()
            #log_time += log_end - log_start

            # Save checkpoint
            torch.save(
                {
                    'iters': iters,
                    'epoch': epoch,
                    'model_state': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, params.checkpoint_path)

        end = time.time()
        if world_rank == 0:
            logging.info('Time taken for epoch {} is {} sec'.format(
                epoch + 1, end - start))
            logging.info('train step time={}, logging time={}'.format(
                tr_time, log_time))
Ejemplo n.º 12
0
def train(params, args, world_rank, local_rank):

    #logging info
    logging.info('rank {:d}, begin data loader init (local rank {:d})'.format(
        world_rank, local_rank))

    # set device
    device = torch.device("cuda:{}".format(local_rank))

    # data loader
    pipe = dl.DaliPipeline(params,
                           num_threads=params.num_data_workers,
                           device_id=device.index)
    pipe.build()
    train_data_loader = DALIGenericIterator([pipe], ['inp', 'tar'],
                                            params.Nsamples,
                                            auto_reset=True)
    logging.info('rank %d, data loader initialized' % world_rank)

    model = UNet.UNet(params).to(device)

    if not args.resuming:
        model.apply(model.get_weights_function(params.weight_init))

    optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision
    if params.distributed:
        model = DDP(model,
                    device_ids=[device.index],
                    output_device=device.index)

    # loss
    criterion = UNet.CosmoLoss(params.LAMBDA_2)

    # amp stuff
    if args.enable_amp:
        gscaler = amp.GradScaler()

    iters = 0
    startEpoch = 0
    checkpoint = None
    if args.resuming:
        if world_rank == 0:
            logging.info("Loading checkpoint %s" % params.checkpoint_path)
        checkpoint = torch.load(params.checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state'])
        iters = checkpoint['iters']
        startEpoch = checkpoint['epoch'] + 1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if world_rank == 0:
        logging.info(model)
        logging.info("Starting Training Loop...")

    with torch.autograd.profiler.emit_nvtx():
        for epoch in range(startEpoch, startEpoch + params.num_epochs):

            if args.global_timing:
                dist.barrier()

            start = time.time()
            epoch_step = 0
            tr_time = 0.
            fw_time = 0.
            bw_time = 0.
            log_time = 0.

            model.train()
            for data in train_data_loader:
                torch.cuda.nvtx.range_push("cosmo3D:step {}".format(iters))
                tr_start = time.time()
                adjust_LR(optimizer, params, iters)

                # fetch data
                inp = data[0]["inp"]
                tar = data[0]["tar"]

                if not args.io_only:
                    torch.cuda.nvtx.range_push(
                        "cosmo3D:forward {}".format(iters))
                    # fw pass
                    fw_time -= time.time()
                    optimizer.zero_grad()
                    with amp.autocast(args.enable_amp):
                        gen = model(inp)
                        loss = criterion(gen, tar)
                    fw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                    # bw pass
                    torch.cuda.nvtx.range_push(
                        "cosmo3D:backward {}".format(iters))
                    bw_time -= time.time()
                    if args.enable_amp:
                        gscaler.scale(loss).backward()
                        gscaler.step(optimizer)
                        gscaler.update()
                    else:
                        loss.backward()
                        optimizer.step()
                    bw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                iters += 1
                epoch_step += 1

                # step done
                tr_end = time.time()
                tr_time += tr_end - tr_start
                torch.cuda.nvtx.range_pop()

            # epoch done
            if args.global_timing:
                dist.barrier()

            end = time.time()
            epoch_time = end - start
            step_time = epoch_time / float(epoch_step)
            tr_time /= float(epoch_step)
            fw_time /= float(epoch_step)
            bw_time /= float(epoch_step)
            io_time = max([step_time - fw_time - bw_time, 0])
            iters_per_sec = 1. / step_time
            fw_per_sec = 1. / tr_time

            if world_rank == 0:
                logging.info('Time taken for epoch {} is {} sec'.format(
                    epoch + 1, epoch_time))
                logging.info(
                    'train step time = {} ({} steps), logging time = {}'.
                    format(tr_time, epoch_step, log_time))
                logging.info('train samples/sec = {} fw steps/sec = {}'.format(
                    iters_per_sec, fw_per_sec))
    def __init__(self,
                 config,
                 batch_slices,
                 seq_slices,
                 distributed_init_method,
                 world_size,
                 data_parallel_size,
                 model_parallel_size,
                 pipeline_parallel_size,
                 rank,
                 local_rank,
                 mixed_precision=False,
                 use_mpi=False,
                 init_process_group=False,
                 checkpoint_gradients=False):
        self.config = config
        self.batch_slices = batch_slices
        self.seq_slices = seq_slices
        torch.cuda.set_device(local_rank)
        if init_process_group:
            dist.init_process_group(
                backend='nccl',
                init_method=distributed_init_method,
                world_size=world_size,
                rank=rank,
            )
        dist.all_reduce(torch.zeros(1).cuda())
        mpu.initialize_model_parallel(model_parallel_size,
                                      pipeline_parallel_size)
        set_random_seed(0)
        mpu.model_parallel_cuda_manual_seed(0)
        self.rank = rank
        self.local_rank = local_rank
        self.world_size = world_size
        self.data_parallel_size = data_parallel_size
        self.model_parallel_size = model_parallel_size
        self.pipeline_parallel_size = pipeline_parallel_size
        self.pipeline_parallel_group_rank = mpu.get_pipeline_parallel_group_rank(
        )
        self.data_parallel_group = mpu.get_data_parallel_group()
        self.model_parallel_group = mpu.get_model_parallel_group()
        self.pipeline_parallel_pred_group = mpu.get_pipeline_parallel_pred_group(
        )
        self.pipeline_parallel_succ_group = mpu.get_pipeline_parallel_succ_group(
        )
        self.model_parallel_src_rank = mpu.get_model_parallel_src_rank()
        self.model_parallel_dst_rank = mpu.get_model_parallel_dst_rank()
        self.model_parallel_next_src_rank = (
            self.model_parallel_src_rank + self.model_parallel_size if
            self.pipeline_parallel_group_rank < self.pipeline_parallel_size - 1
            else None)
        self.model_parallel_prev_dst_rank = (
            self.model_parallel_dst_rank - self.model_parallel_size
            if self.pipeline_parallel_group_rank > 0 else None)

        self.n_layers = (config.n_layers // pipeline_parallel_size +
                         int(rank < config.n_layers % pipeline_parallel_size))
        self.config = config
        self.mixed_precision = mixed_precision
        self.checkpoint_gradients = checkpoint_gradients

        self.layers = []
        for _ in range(self.n_layers):
            l = ModelParallelTransformerLayer(
                self.config.embedding_dim,
                self.config.ffn_embedding_dim,
                self.config.num_attention_heads,
                device="cuda",
                checkpoint_gradients=self.checkpoint_gradients)
            self.layers.append(l.half() if self.mixed_precision else l)

        self.all_parameters = []
        for layer in self.layers:
            self.all_parameters.extend(layer.parameters())
        self.n_params = len(self.all_parameters)

        if self.mixed_precision:
            self.master_parameters = [
                p.clone().detach().float() for p in self.all_parameters
            ]
            for p in self.master_parameters:
                p.requires_grad_()
            self.optimizer = optimizers.FusedAdam(self.master_parameters,
                                                  lr=1e-10)
        else:
            self.optimizer = torch.optim.Adam(self.all_parameters, lr=1e-10)
Ejemplo n.º 14
0
def main(argv):
    torch.manual_seed(FLAGS.seed)

    utils.init_logging(log_path=FLAGS.log_path)

    use_gpu = "cpu" not in FLAGS.base_device.lower()
    rank, world_size, gpu = dist.init_distributed_mode(backend=FLAGS.backend, use_gpu=use_gpu)
    device = FLAGS.base_device

    feature_spec = load_feature_spec(FLAGS)

    cat_feature_count = len(get_embedding_sizes(feature_spec, None))
    validate_flags(cat_feature_count)

    if is_main_process():
        dllogger.log(data=FLAGS.flag_values_dict(), step='PARAMETER')

    FLAGS.set_default("test_batch_size", FLAGS.test_batch_size // world_size * world_size)

    feature_spec = load_feature_spec(FLAGS)
    world_embedding_sizes = get_embedding_sizes(feature_spec, max_table_size=FLAGS.max_table_size)
    world_categorical_feature_sizes = np.asarray(world_embedding_sizes)
    device_mapping = get_device_mapping(world_embedding_sizes, num_gpus=world_size)

    batch_sizes_per_gpu = get_gpu_batch_sizes(FLAGS.batch_size, num_gpus=world_size)
    batch_indices = tuple(np.cumsum([0] + list(batch_sizes_per_gpu)))  # todo what does this do

    # Embedding sizes for each GPU
    categorical_feature_sizes = world_categorical_feature_sizes[device_mapping['embedding'][rank]].tolist()
    num_numerical_features = feature_spec.get_number_of_numerical_features()

    bottom_mlp_sizes = FLAGS.bottom_mlp_sizes if rank == device_mapping['bottom_mlp'] else None

    data_loader_train, data_loader_test = get_data_loaders(FLAGS, device_mapping=device_mapping,
                                                           feature_spec=feature_spec)

    model = DistributedDlrm(
        vectors_per_gpu=device_mapping['vectors_per_gpu'],
        embedding_device_mapping=device_mapping['embedding'],
        embedding_type=FLAGS.embedding_type,
        embedding_dim=FLAGS.embedding_dim,
        world_num_categorical_features=len(world_categorical_feature_sizes),
        categorical_feature_sizes=categorical_feature_sizes,
        num_numerical_features=num_numerical_features,
        hash_indices=FLAGS.hash_indices,
        bottom_mlp_sizes=bottom_mlp_sizes,
        top_mlp_sizes=FLAGS.top_mlp_sizes,
        interaction_op=FLAGS.interaction_op,
        fp16=FLAGS.amp,
        use_cpp_mlp=FLAGS.optimized_mlp,
        bottom_features_ordered=FLAGS.bottom_features_ordered,
        device=device
    )

    dist.setup_distributed_print(is_main_process())

    # DDP introduces a gradient average through allreduce(mean), which doesn't apply to bottom model.
    # Compensate it with further scaling lr
    if FLAGS.Adam_embedding_optimizer:
        embedding_model_parallel_lr = FLAGS.lr
    else:
        embedding_model_parallel_lr = FLAGS.lr / world_size

    if FLAGS.Adam_MLP_optimizer:
        MLP_model_parallel_lr = FLAGS.lr
    else:
        MLP_model_parallel_lr = FLAGS.lr / world_size

    data_parallel_lr = FLAGS.lr

    if is_main_process():
        mlp_params = [
            {'params': list(model.top_model.parameters()), 'lr': data_parallel_lr},
            {'params': list(model.bottom_model.mlp.parameters()), 'lr': MLP_model_parallel_lr}
        ]
        mlp_lrs = [data_parallel_lr, MLP_model_parallel_lr]
    else:
        mlp_params = [
            {'params': list(model.top_model.parameters()), 'lr': data_parallel_lr}
        ]
        mlp_lrs = [data_parallel_lr]

    if FLAGS.Adam_MLP_optimizer:
        mlp_optimizer = apex_optim.FusedAdam(mlp_params)
    else:
        mlp_optimizer = apex_optim.FusedSGD(mlp_params)

    embedding_params = [{
        'params': list(model.bottom_model.embeddings.parameters()),
        'lr': embedding_model_parallel_lr
    }]
    embedding_lrs = [embedding_model_parallel_lr]

    if FLAGS.Adam_embedding_optimizer:
        embedding_optimizer = torch.optim.SparseAdam(embedding_params)
    else:
        embedding_optimizer = torch.optim.SGD(embedding_params)

    checkpoint_writer = make_distributed_checkpoint_writer(
        device_mapping=device_mapping,
        rank=rank,
        is_main_process=is_main_process(),
        config=FLAGS.flag_values_dict()
    )

    checkpoint_loader = make_distributed_checkpoint_loader(device_mapping=device_mapping, rank=rank)

    if FLAGS.load_checkpoint_path:
        checkpoint_loader.load_checkpoint(model, FLAGS.load_checkpoint_path)
        model.to(device)

    scaler = torch.cuda.amp.GradScaler(enabled=FLAGS.amp, growth_interval=int(1e9))

    def parallelize(model):
        if world_size <= 1:
            return model

        if use_gpu:
            model.top_model = parallel.DistributedDataParallel(model.top_model)
        else:  # Use other backend for CPU
            model.top_model = torch.nn.parallel.DistributedDataParallel(model.top_model)
        return model

    if FLAGS.mode == 'test':
        model = parallelize(model)
        auc = dist_evaluate(model, data_loader_test)

        results = {'auc': auc}
        dllogger.log(data=results, step=tuple())

        if auc is not None:
            print(f"Finished testing. Test auc {auc:.4f}")
        return
    elif FLAGS.mode == 'inference_benchmark':
        if world_size > 1:
            raise ValueError('Inference benchmark only supports singleGPU mode.')

        results = {}

        if FLAGS.amp:
            # can use pure FP16 for inference
            model = model.half()

        for batch_size in FLAGS.inference_benchmark_batch_sizes:
            FLAGS.test_batch_size = batch_size
            _, data_loader_test = get_data_loaders(FLAGS, device_mapping=device_mapping, feature_spec=feature_spec)

            latencies = inference_benchmark(model=model, data_loader=data_loader_test,
                                            num_batches=FLAGS.inference_benchmark_steps,
                                            cuda_graphs=FLAGS.cuda_graphs)

            # drop the first 10 as a warmup
            latencies = latencies[10:]

            mean_latency = np.mean(latencies)
            mean_inference_throughput = batch_size / mean_latency
            subresult = {f'mean_inference_latency_batch_{batch_size}': mean_latency,
                         f'mean_inference_throughput_batch_{batch_size}': mean_inference_throughput}
            results.update(subresult)
        dllogger.log(data=results, step=tuple())
        return

    if FLAGS.save_checkpoint_path and not FLAGS.bottom_features_ordered and is_main_process():
        logging.warning("Saving checkpoint without --bottom_features_ordered flag will result in "
                        "a device-order dependent model. Consider using --bottom_features_ordered "
                        "if you plan to load the checkpoint in different device configurations.")

    loss_fn = torch.nn.BCEWithLogitsLoss(reduction="mean")

    # Print per 16384 * 2000 samples by default
    default_print_freq = 16384 * 2000 // FLAGS.batch_size
    print_freq = default_print_freq if FLAGS.print_freq is None else FLAGS.print_freq

    # last one will be dropped in the training loop
    steps_per_epoch = len(data_loader_train) - 1
    test_freq = FLAGS.test_freq if FLAGS.test_freq is not None else steps_per_epoch - 2

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{avg:.8f}'))
    metric_logger.add_meter('step_time', utils.SmoothedValue(window_size=1, fmt='{avg:.6f}'))
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    # Accumulating loss on GPU to avoid memcpyD2H every step
    moving_loss = torch.zeros(1, device=device)

    lr_scheduler = utils.LearningRateScheduler(optimizers=[mlp_optimizer, embedding_optimizer],
                                               base_lrs=[mlp_lrs, embedding_lrs],
                                               warmup_steps=FLAGS.warmup_steps,
                                               warmup_factor=FLAGS.warmup_factor,
                                               decay_start_step=FLAGS.decay_start_step,
                                               decay_steps=FLAGS.decay_steps,
                                               decay_power=FLAGS.decay_power,
                                               end_lr_factor=FLAGS.decay_end_lr / FLAGS.lr)

    def zero_grad(model):
        if FLAGS.Adam_embedding_optimizer or FLAGS.Adam_MLP_optimizer:
            model.zero_grad()
        else:
            # We don't need to accumulate gradient. Set grad to None is faster than optimizer.zero_grad()
            for param_group in itertools.chain(embedding_optimizer.param_groups, mlp_optimizer.param_groups):
                for param in param_group['params']:
                    param.grad = None

    def forward_backward(model, *args):

        numerical_features, categorical_features, click = args
        with torch.cuda.amp.autocast(enabled=FLAGS.amp):
            output = model(numerical_features, categorical_features, batch_sizes_per_gpu).squeeze()
            loss = loss_fn(output, click[batch_indices[rank]: batch_indices[rank + 1]])

        scaler.scale(loss).backward()

        return loss

    def weight_update():
        if not FLAGS.freeze_mlps:
            if FLAGS.Adam_MLP_optimizer:
                scale_MLP_gradients(mlp_optimizer, world_size)
            scaler.step(mlp_optimizer)

        if not FLAGS.freeze_embeddings:
            if FLAGS.Adam_embedding_optimizer:
                scale_embeddings_gradients(embedding_optimizer, world_size)
            scaler.unscale_(embedding_optimizer)
            embedding_optimizer.step()

        scaler.update()

    trainer = CudaGraphWrapper(model, forward_backward, parallelize, zero_grad,
                               cuda_graphs=FLAGS.cuda_graphs)

    data_stream = torch.cuda.Stream()
    timer = utils.StepTimer()

    best_auc = 0
    best_epoch = 0
    start_time = time()

    for epoch in range(FLAGS.epochs):
        epoch_start_time = time()

        batch_iter = prefetcher(iter(data_loader_train), data_stream)

        for step in range(len(data_loader_train)):
            timer.click()

            numerical_features, categorical_features, click = next(batch_iter)
            torch.cuda.synchronize()

            global_step = steps_per_epoch * epoch + step

            if FLAGS.max_steps and global_step > FLAGS.max_steps:
                print(f"Reached max global steps of {FLAGS.max_steps}. Stopping.")
                break

            # One of the batches will be smaller because the dataset size
            # isn't necessarily a multiple of the batch size. #TODO isn't dropping here a change of behavior
            if click.shape[0] != FLAGS.batch_size:
                continue

            lr_scheduler.step()
            loss = trainer.train_step(numerical_features, categorical_features, click)

            # need to wait for the gradients before the weight update
            torch.cuda.current_stream().wait_stream(trainer.stream)
            weight_update()
            moving_loss += loss

            if timer.measured is None:
                # first iteration, no step time etc. to print
                continue

            if step == 0:
                print(f"Started epoch {epoch}...")
            elif step % print_freq == 0:
                # Averaging across a print_freq period to reduce the error.
                # An accurate timing needs synchronize which would slow things down.

                if global_step < FLAGS.benchmark_warmup_steps:
                    metric_logger.update(
                        loss=moving_loss.item() / print_freq,
                        lr=mlp_optimizer.param_groups[0]["lr"])
                else:
                    metric_logger.update(
                        step_time=timer.measured,
                        loss=moving_loss.item() / print_freq,
                        lr=mlp_optimizer.param_groups[0]["lr"])

                eta_str = datetime.timedelta(seconds=int(metric_logger.step_time.global_avg * (steps_per_epoch - step)))
                metric_logger.print(header=f"Epoch:[{epoch}/{FLAGS.epochs}] [{step}/{steps_per_epoch}]  eta: {eta_str}")

                moving_loss = 0.

            if global_step % test_freq == 0 and global_step > 0 and global_step / steps_per_epoch >= FLAGS.test_after:
                auc = dist_evaluate(trainer.model, data_loader_test)

                if auc is None:
                    continue

                print(f"Epoch {epoch} step {step}. auc {auc:.6f}")
                stop_time = time()

                if auc > best_auc:
                    best_auc = auc
                    best_epoch = epoch + ((step + 1) / steps_per_epoch)

                if FLAGS.auc_threshold and auc >= FLAGS.auc_threshold:
                    run_time_s = int(stop_time - start_time)
                    print(f"Hit target accuracy AUC {FLAGS.auc_threshold} at epoch "
                          f"{global_step / steps_per_epoch:.2f} in {run_time_s}s. ")
                    sys.exit()

        epoch_stop_time = time()
        epoch_time_s = epoch_stop_time - epoch_start_time
        print(f"Finished epoch {epoch} in {datetime.timedelta(seconds=int(epoch_time_s))}. ")

    avg_throughput = FLAGS.batch_size / metric_logger.step_time.avg

    if FLAGS.save_checkpoint_path:
        checkpoint_writer.save_checkpoint(model, FLAGS.save_checkpoint_path, epoch, step)

    results = {'best_auc': best_auc,
               'best_epoch': best_epoch,
               'average_train_throughput': avg_throughput}

    if is_main_process():
        dllogger.log(data=results, step=tuple())
Ejemplo n.º 15
0
if rank == 0:
    print('Creating model')
# Create model, freeze layers and change last layer
model, params = create_model(params)

model = apex.parallel.convert_syncbn_model(model)
model.cuda()
params_to_update = model.get_optim_policies()

if rank == 0:
    print('Creating optimizer')
# Create optimizer and learning rate schedules
if params['use_lr_scheduler'] == 1:
    params['max_lr'] = params['max_lr'] * 10
optimizer = apex_optim.FusedAdam(params_to_update,
                                 lr=params['max_lr'],
                                 weight_decay=params['weight_decay'])
model, optimizer = amp.initialize(model,
                                  optimizer,
                                  opt_level='O1',
                                  verbosity=0)
model = DDP(model, delay_allreduce=True)

# Learning rate scheme
if bool(params['use_lr_scheduler']) == 1:
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=params["n_epochs"] // 3,
                                    gamma=0.1)
else:
    scheduler = None
Ejemplo n.º 16
0
def train(params, args, world_rank, local_rank):

    #logging info
    logging.info('rank {:d}, begin data loader init (local rank {:d})'.format(
        world_rank, local_rank))

    # set device
    device = torch.device("cuda:{}".format(local_rank))

    # data loader
    pipe = dl.DaliPipeline(params,
                           num_threads=params.num_data_workers,
                           device_id=device.index)
    pipe.build()
    train_data_loader = DALIGenericIterator([pipe], ['inp', 'tar'],
                                            params.Nsamples,
                                            auto_reset=True)
    logging.info('rank %d, data loader initialized' % world_rank)

    model = UNet.UNet(params)
    model.to(device)
    if not args.resuming:
        model.apply(model.get_weights_function(params.weight_init))

    optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision
    if params.distributed:
        model = DDP(model, device_ids=[local_rank])

    # loss
    criterion = UNet.CosmoLoss(params.LAMBDA_2)

    # amp stuff
    if args.enable_amp:
        gscaler = amp.GradScaler()

    iters = 0
    startEpoch = 0
    checkpoint = None
    if args.resuming:
        if world_rank == 0:
            logging.info("Loading checkpoint %s" % params.checkpoint_path)
        checkpoint = torch.load(params.checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state'])
        iters = checkpoint['iters']
        startEpoch = checkpoint['epoch'] + 1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if world_rank == 0:
        logging.info(model)
        logging.info("Starting Training Loop...")

    for epoch in range(startEpoch, startEpoch + params.num_epochs):
        start = time.time()
        nsteps = 0
        fw_time = 0.
        bw_time = 0.
        log_time = 0.

        model.train()
        step_time = time.time()
        #for i, data in enumerate(train_data_loader, 0):
        with torch.autograd.profiler.emit_nvtx():
            for data in train_data_loader:
                iters += 1

                #adjust_LR(optimizer, params, iters)
                inp = data[0]["inp"]
                tar = data[0]["tar"]

                if not args.io_only:
                    torch.cuda.nvtx.range_push("cosmo3D:forward")
                    # fw pass
                    fw_time -= time.time()
                    optimizer.zero_grad()
                    with amp.autocast(args.enable_amp):
                        gen = model(inp)
                        loss = criterion(gen, tar)
                    fw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                    # bw pass
                    torch.cuda.nvtx.range_push("cosmo3D:backward")
                    bw_time -= time.time()
                    if args.enable_amp:
                        gscaler.scale(loss).backward()
                        gscaler.step(optimizer)
                        gscaler.update()
                    else:
                        loss.backward()
                        optimizer.step()
                    bw_time += time.time()
                    torch.cuda.nvtx.range_pop()

                nsteps += 1

            # epoch done
            dist.barrier()
            step_time = (time.time() - step_time) / float(nsteps)
            fw_time /= float(nsteps)
            bw_time /= float(nsteps)
            io_time = max([step_time - fw_time - bw_time, 0])
            iters_per_sec = 1. / step_time

            end = time.time()
            if world_rank == 0:
                logging.info('Time taken for epoch {} is {} sec'.format(
                    epoch + 1, end - start))
                logging.info(
                    'total time / step = {}, fw time / step = {}, bw time / step = {}, exposed io time / step = {}, iters/s = {}, logging time = {}'
                    .format(step_time, fw_time, bw_time, io_time,
                            iters_per_sec, log_time))

        ## Output training stats
        #model.eval()
        #if world_rank==0:
        #  log_start = time.time()
        #  gens = []
        #  tars = []
        #  with torch.no_grad():
        #    for i, data in enumerate(train_data_loader, 0):
        #      if i>=16:
        #        break
        #      #inp, tar = map(lambda x: x.to(device), data)
        #      inp, tar = data
        #      gen = model(inp)
        #      gens.append(gen.detach().cpu().numpy())
        #      tars.append(tar.detach().cpu().numpy())
        #  gens = np.concatenate(gens, axis=0)
        #  tars = np.concatenate(tars, axis=0)
        #
        #  # Scalars
        #  args.tboard_writer.add_scalar('G_loss', loss.item(), iters)
        #
        #  # Plots
        #  fig, chi, L1score = meanL1(gens, tars)
        #  args.tboard_writer.add_figure('pixhist', fig, iters, close=True)
        #  args.tboard_writer.add_scalar('Metrics/chi', chi, iters)
        #  args.tboard_writer.add_scalar('Metrics/rhoL1', L1score[0], iters)
        #  args.tboard_writer.add_scalar('Metrics/vxL1', L1score[1], iters)
        #  args.tboard_writer.add_scalar('Metrics/vyL1', L1score[2], iters)
        #  args.tboard_writer.add_scalar('Metrics/vzL1', L1score[3], iters)
        #  args.tboard_writer.add_scalar('Metrics/TL1', L1score[4], iters)
        #
        #  fig = generate_images(inp.detach().cpu().numpy()[0], gens[-1], tars[-1])
        #  args.tboard_writer.add_figure('genimg', fig, iters, close=True)
        #  log_end = time.time()
        #  log_time += log_end - log_start

        #  # Save checkpoint
        #  torch.save({'iters': iters, 'epoch':epoch, 'model_state': model.state_dict(),
        #              'optimizer_state_dict': optimizer.state_dict()}, params.checkpoint_path)

        #end = time.time()
        #if world_rank==0:
        #    logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, end-start))
        #    logging.info('total time / step = {}, fw time / step = {}, bw time / step = {}, exposed io time / step = {}, iters/s = {}, logging time = {}'
        #                 .format(step_time, fw_time, bw_time, io_time, iters_per_sec, log_time))

        # finalize
        dist.barrier()
Ejemplo n.º 17
0
def run():
    with open(args.cfg_path) as f:
        cfg = json.load(f)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.device_ids
    # num_GPU = len(args.device_ids.split(','))
    batch_size_train = cfg['train_batch_size']
    batch_size_valid = cfg['test_batch_size']
    num_workers = args.num_workers

    model = EfficientNet.from_pretrained(cfg['model'], num_classes=2)
    # model = densenet.densenet121(
    #     pretrained=False, num_classes=2, drop_rate=0.2)

    model = apex.parallel.convert_syncbn_model(model)

    # model = DataParallel(model, device_ids=None)
    model = model.to(device)
    loss_fn = nn.SmoothL1Loss().to(device)
    # loss_fn = nn.CrossEntropyLoss().to(device)
    # loss_fn = [nn.CrossEntropyLoss().to(device), nn.SmoothL1Loss().to(device)]
    if cfg['optimizer'] == 'SGD':
        optimizer = optim.SGD(model.parameters(),
                              lr=cfg['lr'],
                              momentum=cfg['momentum'],
                              weight_decay=1e-4)

    elif cfg['optimizer'] == 'Adam':
        optimizer = optimizers.FusedAdam(model.parameters(),
                                         lr=cfg['lr'],
                                         betas=(0.9, 0.999),
                                         weight_decay=1e-4)

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level="O1",
    )

    if args.resume:
        model, epoch = load_checkpoint(args, model, optimizer, amp)
        if args.start_epoch < epoch:
            args.start_epoch = epoch

    if args.distributed:
        # model = DDP(model)
        # delay_allreduce delays all communication to the end of the backward pass.
        model = DDP(model, delay_allreduce=True)

    dataset_valid = DegreesData(cfg['test_data_path'],
                                cfg['image_size'],
                                sample=False)

    eval_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset_valid)

    dataloader_valid = DataLoader(dataset_valid,
                                  sampler=eval_sampler,
                                  batch_size=batch_size_valid,
                                  num_workers=num_workers,
                                  drop_last=True,
                                  shuffle=False)

    summary_train = {'epoch': 0, 'step': 0}
    summary_valid = {'loss': float('inf'), 'step': 0, 'acc': 0}
    summary_writer = None

    if args.local_rank == 0:
        summary_writer = SummaryWriter(log_path)

    loss_valid_best = float('inf')
    lr = cfg['lr']

    for epoch in range(args.start_epoch, args.end_epoch):
        lr = adjust_learning_rate(optimizer, epoch, cfg, args)

        dataset_train = DegreesData(cfg['train_data_path'],
                                    cfg['image_size'],
                                    istraining=True)
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            dataset_train)
        dataloader_train = DataLoader(dataset_train,
                                      sampler=train_sampler,
                                      batch_size=batch_size_train,
                                      num_workers=num_workers,
                                      drop_last=True,
                                      shuffle=(train_sampler is None))
        summary_train = train_epoch(epoch, summary_train, summary_writer,
                                    model, loss_fn, optimizer,
                                    dataloader_train, cfg)
        if args.local_rank == 0:
            if epoch % 10 == 0:
                torch.save(
                    {
                        'epoch': summary_train['epoch'],
                        'step': summary_train['step'],
                        'state_dict': model.module.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'amp': amp.state_dict()
                    }, (ckpt_path_save + '/' + str(epoch) + '.ckpt'))

        for param_group in optimizer.param_groups:
            lr = param_group['lr']
            if args.local_rank == 0:
                print('Learning_rate:', lr)
            break
        # summary_writer.add_scalar(
        #   'ROC',summary_train['tp']*1.0 / summary_train['Pos'],summary_train['fp']*1.0 / summary_train['Neg'])
        if epoch % 1 == 0:
            summary_valid = valid_epoch(summary_valid, summary_writer, epoch,
                                        model, loss_fn, dataloader_valid, cfg)
            if args.local_rank == 0:
                summary_writer.add_scalar('valid/loss', summary_valid['loss'],
                                          epoch)
                summary_writer.add_scalar('valid/acc', summary_valid['acc'],
                                          epoch)
                summary_writer.add_scalar('valid/R2', summary_valid['r2'],
                                          epoch)

        if args.local_rank == 0:
            if summary_valid['loss'] < loss_valid_best:
                loss_valid_best = summary_valid['loss']
                torch.save(
                    {
                        'epoch': summary_train['epoch'],
                        'step': summary_train['step'],
                        'state_dict': model.module.state_dict()
                    }, os.path.join(ckpt_path_save, 'best.ckpt'))
            summary_writer.flush()
Ejemplo n.º 18
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--model_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="")
    parser.add_argument("--my_config", default=None, type=str, required=True)
    parser.add_argument("--feature_path",
                        default=None,
                        type=str,
                        required=True)
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )

    ## Other parameters
    parser.add_argument("--train_pattern",
                        default=None,
                        type=str,
                        help="training data path.")
    parser.add_argument("--valid_pattern",
                        default=None,
                        type=str,
                        help="validation data path.")
    parser.add_argument("--test_pattern",
                        default=None,
                        type=str,
                        help="test data path.")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--report_steps",
                        default=100,
                        type=int,
                        help="report steps when training.")
    parser.add_argument("--train_batch_size",
                        default=4,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_steps", default=100, type=int)
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
        "of training.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--train_size',
                        type=int,
                        default=10000,
                        help="Use how many train data")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--frame_name',
                        type=str,
                        default='elgeish/cs224n-squad2.0-albert-large-v2')
    parser.add_argument('--DataName', type=str, default="SST")
    parser.add_argument("--adam_epsilon",
                        default=1e-6,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument(
        '--run_og',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    args = parser.parse_args()
    #print(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        #        torch.cuda.set_device(args.local_rank)
        #        print(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        n_gpu = 1

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_pattern:
            raise ValueError(
                "If `do_train` is True, then `train_pattern` must be specified."
            )

    if args.do_predict:
        if not args.test_pattern:
            raise ValueError(
                "If `do_predict` is True, then `test_pattern` must be specified."
            )

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Prepare model
    my_config = Config(args.my_config)
    my_config.num_edge_types = sum(EdgePosition.max_edge_types)
    #    my_config.forward_edges = [EdgeType.TOKEN_TO_SENTENCE,
    #                               EdgeType.SENTENCE_TO_PARAGRAPH,
    #                               EdgeType.PARAGRAPH_TO_DOCUMENT]
    #print(my_config)
    if args.do_train:

        model = NqModel(my_config=my_config, args=args)
        #model_dict = model.state_dict()
        #pretrained_model_dict = torch.load(pretrained_model_file, map_location=lambda storage, loc: storage)
        #pretrained_model_dict = {k: v for k, v in pretrained_model_dict.items() if k in model_dict.keys()}
        #model_dict.update(pretrained_model_dict)
        #model.load_state_dict(model_dict)
#    else:
#        pretrained_config_file = os.path.join(args.model_dir, CONFIG_NAME)
##        bert_config = BertConfig(pretrained_config_file)
#        model = NqModel( my_config=my_config)
#        pretrained_model_file = os.path.join(args.model_dir, WEIGHTS_NAME)
#        model.load_state_dict(torch.load(pretrained_model_file))

    if args.fp16:
        model.half()
    global run_og
    run_og = args.run_og
    if args.run_og:
        if n_gpu:
            model.cuda()
        if args.local_rank != -1:
            #            model = torch.nn.parallel.DistributedDataParallel(model,find_unused_parameters=True)
            model = torch.nn.parallel.DistributedDataParallel(model)

    else:
        model.bert.to("cuda:0")
        model.encoder.to("cuda:1")
        model.tok_outputs.to("cuda:0")
        model.tok_dense.to("cuda:0")
        model.dropout.to("cuda:0")


#    if args.local_rank != -1:
#        try:
#            from apex.parallel import DistributedDataParallel as DDP
#        except ImportError:
#            raise ImportError(
#                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
#
#        model = DDP(model)
#    elif n_gpu > 1:
#        model = torch.nn.DataParallel(model)

    num_train_features = None
    num_train_optimization_steps = None

    #train_dataset = None
    #train_features = None

    #    output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
    #    model.load_state_dict(torch.load(output_model_file))

    prefix = "cached_{0}_{1}_{2}_{3}".format(str(args.max_seq_length),
                                             str(args.doc_stride),
                                             str(args.max_query_length),
                                             "RaceA")
    prefix = os.path.join("features", prefix)
    cached_path = os.path.join(prefix, "train.pkl")
    with open(cached_path, "rb") as reader:
        RaceFeatures = pickle.load(reader)
    num_race = len(RaceFeatures)
    RaceFeatures = convert_features_to_tensors(RaceFeatures, "multi-choice")

    if args.do_train:
        num_train_features = 0
        for data_path in glob(args.train_pattern):
            train_dataset = NqDataset(args, data_path, is_training=True)
            train_features = train_dataset.features
            num_train_features += len(train_dataset.features)
        num_train_features += num_train_features

        print(num_train_features, args.train_batch_size,
              args.gradient_accumulation_steps)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    #    print(param_optimizer)
    #print([i for i,j in model.named_parameters()])

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        args.weight_decay
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.00
    }]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
    else:

        #        optimizer = BertAdam(optimizer_grouped_parameters,
        #                             lr=args.learning_rate,
        #                             warmup=args.warmup_proportion,
        #                             t_total=num_train_optimization_steps)

        feature_cnt = num_train_features * 2 + num_race
        sampling_prob = [
            num_train_features * 2 / feature_cnt, num_race / feature_cnt
        ]
        feature_cnt -= num_train_features
        num_train_optimization_steps = int(
            (feature_cnt / args.train_batch_size) /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.warmup_steps > 0:
            args.warmup_proportion = min(
                args.warmup_proportion,
                args.warmup_steps / num_train_optimization_steps)
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )
        optimizer = apex_optim.FusedAdam(optimizer_grouped_parameters,
                                         lr=args.learning_rate,
                                         eps=args.adam_epsilon)
        # optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        #        optimizer = SGD(optimizer_grouped_parameters, lr=args.learning_rate,momentum=0.9)

        scheduler = WarmupLinearSchedule(
            optimizer,
            warmup_steps=int(args.warmup_proportion *
                             num_train_optimization_steps)
            if args.warmup_proportion > 0 else args.warmup_steps,
            t_total=num_train_optimization_steps)
        # scheduler = WarmupConstantSchedule(optimizer,
        #                              warmup_steps=int(args.warmup_proportion * num_train_optimization_steps)
        #                              if args.warmup_proportion > 0 else args.warmup_steps)

    global_step = 0
    last_acc = 89.5
    albert_toker = AlbertTokenizer.from_pretrained('albert-xxlarge-v2')
    model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
    # torch.autograd.set_detect_anomaly(True)

    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num split examples = %d", num_train_features)
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        tr_loss, report_loss = 0.0, 0.0
        nb_tr_examples = 0
        model.zero_grad()
        optimizer.zero_grad()
        # Err_test = False
        ErrorSelect = open("./Err.txt", 'w+')

        train_dataset = NqDataset(args, data_path, is_training=True)
        train_features = train_dataset.features
        #logging.info("Data Load Done!")
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_features)
        else:
            train_sampler = DistributedSampler(train_features)
        train_dataloader = DataLoader(train_features,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size,
                                      collate_fn=batcher(device,
                                                         is_training=True),
                                      num_workers=0)

        train_dataloader = InfiniteDataLoader(train_dataloader)

        train_features = train_dataset.features
        logging.info("Dream_Data ready {} ".format(len(train_features)))
        RACE_train_dataloader = DataLoader(RaceFeatures,
                                           sampler=RandomSampler(RaceFeatures),
                                           batch_size=args.train_batch_size)
        logging.info("RACE_Data ready {} ".format(len(RaceFeatures)))
        RACE_train_dataloader = InfiniteDataLoader(RACE_train_dataloader)
        Epoch_cnt = 0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            logging.info("Loggin TEST!")
            for data_path in glob(args.train_pattern):
                model.train()

                logging.info("Step Pre Epoch: {} ".format(
                    int(feature_cnt) // args.gradient_accumulation_steps))
                for step, _ in enumerate(range(int(feature_cnt))):
                    # if not Err_test:
                    #     WrOut = ""
                    #     for i in albert_toker.convert_ids_to_tokens(batch.input_ids[0][0]):
                    #         WrOut+=str(i)
                    #     ErrorSelect.write(WrOut)
                    #     Err_test = True

                    task_id = np.argmax(np.random.multinomial(
                        1, sampling_prob))
                    if task_id == 0:
                        batch = train_dataloader.get_next()
                        loss = model(batch.input_ids, batch.input_mask,
                                     batch.segment_ids, batch.st_mask,
                                     (batch.edges_src, batch.edges_tgt,
                                      batch.edges_type, batch.edges_pos),
                                     batch.label, batch.all_sen)
                    else:
                        batch = RACE_train_dataloader.get_next()
                        batch = tuple(t.to(device) for t in batch)
                        inputs = {
                            'input_idss': batch[0],
                            'attention_masks': batch[1],
                            'token_type_idss':
                            batch[2],  # XLM don't use segment_ids
                            'labels': batch[3],
                            'all_sens': batch[4]
                        }

                        loss = model(**inputs)

                    if n_gpu > 1:
                        loss = loss.mean()  # mean() to average on multi-gpu.
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()

                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 50.0)

                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        #                        gc.collect() s
                        #                        torch.cuda.empty_cache()
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), 1.0)
                        optimizer.step()
                        scheduler.step()
                        optimizer.zero_grad()
                        global_step += 1

                    tr_loss += loss.item()
                    nb_tr_examples += 1

                    if (step + 1) % args.gradient_accumulation_steps == 0 and (
                            global_step + 1) % args.report_steps == 0 and (
                                args.local_rank == -1
                                or torch.distributed.get_rank() == 0):
                        lr_this_step = get_lr(optimizer)
                        logging.info(
                            "Epoch={} iter={} lr={:.12f} train_ave_loss={:.6f} ."
                            .format(
                                # _, global_step, lr_this_step, tr_loss / nb_tr_examples))
                                Epoch_cnt,
                                global_step,
                                lr_this_step,
                                (tr_loss - report_loss) / args.report_steps))
                        report_loss = tr_loss

            model.eval()
            model.zero_grad()
            model.ACC = model.ALL = 0
            test_dataset = NqDataset(args, "test.json", is_training=True)
            test_features = test_dataset.features
            #logging.info("Data Load Done!")

            if args.local_rank == -1:
                train_sampler = RandomSampler(test_features)
            else:
                train_sampler = DistributedSampler(test_features)
            if args.local_rank == -1:
                test_dataloader = DataLoader(test_features,
                                             sampler=train_sampler,
                                             batch_size=args.train_batch_size,
                                             collate_fn=batcher(
                                                 device, is_training=True),
                                             num_workers=0)
            else:
                test_dataloader = DataLoader(test_features,
                                             sampler=train_sampler,
                                             batch_size=args.train_batch_size,
                                             collate_fn=batcher(
                                                 device, is_training=True),
                                             num_workers=0,
                                             drop_last=True)

            test_features = test_dataset.features
            logging.info("Data ready {} ".format(len(test_features)))
            tgobal_step = 0
            ttr_loss = 0
            optimizer.zero_grad()
            logging.info("***** Running evalating *****")

            with torch.no_grad():
                for step, batch in enumerate(test_dataloader):
                    tgobal_step += 1
                    tmp_acc = model.ACC
                    loss = model(batch.input_ids, batch.input_mask,
                                 batch.segment_ids, batch.st_mask,
                                 (batch.edges_src, batch.edges_tgt,
                                  batch.edges_type, batch.edges_pos),
                                 batch.label, batch.all_sen)
                    if model.ACC == tmp_acc:
                        WrOut = ""
                        for i in albert_toker.convert_ids_to_tokens(
                                batch.input_ids[0][0]):
                            WrOut += str(i)
                        ErrorSelect.write(WrOut)
                    ttr_loss += loss.item()
            logging.info("ACC:{}% LOSS:{}".format(model.ACC / model.ALL * 100,
                                                  ttr_loss / tgobal_step))
            model.zero_grad()
            optimizer.zero_grad()

            if model.ACC / model.ALL * 100 > last_acc:
                logging.info("Save Model")
                last_acc = model.ACC / model.ALL * 100
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self

                # If we save using the predefined names, we can load using `from_pretrained`
                output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
                torch.save(model_to_save.state_dict(), output_model_file)
            Epoch_cnt += 1
        ErrorSelect.close()