コード例 #1
0
ファイル: train.py プロジェクト: CookiePPP/codedump
def getOptimizer(model, hparams):
    if hparams.optimizer_fused:
        from apex import optimizers as apexopt
        if hparams.optimizer == "Adam":
            optimizer = apexopt.FusedAdam(model.parameters(),
                                          lr=hparams.learning_rate)
        elif hparams.optimizer == "LAMB":
            optimizer = apexopt.FusedLAMB(model.parameters(),
                                          lr=hparams.learning_rate)
    else:
        if hparams.optimizer == "Adam":
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=hparams.learning_rate)
        elif hparams.optimizer == "LAMB":
            raise NotImplementedError  # PyTorch doesn't currently include LAMB optimizer.
    return optimizer
コード例 #2
0
ファイル: train.py プロジェクト: Marcus-Arcadius/cookietts
def get_optimizer(model, optimizer: str, fp16_run=True, optimizer_fused=True, learning_rate=1e-4, max_grad_norm=200):
    optimizer = optimizer.lower()
    if optimizer_fused:
        from apex import optimizers as apexopt
        if optimizer == "adam":
            optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            optimizer = apexopt.FusedLAMB(model.parameters(), lr=learning_rate, max_grad_norm=max_grad_norm)
    else:
        if optimizer == "adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            from lamb import Lamb as optLAMB
            optimizer = optLAMB(model.parameters(), lr=learning_rate)
    
    global amp
    if fp16_run:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    else:
        amp = None
    return model, optimizer
コード例 #3
0
def main(pargs):

    # this should be global
    global have_wandb

    #init distributed training
    comm.init(pargs.wireup_method)
    comm_rank = comm.get_rank()
    comm_local_rank = comm.get_local_rank()
    comm_size = comm.get_size()

    # set up logging
    pargs.logging_frequency = max([pargs.logging_frequency, 1])
    log_file = os.path.normpath(
        os.path.join(pargs.output_dir, "logs", pargs.run_tag + ".log"))
    logger = mll.mlperf_logger(log_file, "deepcam", "Umbrella Corp.")
    logger.log_start(key="init_start", sync=True)
    logger.log_event(key="cache_clear")

    #set seed
    seed = 333
    logger.log_event(key="seed", value=seed)

    # Some setup
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        device = torch.device("cuda", comm_local_rank)
        torch.cuda.manual_seed(seed)
        #necessary for AMP to work
        torch.cuda.set_device(device)

        # TEST: allowed? Valuable?
        #torch.backends.cudnn.benchark = True
    else:
        device = torch.device("cpu")

    #visualize?
    visualize = (pargs.training_visualization_frequency >
                 0) or (pargs.validation_visualization_frequency > 0)

    #set up directories
    root_dir = os.path.join(pargs.data_dir_prefix)
    output_dir = pargs.output_dir
    plot_dir = os.path.join(output_dir, "plots")
    if comm_rank == 0:
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
        if visualize and not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)

    # Setup WandB
    if not pargs.enable_wandb:
        have_wandb = False
    if have_wandb and (comm_rank == 0):
        # get wandb api token
        certfile = os.path.join(pargs.wandb_certdir, ".wandbirc")
        try:
            with open(certfile) as f:
                token = f.readlines()[0].replace("\n", "").split()
                wblogin = token[0]
                wbtoken = token[1]
        except IOError:
            print("Error, cannot open WandB certificate {}.".format(certfile))
            have_wandb = False

        if have_wandb:
            # log in: that call can be blocking, it should be quick
            sp.call(["wandb", "login", wbtoken])

            #init db and get config
            resume_flag = pargs.run_tag if pargs.resume_logging else False
            wandb.init(entity=wblogin,
                       project='deepcam',
                       name=pargs.run_tag,
                       id=pargs.run_tag,
                       resume=resume_flag)
            config = wandb.config

            #set general parameters
            config.root_dir = root_dir
            config.output_dir = pargs.output_dir
            config.max_epochs = pargs.max_epochs
            config.local_batch_size = pargs.local_batch_size
            config.num_workers = comm_size
            config.channels = pargs.channels
            config.optimizer = pargs.optimizer
            config.start_lr = pargs.start_lr
            config.adam_eps = pargs.adam_eps
            config.weight_decay = pargs.weight_decay
            config.model_prefix = pargs.model_prefix
            config.amp_opt_level = pargs.amp_opt_level
            config.loss_weight_pow = pargs.loss_weight_pow
            config.lr_warmup_steps = pargs.lr_warmup_steps
            config.lr_warmup_factor = pargs.lr_warmup_factor

            # lr schedule if applicable
            if pargs.lr_schedule:
                for key in pargs.lr_schedule:
                    config.update(
                        {"lr_schedule_" + key: pargs.lr_schedule[key]},
                        allow_val_change=True)

    # Logging hyperparameters
    logger.log_event(key="global_batch_size",
                     value=(pargs.local_batch_size * comm_size))
    logger.log_event(key="opt_name", value=pargs.optimizer)
    logger.log_event(key="opt_base_learning_rate",
                     value=pargs.start_lr * pargs.lr_warmup_factor)
    logger.log_event(key="opt_learning_rate_warmup_steps",
                     value=pargs.lr_warmup_steps)
    logger.log_event(key="opt_learning_rate_warmup_factor",
                     value=pargs.lr_warmup_factor)
    logger.log_event(key="opt_epsilon", value=pargs.adam_eps)

    # Define architecture
    n_input_channels = len(pargs.channels)
    n_output_channels = 3
    net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels,
                                          n_classes=n_output_channels,
                                          os=16,
                                          pretrained=False,
                                          rank=comm_rank)
    net.to(device)

    #select loss
    loss_pow = pargs.loss_weight_pow
    #some magic numbers
    class_weights = [
        0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow,
        0.01327431072255291**loss_pow
    ]
    fpw_1 = 2.61461122397522257612
    fpw_2 = 1.71641974795896018744
    criterion = losses.fp_loss

    #select optimizer
    optimizer = None
    if pargs.optimizer == "Adam":
        optimizer = optim.Adam(net.parameters(),
                               lr=pargs.start_lr,
                               eps=pargs.adam_eps,
                               weight_decay=pargs.weight_decay)
    elif pargs.optimizer == "AdamW":
        optimizer = optim.AdamW(net.parameters(),
                                lr=pargs.start_lr,
                                eps=pargs.adam_eps,
                                weight_decay=pargs.weight_decay)
    elif have_apex and (pargs.optimizer == "LAMB"):
        optimizer = aoptim.FusedLAMB(net.parameters(),
                                     lr=pargs.start_lr,
                                     eps=pargs.adam_eps,
                                     weight_decay=pargs.weight_decay)
    else:
        raise NotImplementedError("Error, optimizer {} not supported".format(
            pargs.optimizer))

    if have_apex:
        #wrap model and opt into amp
        net, optimizer = amp.initialize(net,
                                        optimizer,
                                        opt_level=pargs.amp_opt_level)

    #make model distributed
    net = DDP(net)

    #restart from checkpoint if desired
    #if (comm_rank == 0) and (pargs.checkpoint):
    #load it on all ranks for now
    if pargs.checkpoint:
        checkpoint = torch.load(pargs.checkpoint, map_location=device)
        start_step = checkpoint['step']
        start_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer'])
        net.load_state_dict(checkpoint['model'])
        if have_apex:
            amp.load_state_dict(checkpoint['amp'])
    else:
        start_step = 0
        start_epoch = 0

    #select scheduler
    if pargs.lr_schedule:
        scheduler_after = ph.get_lr_schedule(pargs.start_lr,
                                             pargs.lr_schedule,
                                             optimizer,
                                             last_step=start_step)

        # LR warmup
        if pargs.lr_warmup_steps > 0:
            if have_warmup_scheduler:
                scheduler = GradualWarmupScheduler(
                    optimizer,
                    multiplier=pargs.lr_warmup_factor,
                    total_epoch=pargs.lr_warmup_steps,
                    after_scheduler=scheduler_after)
            # Throw an error if the package is not found
            else:
                raise Exception(
                    f'Requested {pargs.lr_warmup_steps} LR warmup steps '
                    'but warmup scheduler not found. Install it from '
                    'https://github.com/ildoonet/pytorch-gradual-warmup-lr')
        else:
            scheduler = scheduler_after

    #broadcast model and optimizer state
    steptens = torch.tensor(np.array([start_step, start_epoch]),
                            requires_grad=False).to(device)
    dist.broadcast(steptens, src=0)

    ##broadcast model and optimizer state
    #hvd.broadcast_parameters(net.state_dict(), root_rank = 0)
    #hvd.broadcast_optimizer_state(optimizer, root_rank = 0)

    #unpack the bcasted tensor
    start_step = steptens.cpu().numpy()[0]
    start_epoch = steptens.cpu().numpy()[1]

    # Set up the data feeder
    # train
    train_dir = os.path.join(root_dir, "train")
    train_set = cam.CamDataset(train_dir,
                               statsfile=os.path.join(root_dir, 'stats.h5'),
                               channels=pargs.channels,
                               allow_uneven_distribution=False,
                               shuffle=True,
                               preprocess=True,
                               comm_size=comm_size,
                               comm_rank=comm_rank)
    train_loader = DataLoader(
        train_set,
        pargs.local_batch_size,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        pin_memory=True,
        drop_last=True)

    # validation: we only want to shuffle the set if we are cutting off validation after a certain number of steps
    validation_dir = os.path.join(root_dir, "validation")
    validation_set = cam.CamDataset(validation_dir,
                                    statsfile=os.path.join(
                                        root_dir, 'stats.h5'),
                                    channels=pargs.channels,
                                    allow_uneven_distribution=True,
                                    shuffle=(pargs.max_validation_steps
                                             is not None),
                                    preprocess=True,
                                    comm_size=comm_size,
                                    comm_rank=comm_rank)
    # use batch size = 1 here to make sure that we do not drop a sample
    validation_loader = DataLoader(
        validation_set,
        1,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        pin_memory=True,
        drop_last=True)

    # log size of datasets
    logger.log_event(key="train_samples", value=train_set.global_size)
    if pargs.max_validation_steps is not None:
        val_size = min([
            validation_set.global_size,
            pargs.max_validation_steps * pargs.local_batch_size * comm_size
        ])
    else:
        val_size = validation_set.global_size
    logger.log_event(key="eval_samples", value=val_size)

    # do sanity check
    if pargs.max_validation_steps is not None:
        logger.log_event(key="invalid_submission")

    #for visualization
    #if visualize:
    #    viz = vizc.CamVisualizer()

    # Train network
    if have_wandb and (comm_rank == 0):
        wandb.watch(net)

    step = start_step
    epoch = start_epoch
    current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr(
    )[0]
    stop_training = False
    net.train()

    # start trining
    logger.log_end(key="init_stop", sync=True)
    logger.log_start(key="run_start", sync=True)

    # training loop
    while True:

        # start epoch
        logger.log_start(key="epoch_start",
                         metadata={
                             'epoch_num': epoch + 1,
                             'step_num': step
                         },
                         sync=True)

        # epoch loop
        for inputs, label, filename in train_loader:

            # send to device
            inputs = inputs.to(device)
            label = label.to(device)

            # forward pass
            outputs = net.forward(inputs)

            # Compute loss and average across nodes
            loss = criterion(outputs,
                             label,
                             weight=class_weights,
                             fpw_1=fpw_1,
                             fpw_2=fpw_2)

            # Backprop
            optimizer.zero_grad()
            if have_apex:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            # step counter
            step += 1

            if pargs.lr_schedule:
                current_lr = scheduler.get_last_lr()[0]
                scheduler.step()

            #visualize if requested
            #if (step % pargs.training_visualization_frequency == 0) and (comm_rank == 0):
            #    # Compute predictions
            #    predictions = torch.max(outputs, 1)[1]
            #
            #    # extract sample id and data tensors
            #    sample_idx = np.random.randint(low=0, high=label.shape[0])
            #    plot_input = inputs.detach()[sample_idx, 0,...].cpu().numpy()
            #    plot_prediction = predictions.detach()[sample_idx,...].cpu().numpy()
            #    plot_label = label.detach()[sample_idx,...].cpu().numpy()
            #
            #    # create filenames
            #    outputfile = os.path.basename(filename[sample_idx]).replace("data-", "training-").replace(".h5", ".png")
            #    outputfile = os.path.join(plot_dir, outputfile)
            #
            #    # plot
            #    viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label)
            #
            #    #log if requested
            #    if have_wandb:
            #        img = Image.open(outputfile)
            #        wandb.log({"train_examples": [wandb.Image(img, caption="Prediction vs. Ground Truth")]}, step = step)

            #log if requested
            if (step % pargs.logging_frequency == 0):

                # allreduce for loss
                loss_avg = loss.detach()
                dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM)
                loss_avg_train = loss_avg.item() / float(comm_size)

                # Compute score
                predictions = torch.max(outputs, 1)[1]
                iou = utils.compute_score(predictions,
                                          label,
                                          device_id=device,
                                          num_classes=3)
                iou_avg = iou.detach()
                dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM)
                iou_avg_train = iou_avg.item() / float(comm_size)

                logger.log_event(key="learning_rate",
                                 value=current_lr,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })
                logger.log_event(key="train_accuracy",
                                 value=iou_avg_train,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })
                logger.log_event(key="train_loss",
                                 value=loss_avg_train,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })

                if have_wandb and (comm_rank == 0):
                    wandb.log(
                        {"train_loss": loss_avg.item() / float(comm_size)},
                        step=step)
                    wandb.log(
                        {"train_accuracy": iou_avg.item() / float(comm_size)},
                        step=step)
                    wandb.log({"learning_rate": current_lr}, step=step)
                    wandb.log({"epoch": epoch + 1}, step=step)

            # validation step if desired
            if (step % pargs.validation_frequency == 0):

                logger.log_start(key="eval_start",
                                 metadata={'epoch_num': epoch + 1})

                #eval
                net.eval()

                count_sum_val = torch.Tensor([0.]).to(device)
                loss_sum_val = torch.Tensor([0.]).to(device)
                iou_sum_val = torch.Tensor([0.]).to(device)

                # disable gradients
                with torch.no_grad():

                    # iterate over validation sample
                    step_val = 0
                    # only print once per eval at most
                    visualized = False
                    for inputs_val, label_val, filename_val in validation_loader:

                        #send to device
                        inputs_val = inputs_val.to(device)
                        label_val = label_val.to(device)

                        # forward pass
                        outputs_val = net.forward(inputs_val)

                        # Compute loss and average across nodes
                        loss_val = criterion(outputs_val,
                                             label_val,
                                             weight=class_weights,
                                             fpw_1=fpw_1,
                                             fpw_2=fpw_2)
                        loss_sum_val += loss_val

                        #increase counter
                        count_sum_val += 1.

                        # Compute score
                        predictions_val = torch.max(outputs_val, 1)[1]
                        iou_val = utils.compute_score(predictions_val,
                                                      label_val,
                                                      device_id=device,
                                                      num_classes=3)
                        iou_sum_val += iou_val

                        # Visualize
                        #if (step_val % pargs.validation_visualization_frequency == 0) and (not visualized) and (comm_rank == 0):
                        #    #extract sample id and data tensors
                        #    sample_idx = np.random.randint(low=0, high=label_val.shape[0])
                        #    plot_input = inputs_val.detach()[sample_idx, 0,...].cpu().numpy()
                        #    plot_prediction = predictions_val.detach()[sample_idx,...].cpu().numpy()
                        #    plot_label = label_val.detach()[sample_idx,...].cpu().numpy()
                        #
                        #    #create filenames
                        #    outputfile = os.path.basename(filename[sample_idx]).replace("data-", "validation-").replace(".h5", ".png")
                        #    outputfile = os.path.join(plot_dir, outputfile)
                        #
                        #    #plot
                        #    viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label)
                        #    visualized = True
                        #
                        #    #log if requested
                        #    if have_wandb:
                        #        img = Image.open(outputfile)
                        #        wandb.log({"eval_examples": [wandb.Image(img, caption="Prediction vs. Ground Truth")]}, step = step)

                        #increase eval step counter
                        step_val += 1

                        if (pargs.max_validation_steps is not None
                            ) and step_val > pargs.max_validation_steps:
                            break

                # average the validation loss
                dist.all_reduce(count_sum_val, op=dist.ReduceOp.SUM)
                dist.all_reduce(loss_sum_val, op=dist.ReduceOp.SUM)
                dist.all_reduce(iou_sum_val, op=dist.ReduceOp.SUM)
                loss_avg_val = loss_sum_val.item() / count_sum_val.item()
                iou_avg_val = iou_sum_val.item() / count_sum_val.item()

                # print results
                logger.log_event(key="eval_accuracy",
                                 value=iou_avg_val,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })
                logger.log_event(key="eval_loss",
                                 value=loss_avg_val,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })

                # log in wandb
                if have_wandb and (comm_rank == 0):
                    wandb.log({"eval_loss": loss_avg_val}, step=step)
                    wandb.log({"eval_accuracy": iou_avg_val}, step=step)

                if (iou_avg_val >= pargs.target_iou):
                    logger.log_event(key="target_accuracy_reached",
                                     value=pargs.target_iou,
                                     metadata={
                                         'epoch_num': epoch + 1,
                                         'step_num': step
                                     })
                    stop_training = True

                # set to train
                net.train()

                logger.log_end(key="eval_stop",
                               metadata={'epoch_num': epoch + 1})

            #save model if desired
            if (pargs.save_frequency > 0) and (step % pargs.save_frequency
                                               == 0):
                logger.log_start(key="save_start",
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 },
                                 sync=True)
                if comm_rank == 0:
                    checkpoint = {
                        'step': step,
                        'epoch': epoch,
                        'model': net.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    if have_apex:
                        checkpoint['amp'] = amp.state_dict()
                    torch.save(
                        checkpoint,
                        os.path.join(
                            output_dir, pargs.model_prefix + "_step_" +
                            str(step) + ".cpt"))
                logger.log_end(key="save_stop",
                               metadata={
                                   'epoch_num': epoch + 1,
                                   'step_num': step
                               },
                               sync=True)

            # Stop training?
            if stop_training:
                break

        # log the epoch
        logger.log_end(key="epoch_stop",
                       metadata={
                           'epoch_num': epoch + 1,
                           'step_num': step
                       },
                       sync=True)
        epoch += 1

        # are we done?
        if epoch >= pargs.max_epochs or stop_training:
            break

    # run done
    logger.log_end(key="run_stop", sync=True, metadata={'status': 'success'})
コード例 #4
0
ファイル: train.py プロジェクト: Marcus-Arcadius/cookietts
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          sigma, loss_empthasis, iters_per_checkpoint, batch_size, seed, fp16_run,
          checkpoint_path, with_tensorboard, logdirname, datedlogdir, warm_start=False, optimizer='ADAM', start_zero=False):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======
    
    global WaveGlow
    global WaveGlowLoss
    
    ax = True # this is **really** bad coding practice :D
    if ax:
        from efficient_model_ax import WaveGlow
        from efficient_loss import WaveGlowLoss
    else:
        if waveglow_config["yoyo"]: # efficient_mode # TODO: Add to Config File
            from efficient_model import WaveGlow
            from efficient_loss import WaveGlowLoss
        else:
            from glow import WaveGlow, WaveGlowLoss
    
    criterion = WaveGlowLoss(sigma, loss_empthasis)
    model = WaveGlow(**waveglow_config).cuda()
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======
    STFTs = [STFT.TacotronSTFT(filter_length=window,
                                 hop_length=data_config['hop_length'],
                                 win_length=window,
                                 sampling_rate=data_config['sampling_rate'],
                                 n_mel_channels=160,
                                 mel_fmin=0, mel_fmax=16000) for window in data_config['validation_windows']]
    
    loader_STFT = STFT.TacotronSTFT(filter_length=data_config['filter_length'],
                                 hop_length=data_config['hop_length'],
                                 win_length=data_config['win_length'],
                                 sampling_rate=data_config['sampling_rate'],
                                 n_mel_channels=data_config['n_mel_channels'] if 'n_mel_channels' in data_config.keys() else 160,
                                 mel_fmin=data_config['mel_fmin'], mel_fmax=data_config['mel_fmax'])
    
    #optimizer = "Adam"
    optimizer = optimizer.lower()
    optimizer_fused = bool( 0 ) # use Apex fused optimizer, should be identical to normal but slightly faster and only works on RTX cards
    if optimizer_fused:
        from apex import optimizers as apexopt
        if optimizer == "adam":
            optimizer = apexopt.FusedAdam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            optimizer = apexopt.FusedLAMB(model.parameters(), lr=learning_rate, max_grad_norm=200)
    else:
        if optimizer == "adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        elif optimizer == "lamb":
            from lamb import Lamb as optLAMB
            optimizer = optLAMB(model.parameters(), lr=learning_rate)
            #import torch_optimizer as optim
            #optimizer = optim.Lamb(model.parameters(), lr=learning_rate)
            #raise# PyTorch doesn't currently include LAMB optimizer.
    
    if fp16_run:
        global amp
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
    else:
        amp = None
    
    ## LEARNING RATE SCHEDULER
    if True:
        from torch.optim.lr_scheduler import ReduceLROnPlateau
        min_lr = 1e-8
        factor = 0.1**(1/5) # amount to scale the LR by on Validation Loss plateau
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=factor, patience=20, cooldown=2, min_lr=min_lr, verbose=True, threshold=0.0001, threshold_mode='abs')
        print("ReduceLROnPlateau used as Learning Rate Scheduler.")
    else: scheduler=False
    
    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration, scheduler = load_checkpoint(checkpoint_path, model,
                                                      optimizer, scheduler, fp16_run, warm_start=warm_start)
        iteration += 1  # next iteration is iteration + 1
    if start_zero:
        iteration = 0
    
    trainset = Mel2Samp(**data_config, check_files=True)
    speaker_lookup = trainset.speaker_ids
    # =====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        train_sampler = DistributedSampler(trainset, shuffle=True)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset, num_workers=3, shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)
    
    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)
    
    if with_tensorboard and rank == 0:
        from tensorboardX import SummaryWriter
        if datedlogdir:
            timestr = time.strftime("%Y_%m_%d-%H_%M_%S")
            log_directory = os.path.join(output_directory, logdirname, timestr)
        else:
            log_directory = os.path.join(output_directory, logdirname)
        logger = SummaryWriter(log_directory)
    
    moving_average = int(min(len(train_loader), 100)) # average loss over entire Epoch
    rolling_sum = StreamingMovingAverage(moving_average)
    start_time = time.time()
    start_time_iter = time.time()
    start_time_dekaiter = time.time()
    model.train()
    
    # best (averaged) training loss
    if os.path.exists(os.path.join(output_directory, "best_model")+".txt"):
        best_model_loss = float(str(open(os.path.join(output_directory, "best_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0])
    else:
        best_model_loss = -6.20
    
    # best (validation) MSE on inferred spectrogram.
    if os.path.exists(os.path.join(output_directory, "best_val_model")+".txt"):
        best_MSE = float(str(open(os.path.join(output_directory, "best_val_model")+".txt", "r", encoding="utf-8").read()).split("\n")[0])
    else:
        best_MSE = 9e9
    
    epoch_offset = max(0, int(iteration / len(train_loader)))
    
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("{:,} total parameters in model".format(pytorch_total_params))
    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("{:,} trainable parameters.".format(pytorch_total_params))
    
    print(f"Segment Length: {data_config['segment_length']:,}\nBatch Size: {batch_size:,}\nNumber of GPUs: {num_gpus:,}\nSamples/Iter: {data_config['segment_length']*batch_size*num_gpus:,}")
    
    training = True
    while training:
        try:
            if rank == 0:
                epochs_iterator = tqdm(range(epoch_offset, epochs), initial=epoch_offset, total=epochs, smoothing=0.01, desc="Epoch", position=1, unit="epoch")
            else:
                epochs_iterator = range(epoch_offset, epochs)
            # ================ MAIN TRAINING LOOP! ===================
            for epoch in epochs_iterator:
                print(f"Epoch: {epoch}")
                if num_gpus > 1:
                    train_sampler.set_epoch(epoch)
                
                if rank == 0:
                    iters_iterator = tqdm(enumerate(train_loader), desc=" Iter", smoothing=0, total=len(train_loader), position=0, unit="iter", leave=True)
                else:
                    iters_iterator = enumerate(train_loader)
                for i, batch in iters_iterator:
                    # run external code every iter, allows the run to be adjusted without restarts
                    if (i==0 or iteration % param_interval == 0):
                        try:
                            with open("run_every_epoch.py") as f:
                                internal_text = str(f.read())
                                if len(internal_text) > 0:
                                    #code = compile(internal_text, "run_every_epoch.py", 'exec')
                                    ldict = {'iteration': iteration, 'seconds_elapsed': time.time()-start_time}
                                    exec(internal_text, globals(), ldict)
                                else:
                                    print("No Custom code found, continuing without changes.")
                        except Exception as ex:
                            print(f"Custom code FAILED to run!\n{ex}")
                        globals().update(ldict)
                        locals().update(ldict)
                        if show_live_params:
                            print(internal_text)
                    if not iteration % 50: # check actual learning rate every 20 iters (because I sometimes see learning_rate variable go out-of-sync with real LR)
                        learning_rate = optimizer.param_groups[0]['lr']
                    # Learning Rate Schedule
                    if custom_lr:
                        old_lr = learning_rate
                        if iteration < warmup_start:
                            learning_rate = warmup_start_lr
                        elif iteration < warmup_end:
                            learning_rate = (iteration-warmup_start)*((A_+C_)-warmup_start_lr)/(warmup_end-warmup_start) + warmup_start_lr # learning rate increases from warmup_start_lr to A_ linearly over (warmup_end-warmup_start) iterations.
                        else:
                            if iteration < decay_start:
                                learning_rate = A_ + C_
                            else:
                                iteration_adjusted = iteration - decay_start
                                learning_rate = (A_*(e**(-iteration_adjusted/B_))) + C_
                        assert learning_rate > -1e-8, "Negative Learning Rate."
                        if old_lr != learning_rate:
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = learning_rate
                    else:
                        scheduler.patience = scheduler_patience
                        scheduler.cooldown = scheduler_cooldown
                        if override_scheduler_last_lr:
                            scheduler._last_lr = override_scheduler_last_lr
                        if override_scheduler_best:
                            scheduler.best = override_scheduler_best
                        if override_scheduler_last_lr or override_scheduler_best:
                            print("scheduler._last_lr =", scheduler._last_lr, "scheduler.best =", scheduler.best, "  |", end='')
                    model.zero_grad()
                    mel, audio, speaker_ids = batch
                    mel = torch.autograd.Variable(mel.cuda(non_blocking=True))
                    audio = torch.autograd.Variable(audio.cuda(non_blocking=True))
                    speaker_ids = speaker_ids.cuda(non_blocking=True).long().squeeze(1)
                    outputs = model(mel, audio, speaker_ids)
                    
                    loss = criterion(outputs)
                    if num_gpus > 1:
                        reduced_loss = reduce_tensor(loss.data, num_gpus).item()
                    else:
                        reduced_loss = loss.item()
                    
                    if fp16_run:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    
                    if (reduced_loss > LossExplosionThreshold) or (math.isnan(reduced_loss)):
                        model.zero_grad()
                        raise LossExplosion(f"\nLOSS EXPLOSION EXCEPTION ON RANK {rank}: Loss reached {reduced_loss} during iteration {iteration}.\n\n\n")
                    
                    if use_grad_clip:
                        if fp16_run:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                amp.master_params(optimizer), grad_clip_thresh)
                        else:
                            grad_norm = torch.nn.utils.clip_grad_norm_(
                                model.parameters(), grad_clip_thresh)
                        if type(grad_norm) == torch.Tensor:
                            grad_norm = grad_norm.item()
                        is_overflow = math.isinf(grad_norm) or math.isnan(grad_norm)
                    else: is_overflow = False; grad_norm=0.00001
                    
                    optimizer.step()
                    if not is_overflow and rank == 0:
                        # get current Loss Scale of first optimizer
                        loss_scale = amp._amp_state.loss_scalers[0]._loss_scale if fp16_run else 32768
                        
                        if with_tensorboard:
                            if (iteration % 100000 == 0):
                                # plot distribution of parameters
                                for tag, value in model.named_parameters():
                                    tag = tag.replace('.', '/')
                                    logger.add_histogram(tag, value.data.cpu().numpy(), iteration)
                            logger.add_scalar('training_loss', reduced_loss, iteration)
                            logger.add_scalar('training_loss_samples', reduced_loss, iteration*batch_size)
                            if (iteration % 20 == 0):
                                logger.add_scalar('learning.rate', learning_rate, iteration)
                            if (iteration % 10 == 0):
                                logger.add_scalar('duration', ((time.time() - start_time_dekaiter)/10), iteration)
                        
                        average_loss = rolling_sum.process(reduced_loss)
                        if (iteration % 10 == 0):
                            tqdm.write("{} {}:  {:.3f} {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective)  {:.2f}s/iter {:.4f}s/item".format(time.strftime("%H:%M:%S"), iteration, reduced_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), (time.time() - start_time_dekaiter)/10, ((time.time() - start_time_dekaiter)/10)/(batch_size*num_gpus)))
                            start_time_dekaiter = time.time()
                        else:
                            tqdm.write("{} {}:  {:.3f} {:.3f}  {:.3f} {:08.3F} {:.8f}LR ({:.8f} Effective) {}LS".format(time.strftime("%H:%M:%S"), iteration, reduced_loss, average_loss, best_MSE, round(grad_norm,3), learning_rate, min((grad_clip_thresh/grad_norm)*learning_rate,learning_rate), loss_scale))
                        start_time_iter = time.time()
                    
                    if rank == 0 and (len(rolling_sum.values) > moving_average-2):
                        if (average_loss+best_model_margin) < best_model_loss:
                            checkpoint_path = os.path.join(output_directory, "best_model")
                            try:
                                save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                            checkpoint_path)
                            except KeyboardInterrupt: # Avoid corrupting the model.
                                save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                            checkpoint_path)
                            text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8")
                            text_file.write(str(average_loss)+"\n"+str(iteration))
                            text_file.close()
                            best_model_loss = average_loss #Only save the model if X better than the current loss.
                    if rank == 0 and iteration > 0 and ((iteration % iters_per_checkpoint == 0) or (os.path.exists(save_file_check_path))):
                        checkpoint_path = f"{output_directory}/waveglow_{iteration}"
                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                        checkpoint_path)
                        if (os.path.exists(save_file_check_path)):
                            os.remove(save_file_check_path)
                    
                    if (iteration % validation_interval == 0):
                        if rank == 0:
                            MSE, MAE = validate(model, loader_STFT, STFTs, logger, iteration, data_config['validation_files'], speaker_lookup, sigma, output_directory, data_config)
                            if scheduler:
                                MSE = torch.tensor(MSE, device='cuda')
                                if num_gpus > 1:
                                    broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                                if MSE < best_MSE:
                                    checkpoint_path = os.path.join(output_directory, "best_val_model")
                                    try:
                                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                                    checkpoint_path)
                                    except KeyboardInterrupt: # Avoid corrupting the model.
                                        save_checkpoint(model, optimizer, learning_rate, iteration, amp, scheduler, speaker_lookup,
                                                    checkpoint_path)
                                    text_file = open((f"{checkpoint_path}.txt"), "w", encoding="utf-8")
                                    text_file.write(str(MSE.item())+"\n"+str(iteration))
                                    text_file.close()
                                    best_MSE = MSE.item() #Only save the model if X better than the current loss.
                        else:
                            if scheduler:
                                MSE = torch.zeros(1, device='cuda')
                                broadcast(MSE, 0)
                                scheduler.step(MSE.item())
                        learning_rate = optimizer.param_groups[0]['lr'] #check actual learning rate (because I sometimes see learning_rate variable go out-of-sync with real LR)
                    iteration += 1
            training = False # exit the While loop
        
        except LossExplosion as ex: # print Exception and continue from checkpoint. (turns out it takes < 4 seconds to restart like this, f*****g awesome)
            print(ex) # print Loss
            checkpoint_path = os.path.join(output_directory, "best_model")
            assert os.path.exists(checkpoint_path), "best_val_model must exist for automatic restarts"
            
            # clearing VRAM for load checkpoint
            audio = mel = speaker_ids = loss = None
            torch.cuda.empty_cache()
            
            model.eval()
            model, optimizer, iteration, scheduler = load_checkpoint(checkpoint_path, model, optimizer, scheduler, fp16_run)
            learning_rate = optimizer.param_groups[0]['lr']
            epoch_offset = max(0, int(iteration / len(train_loader)))
            model.train()
            iteration += 1
            pass # and continue training.
コード例 #5
0
def main(pargs):

    #init distributed training
    comm.init(pargs.wireup_method)
    comm_rank = comm.get_rank()
    comm_local_rank = comm.get_local_rank()
    comm_size = comm.get_size()

    #set seed
    seed = 333

    # Some setup
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        printr("Using GPUs", 0)
        device = torch.device("cuda", comm_local_rank)
        torch.cuda.manual_seed(seed)
        #necessary for AMP to work
        torch.cuda.set_device(device)
    else:
        printr("Using CPUs", 0)
        device = torch.device("cpu")

    #set up directories
    root_dir = os.path.join(pargs.data_dir_prefix)
    output_dir = pargs.output_dir
    plot_dir = os.path.join(output_dir, "plots")
    if comm_rank == 0:
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

    # Define architecture
    n_input_channels = len(pargs.channels)
    n_output_channels = 3
    net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels,
                                          n_classes=n_output_channels,
                                          os=16,
                                          pretrained=False,
                                          rank=comm_rank)
    net.to(device)

    #select loss
    loss_pow = pargs.loss_weight_pow
    #some magic numbers
    class_weights = [
        0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow,
        0.01327431072255291**loss_pow
    ]
    fpw_1 = 2.61461122397522257612
    fpw_2 = 1.71641974795896018744
    criterion = losses.fp_loss

    #select optimizer
    optimizer = None
    if pargs.optimizer == "Adam":
        optimizer = optim.Adam(net.parameters(),
                               lr=pargs.start_lr,
                               eps=pargs.adam_eps,
                               weight_decay=pargs.weight_decay)
    elif pargs.optimizer == "AdamW":
        optimizer = optim.AdamW(net.parameters(),
                                lr=pargs.start_lr,
                                eps=pargs.adam_eps,
                                weight_decay=pargs.weight_decay)
    elif have_apex and (pargs.optimizer == "LAMB"):
        optimizer = aoptim.FusedLAMB(net.parameters(),
                                     lr=pargs.start_lr,
                                     eps=pargs.adam_eps,
                                     weight_decay=pargs.weight_decay)
    else:
        raise NotImplementedError("Error, optimizer {} not supported".format(
            pargs.optimizer))

    if have_apex:
        #wrap model and opt into amp
        net, optimizer = amp.initialize(net,
                                        optimizer,
                                        opt_level=pargs.amp_opt_level)

    #make model distributed
    net = DDP(net)

    #select scheduler
    if pargs.lr_schedule:
        scheduler = ph.get_lr_schedule(pargs.start_lr,
                                       pargs.lr_schedule,
                                       optimizer,
                                       last_step=0)

    # Set up the data feeder
    # train
    train_dir = os.path.join(root_dir, "train")
    train_set = cam.CamDataset(train_dir,
                               statsfile=os.path.join(root_dir, 'stats.h5'),
                               channels=pargs.channels,
                               shuffle=True,
                               preprocess=True,
                               comm_size=comm_size,
                               comm_rank=comm_rank)
    train_loader = DataLoader(
        train_set,
        pargs.local_batch_size,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        drop_last=True)

    printr(
        '{:14.4f} REPORT: starting warmup'.format(
            dt.datetime.now().timestamp()), 0)
    step = 0
    current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr(
    )[0]
    current_lr = pargs.start_lr
    net.train()
    while True:

        #for inputs_raw, labels, source in train_loader:
        for inputs, label, filename in train_loader:

            # Print status
            if step == pargs.num_warmup_steps:
                printr(
                    '{:14.4f} REPORT: starting profiling'.format(
                        dt.datetime.now().timestamp()), 0)

            # Forward pass
            with Profile(pargs, "Forward", step):

                #send data to device
                inputs = inputs.to(device)
                label = label.to(device)

                # Compute output
                outputs = net.forward(inputs)

                # Compute loss
                loss = criterion(outputs,
                                 label,
                                 weight=class_weights,
                                 fpw_1=fpw_1,
                                 fpw_2=fpw_2)

            # allreduce for loss
            loss_avg = loss.detach()
            dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM)

            # Compute score
            predictions = torch.max(outputs, 1)[1]
            iou = utils.compute_score(predictions,
                                      label,
                                      device_id=device,
                                      num_classes=3)
            iou_avg = iou.detach()
            dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM)

            # Backprop
            with Profile(pargs, "Backward", step):

                # reset grads
                optimizer.zero_grad()

                # compute grads
                if have_apex:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

            # weight update
            with Profile(pargs, "Optimizer", step):
                # update weights
                optimizer.step()

            # advance the scheduler
            if pargs.lr_schedule:
                current_lr = scheduler.get_last_lr()[0]
                scheduler.step()

            #step counter
            step += 1

            #are we done?
            if step >= (pargs.num_warmup_steps + pargs.num_profile_steps):
                break

        #need to check here too
        if step >= (pargs.num_warmup_steps + pargs.num_profile_steps):
            break

    printr(
        '{:14.4f} REPORT: finishing profiling'.format(
            dt.datetime.now().timestamp()), 0)
コード例 #6
0
def main(pargs):

    #init distributed training
    comm.init(pargs.wireup_method)
    comm_rank = comm.get_rank()
    comm_local_rank = comm.get_local_rank()
    comm_size = comm.get_size()

    #set seed
    seed = 333

    # Some setup
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        printr("Using GPUs", 0)
        device = torch.device("cuda", comm_local_rank)
        torch.cuda.manual_seed(seed)
        #necessary for AMP to work
        torch.cuda.set_device(device)
    else:
        printr("Using CPUs", 0)
        device = torch.device("cpu")

    #visualize?
    visualize = (pargs.training_visualization_frequency >
                 0) or (pargs.validation_visualization_frequency > 0)

    #set up directories
    root_dir = os.path.join(pargs.data_dir_prefix)
    output_dir = pargs.output_dir
    plot_dir = os.path.join(output_dir, "plots")
    if comm_rank == 0:
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
        if visualize and not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)

    # Setup WandB
    if (pargs.logging_frequency > 0) and (comm_rank == 0):
        # get wandb api token
        with open(os.path.join(pargs.wandb_certdir, ".wandbirc")) as f:
            token = f.readlines()[0].replace("\n", "").split()
            wblogin = token[0]
            wbtoken = token[1]
        # log in: that call can be blocking, it should be quick
        sp.call(["wandb", "login", wbtoken])

        #init db and get config
        resume_flag = pargs.run_tag if pargs.resume_logging else False
        wandb.init(entity=wblogin,
                   project='deepcam',
                   name=pargs.run_tag,
                   id=pargs.run_tag,
                   resume=resume_flag)
        config = wandb.config

        #set general parameters
        config.root_dir = root_dir
        config.output_dir = pargs.output_dir
        config.max_epochs = pargs.max_epochs
        config.local_batch_size = pargs.local_batch_size
        config.num_workers = comm_size
        config.channels = pargs.channels
        config.optimizer = pargs.optimizer
        config.start_lr = pargs.start_lr
        config.adam_eps = pargs.adam_eps
        config.weight_decay = pargs.weight_decay
        config.model_prefix = pargs.model_prefix
        config.amp_opt_level = pargs.amp_opt_level
        config.loss_weight_pow = pargs.loss_weight_pow
        config.lr_warmup_steps = pargs.lr_warmup_steps
        config.lr_warmup_factor = pargs.lr_warmup_factor

        # lr schedule if applicable
        if pargs.lr_schedule:
            for key in pargs.lr_schedule:
                config.update({"lr_schedule_" + key: pargs.lr_schedule[key]},
                              allow_val_change=True)

    # Define architecture
    n_input_channels = len(pargs.channels)
    n_output_channels = 3
    net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels,
                                          n_classes=n_output_channels,
                                          os=16,
                                          pretrained=False,
                                          rank=comm_rank)
    net.to(device)

    #select loss
    loss_pow = pargs.loss_weight_pow
    #some magic numbers
    class_weights = [
        0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow,
        0.01327431072255291**loss_pow
    ]
    fpw_1 = 2.61461122397522257612
    fpw_2 = 1.71641974795896018744
    criterion = losses.fp_loss

    #select optimizer
    optimizer = None
    if pargs.optimizer == "Adam":
        optimizer = optim.Adam(net.parameters(),
                               lr=pargs.start_lr,
                               eps=pargs.adam_eps,
                               weight_decay=pargs.weight_decay)
    elif pargs.optimizer == "AdamW":
        optimizer = optim.AdamW(net.parameters(),
                                lr=pargs.start_lr,
                                eps=pargs.adam_eps,
                                weight_decay=pargs.weight_decay)
    elif have_apex and (pargs.optimizer == "LAMB"):
        optimizer = aoptim.FusedLAMB(net.parameters(),
                                     lr=pargs.start_lr,
                                     eps=pargs.adam_eps,
                                     weight_decay=pargs.weight_decay)
    else:
        raise NotImplementedError("Error, optimizer {} not supported".format(
            pargs.optimizer))

    if have_apex:
        #wrap model and opt into amp
        net, optimizer = amp.initialize(net,
                                        optimizer,
                                        opt_level=pargs.amp_opt_level)

    #make model distributed
    net = DDP(net)

    #restart from checkpoint if desired
    #if (comm_rank == 0) and (pargs.checkpoint):
    #load it on all ranks for now
    if pargs.checkpoint:
        checkpoint = torch.load(pargs.checkpoint, map_location=device)
        start_step = checkpoint['step']
        start_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer'])
        net.load_state_dict(checkpoint['model'])
        if have_apex:
            amp.load_state_dict(checkpoint['amp'])
    else:
        start_step = 0
        start_epoch = 0

    #select scheduler
    if pargs.lr_schedule:
        scheduler_after = ph.get_lr_schedule(pargs.start_lr,
                                             pargs.lr_schedule,
                                             optimizer,
                                             last_step=start_step)

        if pargs.lr_warmup_steps > 0:
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=pargs.lr_warmup_factor,
                total_epoch=pargs.lr_warmup_steps,
                after_scheduler=scheduler_after)
        else:
            scheduler = scheduler_after

    #broadcast model and optimizer state
    steptens = torch.tensor(np.array([start_step, start_epoch]),
                            requires_grad=False).to(device)
    dist.broadcast(steptens, src=0)

    ##broadcast model and optimizer state
    #hvd.broadcast_parameters(net.state_dict(), root_rank = 0)
    #hvd.broadcast_optimizer_state(optimizer, root_rank = 0)

    #unpack the bcasted tensor
    start_step = steptens.cpu().numpy()[0]
    start_epoch = steptens.cpu().numpy()[1]

    # Set up the data feeder
    # train
    train_dir = os.path.join(root_dir, "train")
    train_set = cam.CamDataset(train_dir,
                               statsfile=os.path.join(root_dir, 'stats.h5'),
                               channels=pargs.channels,
                               shuffle=True,
                               preprocess=True,
                               comm_size=comm_size,
                               comm_rank=comm_rank)
    train_loader = DataLoader(
        train_set,
        pargs.local_batch_size,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        drop_last=True)

    # validation: we only want to shuffle the set if we are cutting off validation after a certain number of steps
    validation_dir = os.path.join(root_dir, "validation")
    validation_set = cam.CamDataset(validation_dir,
                                    statsfile=os.path.join(
                                        root_dir, 'stats.h5'),
                                    channels=pargs.channels,
                                    shuffle=(pargs.max_validation_steps
                                             is not None),
                                    preprocess=True,
                                    comm_size=comm_size,
                                    comm_rank=comm_rank)
    validation_loader = DataLoader(
        validation_set,
        pargs.local_batch_size,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        drop_last=True)

    #for visualization
    if visualize:
        viz = vizc.CamVisualizer()

    # Train network
    if (pargs.logging_frequency > 0) and (comm_rank == 0):
        wandb.watch(net)

    printr(
        '{:14.4f} REPORT: starting training'.format(
            dt.datetime.now().timestamp()), 0)
    step = start_step
    epoch = start_epoch
    current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr(
    )[0]
    net.train()
    while True:

        printr(
            '{:14.4f} REPORT: starting epoch {}'.format(
                dt.datetime.now().timestamp(), epoch), 0)

        #for inputs_raw, labels, source in train_loader:
        for inputs, label, filename in train_loader:

            #send to device
            inputs = inputs.to(device)
            label = label.to(device)

            # forward pass
            outputs = net.forward(inputs)

            # Compute loss and average across nodes
            loss = criterion(outputs,
                             label,
                             weight=class_weights,
                             fpw_1=fpw_1,
                             fpw_2=fpw_2)

            # allreduce for loss
            loss_avg = loss.detach()
            dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM)

            # Compute score
            predictions = torch.max(outputs, 1)[1]
            iou = utils.compute_score(predictions,
                                      label,
                                      device_id=device,
                                      num_classes=3)
            iou_avg = iou.detach()
            dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM)

            # Backprop
            optimizer.zero_grad()
            if have_apex:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            #step counter
            step += 1

            if pargs.lr_schedule:
                current_lr = scheduler.get_last_lr()[0]
                scheduler.step()

            #print some metrics
            printr(
                '{:14.4f} REPORT training: step {} loss {} iou {} LR {}'.
                format(dt.datetime.now().timestamp(), step,
                       loss_avg.item() / float(comm_size),
                       iou_avg.item() / float(comm_size), current_lr), 0)

            #visualize if requested
            if (step % pargs.training_visualization_frequency
                    == 0) and (comm_rank == 0):
                #extract sample id and data tensors
                sample_idx = np.random.randint(low=0, high=label.shape[0])
                plot_input = inputs.detach()[sample_idx, 0, ...].cpu().numpy()
                plot_prediction = predictions.detach()[sample_idx,
                                                       ...].cpu().numpy()
                plot_label = label.detach()[sample_idx, ...].cpu().numpy()

                #create filenames
                outputfile = os.path.basename(filename[sample_idx]).replace(
                    "data-", "training-").replace(".h5", ".png")
                outputfile = os.path.join(plot_dir, outputfile)

                #plot
                viz.plot(filename[sample_idx], outputfile, plot_input,
                         plot_prediction, plot_label)

                #log if requested
                if pargs.logging_frequency > 0:
                    img = Image.open(outputfile)
                    wandb.log(
                        {
                            "Training Examples": [
                                wandb.Image(
                                    img, caption="Prediction vs. Ground Truth")
                            ]
                        },
                        step=step)

            #log if requested
            if (pargs.logging_frequency > 0) and (
                    step % pargs.logging_frequency == 0) and (comm_rank == 0):
                wandb.log(
                    {"Training Loss": loss_avg.item() / float(comm_size)},
                    step=step)
                wandb.log({"Training IoU": iou_avg.item() / float(comm_size)},
                          step=step)
                wandb.log({"Current Learning Rate": current_lr}, step=step)

            # validation step if desired
            if (step % pargs.validation_frequency == 0):

                #eval
                net.eval()

                count_sum_val = torch.Tensor([0.]).to(device)
                loss_sum_val = torch.Tensor([0.]).to(device)
                iou_sum_val = torch.Tensor([0.]).to(device)

                # disable gradients
                with torch.no_grad():

                    # iterate over validation sample
                    step_val = 0
                    # only print once per eval at most
                    visualized = False
                    for inputs_val, label_val, filename_val in validation_loader:

                        #send to device
                        inputs_val = inputs_val.to(device)
                        label_val = label_val.to(device)

                        # forward pass
                        outputs_val = net.forward(inputs_val)

                        # Compute loss and average across nodes
                        loss_val = criterion(outputs_val,
                                             label_val,
                                             weight=class_weights)
                        loss_sum_val += loss_val

                        #increase counter
                        count_sum_val += 1.

                        # Compute score
                        predictions_val = torch.max(outputs_val, 1)[1]
                        iou_val = utils.compute_score(predictions_val,
                                                      label_val,
                                                      device_id=device,
                                                      num_classes=3)
                        iou_sum_val += iou_val

                        # Visualize
                        if (step_val % pargs.validation_visualization_frequency
                                == 0) and (not visualized) and (comm_rank
                                                                == 0):
                            #extract sample id and data tensors
                            sample_idx = np.random.randint(
                                low=0, high=label_val.shape[0])
                            plot_input = inputs_val.detach()[
                                sample_idx, 0, ...].cpu().numpy()
                            plot_prediction = predictions_val.detach()[
                                sample_idx, ...].cpu().numpy()
                            plot_label = label_val.detach()[sample_idx,
                                                            ...].cpu().numpy()

                            #create filenames
                            outputfile = os.path.basename(
                                filename[sample_idx]).replace(
                                    "data-",
                                    "validation-").replace(".h5", ".png")
                            outputfile = os.path.join(plot_dir, outputfile)

                            #plot
                            viz.plot(filename[sample_idx], outputfile,
                                     plot_input, plot_prediction, plot_label)
                            visualized = True

                            #log if requested
                            if pargs.logging_frequency > 0:
                                img = Image.open(outputfile)
                                wandb.log(
                                    {
                                        "Validation Examples": [
                                            wandb.Image(
                                                img,
                                                caption=
                                                "Prediction vs. Ground Truth")
                                        ]
                                    },
                                    step=step)

                        #increase eval step counter
                        step_val += 1

                        if (pargs.max_validation_steps is not None
                            ) and step_val > pargs.max_validation_steps:
                            break

                # average the validation loss
                dist.reduce(count_sum_val, dst=0, op=dist.ReduceOp.SUM)
                dist.reduce(loss_sum_val, dst=0, op=dist.ReduceOp.SUM)
                dist.reduce(iou_sum_val, dst=0, op=dist.ReduceOp.SUM)
                loss_avg_val = loss_sum_val.item() / count_sum_val.item()
                iou_avg_val = iou_sum_val.item() / count_sum_val.item()

                # print results
                printr(
                    '{:14.4f} REPORT validation: step {} loss {} iou {}'.
                    format(dt.datetime.now().timestamp(), step, loss_avg_val,
                           iou_avg_val), 0)

                # log in wandb
                if (pargs.logging_frequency > 0) and (comm_rank == 0):
                    wandb.log({"Validation Loss": loss_avg_val}, step=step)
                    wandb.log({"Validation IoU": iou_avg_val}, step=step)

                # set to train
                net.train()

            #save model if desired
            if (step % pargs.save_frequency == 0) and (comm_rank == 0):
                checkpoint = {
                    'step': step,
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                if have_apex:
                    checkpoint['amp'] = amp.state_dict()
                torch.save(
                    checkpoint,
                    os.path.join(
                        output_dir,
                        pargs.model_prefix + "_step_" + str(step) + ".cpt"))

        #do some after-epoch prep, just for the books
        epoch += 1
        if comm_rank == 0:

            # Save the model
            checkpoint = {
                'step': step,
                'epoch': epoch,
                'model': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            if have_apex:
                checkpoint['amp'] = amp.state_dict()
            torch.save(
                checkpoint,
                os.path.join(
                    output_dir,
                    pargs.model_prefix + "_epoch_" + str(epoch) + ".cpt"))

        #are we done?
        if epoch >= pargs.max_epochs:
            break

    printr(
        '{:14.4f} REPORT: finishing training'.format(
            dt.datetime.now().timestamp()), 0)