Example #1
0
def main(pargs):

    #init distributed training
    comm.init(pargs.wireup_method)
    comm_rank = comm.get_rank()
    comm_local_rank = comm.get_local_rank()
    comm_size = comm.get_size()

    #set seed
    seed = 333

    # Some setup
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        printr("Using GPUs", 0)
        device = torch.device("cuda", comm_local_rank)
        torch.cuda.manual_seed(seed)
        #necessary for AMP to work
        torch.cuda.set_device(device)
    else:
        printr("Using CPUs", 0)
        device = torch.device("cpu")

    #set up directories
    root_dir = os.path.join(pargs.data_dir_prefix)
    output_dir = pargs.output_dir
    plot_dir = os.path.join(output_dir, "plots")
    if comm_rank == 0:
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

    # Define architecture
    n_input_channels = len(pargs.channels)
    n_output_channels = 3
    net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels,
                                          n_classes=n_output_channels,
                                          os=16,
                                          pretrained=False,
                                          rank=comm_rank)
    net.to(device)

    #select loss
    loss_pow = pargs.loss_weight_pow
    #some magic numbers
    class_weights = [
        0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow,
        0.01327431072255291**loss_pow
    ]
    fpw_1 = 2.61461122397522257612
    fpw_2 = 1.71641974795896018744
    criterion = losses.fp_loss

    #select optimizer
    optimizer = None
    if pargs.optimizer == "Adam":
        optimizer = optim.Adam(net.parameters(),
                               lr=pargs.start_lr,
                               eps=pargs.adam_eps,
                               weight_decay=pargs.weight_decay)
    elif pargs.optimizer == "AdamW":
        optimizer = optim.AdamW(net.parameters(),
                                lr=pargs.start_lr,
                                eps=pargs.adam_eps,
                                weight_decay=pargs.weight_decay)
    elif have_apex and (pargs.optimizer == "LAMB"):
        optimizer = aoptim.FusedLAMB(net.parameters(),
                                     lr=pargs.start_lr,
                                     eps=pargs.adam_eps,
                                     weight_decay=pargs.weight_decay)
    else:
        raise NotImplementedError("Error, optimizer {} not supported".format(
            pargs.optimizer))

    if have_apex:
        #wrap model and opt into amp
        net, optimizer = amp.initialize(net,
                                        optimizer,
                                        opt_level=pargs.amp_opt_level)

    #make model distributed
    net = DDP(net)

    #select scheduler
    if pargs.lr_schedule:
        scheduler = ph.get_lr_schedule(pargs.start_lr,
                                       pargs.lr_schedule,
                                       optimizer,
                                       last_step=0)

    # Set up the data feeder
    # train
    train_dir = os.path.join(root_dir, "train")
    train_set = cam.CamDataset(train_dir,
                               statsfile=os.path.join(root_dir, 'stats.h5'),
                               channels=pargs.channels,
                               shuffle=True,
                               preprocess=True,
                               comm_size=comm_size,
                               comm_rank=comm_rank)
    train_loader = DataLoader(
        train_set,
        pargs.local_batch_size,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        drop_last=True)

    printr(
        '{:14.4f} REPORT: starting warmup'.format(
            dt.datetime.now().timestamp()), 0)
    step = 0
    current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr(
    )[0]
    current_lr = pargs.start_lr
    net.train()
    while True:

        #for inputs_raw, labels, source in train_loader:
        for inputs, label, filename in train_loader:

            # Print status
            if step == pargs.num_warmup_steps:
                printr(
                    '{:14.4f} REPORT: starting profiling'.format(
                        dt.datetime.now().timestamp()), 0)

            # Forward pass
            with Profile(pargs, "Forward", step):

                #send data to device
                inputs = inputs.to(device)
                label = label.to(device)

                # Compute output
                outputs = net.forward(inputs)

                # Compute loss
                loss = criterion(outputs,
                                 label,
                                 weight=class_weights,
                                 fpw_1=fpw_1,
                                 fpw_2=fpw_2)

            # allreduce for loss
            loss_avg = loss.detach()
            dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM)

            # Compute score
            predictions = torch.max(outputs, 1)[1]
            iou = utils.compute_score(predictions,
                                      label,
                                      device_id=device,
                                      num_classes=3)
            iou_avg = iou.detach()
            dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM)

            # Backprop
            with Profile(pargs, "Backward", step):

                # reset grads
                optimizer.zero_grad()

                # compute grads
                if have_apex:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

            # weight update
            with Profile(pargs, "Optimizer", step):
                # update weights
                optimizer.step()

            # advance the scheduler
            if pargs.lr_schedule:
                current_lr = scheduler.get_last_lr()[0]
                scheduler.step()

            #step counter
            step += 1

            #are we done?
            if step >= (pargs.num_warmup_steps + pargs.num_profile_steps):
                break

        #need to check here too
        if step >= (pargs.num_warmup_steps + pargs.num_profile_steps):
            break

    printr(
        '{:14.4f} REPORT: finishing profiling'.format(
            dt.datetime.now().timestamp()), 0)
def main(pargs):

    # this should be global
    global have_wandb

    #init distributed training
    comm.init(pargs.wireup_method)
    comm_rank = comm.get_rank()
    comm_local_rank = comm.get_local_rank()
    comm_size = comm.get_size()

    # set up logging
    pargs.logging_frequency = max([pargs.logging_frequency, 1])
    log_file = os.path.normpath(
        os.path.join(pargs.output_dir, "logs", pargs.run_tag + ".log"))
    logger = mll.mlperf_logger(log_file, "deepcam", "Umbrella Corp.")
    logger.log_start(key="init_start", sync=True)
    logger.log_event(key="cache_clear")

    #set seed
    seed = 333
    logger.log_event(key="seed", value=seed)

    # Some setup
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        device = torch.device("cuda", comm_local_rank)
        torch.cuda.manual_seed(seed)
        #necessary for AMP to work
        torch.cuda.set_device(device)

        # TEST: allowed? Valuable?
        #torch.backends.cudnn.benchark = True
    else:
        device = torch.device("cpu")

    #visualize?
    visualize = (pargs.training_visualization_frequency >
                 0) or (pargs.validation_visualization_frequency > 0)

    #set up directories
    root_dir = os.path.join(pargs.data_dir_prefix)
    output_dir = pargs.output_dir
    plot_dir = os.path.join(output_dir, "plots")
    if comm_rank == 0:
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
        if visualize and not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)

    # Setup WandB
    if not pargs.enable_wandb:
        have_wandb = False
    if have_wandb and (comm_rank == 0):
        # get wandb api token
        certfile = os.path.join(pargs.wandb_certdir, ".wandbirc")
        try:
            with open(certfile) as f:
                token = f.readlines()[0].replace("\n", "").split()
                wblogin = token[0]
                wbtoken = token[1]
        except IOError:
            print("Error, cannot open WandB certificate {}.".format(certfile))
            have_wandb = False

        if have_wandb:
            # log in: that call can be blocking, it should be quick
            sp.call(["wandb", "login", wbtoken])

            #init db and get config
            resume_flag = pargs.run_tag if pargs.resume_logging else False
            wandb.init(entity=wblogin,
                       project='deepcam',
                       name=pargs.run_tag,
                       id=pargs.run_tag,
                       resume=resume_flag)
            config = wandb.config

            #set general parameters
            config.root_dir = root_dir
            config.output_dir = pargs.output_dir
            config.max_epochs = pargs.max_epochs
            config.local_batch_size = pargs.local_batch_size
            config.num_workers = comm_size
            config.channels = pargs.channels
            config.optimizer = pargs.optimizer
            config.start_lr = pargs.start_lr
            config.adam_eps = pargs.adam_eps
            config.weight_decay = pargs.weight_decay
            config.model_prefix = pargs.model_prefix
            config.amp_opt_level = pargs.amp_opt_level
            config.loss_weight_pow = pargs.loss_weight_pow
            config.lr_warmup_steps = pargs.lr_warmup_steps
            config.lr_warmup_factor = pargs.lr_warmup_factor

            # lr schedule if applicable
            if pargs.lr_schedule:
                for key in pargs.lr_schedule:
                    config.update(
                        {"lr_schedule_" + key: pargs.lr_schedule[key]},
                        allow_val_change=True)

    # Logging hyperparameters
    logger.log_event(key="global_batch_size",
                     value=(pargs.local_batch_size * comm_size))
    logger.log_event(key="opt_name", value=pargs.optimizer)
    logger.log_event(key="opt_base_learning_rate",
                     value=pargs.start_lr * pargs.lr_warmup_factor)
    logger.log_event(key="opt_learning_rate_warmup_steps",
                     value=pargs.lr_warmup_steps)
    logger.log_event(key="opt_learning_rate_warmup_factor",
                     value=pargs.lr_warmup_factor)
    logger.log_event(key="opt_epsilon", value=pargs.adam_eps)

    # Define architecture
    n_input_channels = len(pargs.channels)
    n_output_channels = 3
    net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels,
                                          n_classes=n_output_channels,
                                          os=16,
                                          pretrained=False,
                                          rank=comm_rank)
    net.to(device)

    #select loss
    loss_pow = pargs.loss_weight_pow
    #some magic numbers
    class_weights = [
        0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow,
        0.01327431072255291**loss_pow
    ]
    fpw_1 = 2.61461122397522257612
    fpw_2 = 1.71641974795896018744
    criterion = losses.fp_loss

    #select optimizer
    optimizer = None
    if pargs.optimizer == "Adam":
        optimizer = optim.Adam(net.parameters(),
                               lr=pargs.start_lr,
                               eps=pargs.adam_eps,
                               weight_decay=pargs.weight_decay)
    elif pargs.optimizer == "AdamW":
        optimizer = optim.AdamW(net.parameters(),
                                lr=pargs.start_lr,
                                eps=pargs.adam_eps,
                                weight_decay=pargs.weight_decay)
    elif have_apex and (pargs.optimizer == "LAMB"):
        optimizer = aoptim.FusedLAMB(net.parameters(),
                                     lr=pargs.start_lr,
                                     eps=pargs.adam_eps,
                                     weight_decay=pargs.weight_decay)
    else:
        raise NotImplementedError("Error, optimizer {} not supported".format(
            pargs.optimizer))

    if have_apex:
        #wrap model and opt into amp
        net, optimizer = amp.initialize(net,
                                        optimizer,
                                        opt_level=pargs.amp_opt_level)

    #make model distributed
    net = DDP(net)

    #restart from checkpoint if desired
    #if (comm_rank == 0) and (pargs.checkpoint):
    #load it on all ranks for now
    if pargs.checkpoint:
        checkpoint = torch.load(pargs.checkpoint, map_location=device)
        start_step = checkpoint['step']
        start_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer'])
        net.load_state_dict(checkpoint['model'])
        if have_apex:
            amp.load_state_dict(checkpoint['amp'])
    else:
        start_step = 0
        start_epoch = 0

    #select scheduler
    if pargs.lr_schedule:
        scheduler_after = ph.get_lr_schedule(pargs.start_lr,
                                             pargs.lr_schedule,
                                             optimizer,
                                             last_step=start_step)

        # LR warmup
        if pargs.lr_warmup_steps > 0:
            if have_warmup_scheduler:
                scheduler = GradualWarmupScheduler(
                    optimizer,
                    multiplier=pargs.lr_warmup_factor,
                    total_epoch=pargs.lr_warmup_steps,
                    after_scheduler=scheduler_after)
            # Throw an error if the package is not found
            else:
                raise Exception(
                    f'Requested {pargs.lr_warmup_steps} LR warmup steps '
                    'but warmup scheduler not found. Install it from '
                    'https://github.com/ildoonet/pytorch-gradual-warmup-lr')
        else:
            scheduler = scheduler_after

    #broadcast model and optimizer state
    steptens = torch.tensor(np.array([start_step, start_epoch]),
                            requires_grad=False).to(device)
    dist.broadcast(steptens, src=0)

    ##broadcast model and optimizer state
    #hvd.broadcast_parameters(net.state_dict(), root_rank = 0)
    #hvd.broadcast_optimizer_state(optimizer, root_rank = 0)

    #unpack the bcasted tensor
    start_step = steptens.cpu().numpy()[0]
    start_epoch = steptens.cpu().numpy()[1]

    # Set up the data feeder
    # train
    train_dir = os.path.join(root_dir, "train")
    train_set = cam.CamDataset(train_dir,
                               statsfile=os.path.join(root_dir, 'stats.h5'),
                               channels=pargs.channels,
                               allow_uneven_distribution=False,
                               shuffle=True,
                               preprocess=True,
                               comm_size=comm_size,
                               comm_rank=comm_rank)
    train_loader = DataLoader(
        train_set,
        pargs.local_batch_size,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        pin_memory=True,
        drop_last=True)

    # validation: we only want to shuffle the set if we are cutting off validation after a certain number of steps
    validation_dir = os.path.join(root_dir, "validation")
    validation_set = cam.CamDataset(validation_dir,
                                    statsfile=os.path.join(
                                        root_dir, 'stats.h5'),
                                    channels=pargs.channels,
                                    allow_uneven_distribution=True,
                                    shuffle=(pargs.max_validation_steps
                                             is not None),
                                    preprocess=True,
                                    comm_size=comm_size,
                                    comm_rank=comm_rank)
    # use batch size = 1 here to make sure that we do not drop a sample
    validation_loader = DataLoader(
        validation_set,
        1,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        pin_memory=True,
        drop_last=True)

    # log size of datasets
    logger.log_event(key="train_samples", value=train_set.global_size)
    if pargs.max_validation_steps is not None:
        val_size = min([
            validation_set.global_size,
            pargs.max_validation_steps * pargs.local_batch_size * comm_size
        ])
    else:
        val_size = validation_set.global_size
    logger.log_event(key="eval_samples", value=val_size)

    # do sanity check
    if pargs.max_validation_steps is not None:
        logger.log_event(key="invalid_submission")

    #for visualization
    #if visualize:
    #    viz = vizc.CamVisualizer()

    # Train network
    if have_wandb and (comm_rank == 0):
        wandb.watch(net)

    step = start_step
    epoch = start_epoch
    current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr(
    )[0]
    stop_training = False
    net.train()

    # start trining
    logger.log_end(key="init_stop", sync=True)
    logger.log_start(key="run_start", sync=True)

    # training loop
    while True:

        # start epoch
        logger.log_start(key="epoch_start",
                         metadata={
                             'epoch_num': epoch + 1,
                             'step_num': step
                         },
                         sync=True)

        # epoch loop
        for inputs, label, filename in train_loader:

            # send to device
            inputs = inputs.to(device)
            label = label.to(device)

            # forward pass
            outputs = net.forward(inputs)

            # Compute loss and average across nodes
            loss = criterion(outputs,
                             label,
                             weight=class_weights,
                             fpw_1=fpw_1,
                             fpw_2=fpw_2)

            # Backprop
            optimizer.zero_grad()
            if have_apex:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            # step counter
            step += 1

            if pargs.lr_schedule:
                current_lr = scheduler.get_last_lr()[0]
                scheduler.step()

            #visualize if requested
            #if (step % pargs.training_visualization_frequency == 0) and (comm_rank == 0):
            #    # Compute predictions
            #    predictions = torch.max(outputs, 1)[1]
            #
            #    # extract sample id and data tensors
            #    sample_idx = np.random.randint(low=0, high=label.shape[0])
            #    plot_input = inputs.detach()[sample_idx, 0,...].cpu().numpy()
            #    plot_prediction = predictions.detach()[sample_idx,...].cpu().numpy()
            #    plot_label = label.detach()[sample_idx,...].cpu().numpy()
            #
            #    # create filenames
            #    outputfile = os.path.basename(filename[sample_idx]).replace("data-", "training-").replace(".h5", ".png")
            #    outputfile = os.path.join(plot_dir, outputfile)
            #
            #    # plot
            #    viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label)
            #
            #    #log if requested
            #    if have_wandb:
            #        img = Image.open(outputfile)
            #        wandb.log({"train_examples": [wandb.Image(img, caption="Prediction vs. Ground Truth")]}, step = step)

            #log if requested
            if (step % pargs.logging_frequency == 0):

                # allreduce for loss
                loss_avg = loss.detach()
                dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM)
                loss_avg_train = loss_avg.item() / float(comm_size)

                # Compute score
                predictions = torch.max(outputs, 1)[1]
                iou = utils.compute_score(predictions,
                                          label,
                                          device_id=device,
                                          num_classes=3)
                iou_avg = iou.detach()
                dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM)
                iou_avg_train = iou_avg.item() / float(comm_size)

                logger.log_event(key="learning_rate",
                                 value=current_lr,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })
                logger.log_event(key="train_accuracy",
                                 value=iou_avg_train,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })
                logger.log_event(key="train_loss",
                                 value=loss_avg_train,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })

                if have_wandb and (comm_rank == 0):
                    wandb.log(
                        {"train_loss": loss_avg.item() / float(comm_size)},
                        step=step)
                    wandb.log(
                        {"train_accuracy": iou_avg.item() / float(comm_size)},
                        step=step)
                    wandb.log({"learning_rate": current_lr}, step=step)
                    wandb.log({"epoch": epoch + 1}, step=step)

            # validation step if desired
            if (step % pargs.validation_frequency == 0):

                logger.log_start(key="eval_start",
                                 metadata={'epoch_num': epoch + 1})

                #eval
                net.eval()

                count_sum_val = torch.Tensor([0.]).to(device)
                loss_sum_val = torch.Tensor([0.]).to(device)
                iou_sum_val = torch.Tensor([0.]).to(device)

                # disable gradients
                with torch.no_grad():

                    # iterate over validation sample
                    step_val = 0
                    # only print once per eval at most
                    visualized = False
                    for inputs_val, label_val, filename_val in validation_loader:

                        #send to device
                        inputs_val = inputs_val.to(device)
                        label_val = label_val.to(device)

                        # forward pass
                        outputs_val = net.forward(inputs_val)

                        # Compute loss and average across nodes
                        loss_val = criterion(outputs_val,
                                             label_val,
                                             weight=class_weights,
                                             fpw_1=fpw_1,
                                             fpw_2=fpw_2)
                        loss_sum_val += loss_val

                        #increase counter
                        count_sum_val += 1.

                        # Compute score
                        predictions_val = torch.max(outputs_val, 1)[1]
                        iou_val = utils.compute_score(predictions_val,
                                                      label_val,
                                                      device_id=device,
                                                      num_classes=3)
                        iou_sum_val += iou_val

                        # Visualize
                        #if (step_val % pargs.validation_visualization_frequency == 0) and (not visualized) and (comm_rank == 0):
                        #    #extract sample id and data tensors
                        #    sample_idx = np.random.randint(low=0, high=label_val.shape[0])
                        #    plot_input = inputs_val.detach()[sample_idx, 0,...].cpu().numpy()
                        #    plot_prediction = predictions_val.detach()[sample_idx,...].cpu().numpy()
                        #    plot_label = label_val.detach()[sample_idx,...].cpu().numpy()
                        #
                        #    #create filenames
                        #    outputfile = os.path.basename(filename[sample_idx]).replace("data-", "validation-").replace(".h5", ".png")
                        #    outputfile = os.path.join(plot_dir, outputfile)
                        #
                        #    #plot
                        #    viz.plot(filename[sample_idx], outputfile, plot_input, plot_prediction, plot_label)
                        #    visualized = True
                        #
                        #    #log if requested
                        #    if have_wandb:
                        #        img = Image.open(outputfile)
                        #        wandb.log({"eval_examples": [wandb.Image(img, caption="Prediction vs. Ground Truth")]}, step = step)

                        #increase eval step counter
                        step_val += 1

                        if (pargs.max_validation_steps is not None
                            ) and step_val > pargs.max_validation_steps:
                            break

                # average the validation loss
                dist.all_reduce(count_sum_val, op=dist.ReduceOp.SUM)
                dist.all_reduce(loss_sum_val, op=dist.ReduceOp.SUM)
                dist.all_reduce(iou_sum_val, op=dist.ReduceOp.SUM)
                loss_avg_val = loss_sum_val.item() / count_sum_val.item()
                iou_avg_val = iou_sum_val.item() / count_sum_val.item()

                # print results
                logger.log_event(key="eval_accuracy",
                                 value=iou_avg_val,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })
                logger.log_event(key="eval_loss",
                                 value=loss_avg_val,
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 })

                # log in wandb
                if have_wandb and (comm_rank == 0):
                    wandb.log({"eval_loss": loss_avg_val}, step=step)
                    wandb.log({"eval_accuracy": iou_avg_val}, step=step)

                if (iou_avg_val >= pargs.target_iou):
                    logger.log_event(key="target_accuracy_reached",
                                     value=pargs.target_iou,
                                     metadata={
                                         'epoch_num': epoch + 1,
                                         'step_num': step
                                     })
                    stop_training = True

                # set to train
                net.train()

                logger.log_end(key="eval_stop",
                               metadata={'epoch_num': epoch + 1})

            #save model if desired
            if (pargs.save_frequency > 0) and (step % pargs.save_frequency
                                               == 0):
                logger.log_start(key="save_start",
                                 metadata={
                                     'epoch_num': epoch + 1,
                                     'step_num': step
                                 },
                                 sync=True)
                if comm_rank == 0:
                    checkpoint = {
                        'step': step,
                        'epoch': epoch,
                        'model': net.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    if have_apex:
                        checkpoint['amp'] = amp.state_dict()
                    torch.save(
                        checkpoint,
                        os.path.join(
                            output_dir, pargs.model_prefix + "_step_" +
                            str(step) + ".cpt"))
                logger.log_end(key="save_stop",
                               metadata={
                                   'epoch_num': epoch + 1,
                                   'step_num': step
                               },
                               sync=True)

            # Stop training?
            if stop_training:
                break

        # log the epoch
        logger.log_end(key="epoch_stop",
                       metadata={
                           'epoch_num': epoch + 1,
                           'step_num': step
                       },
                       sync=True)
        epoch += 1

        # are we done?
        if epoch >= pargs.max_epochs or stop_training:
            break

    # run done
    logger.log_end(key="run_stop", sync=True, metadata={'status': 'success'})
def main(pargs):

    #init distributed training
    comm.init(pargs.wireup_method)
    comm_rank = comm.get_rank()
    comm_local_rank = comm.get_local_rank()
    comm_size = comm.get_size()

    #set seed
    seed = 333

    # Some setup
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        printr("Using GPUs", 0)
        device = torch.device("cuda", comm_local_rank)
        torch.cuda.manual_seed(seed)
        #necessary for AMP to work
        torch.cuda.set_device(device)
    else:
        printr("Using CPUs", 0)
        device = torch.device("cpu")

    #visualize?
    visualize = (pargs.training_visualization_frequency >
                 0) or (pargs.validation_visualization_frequency > 0)

    #set up directories
    root_dir = os.path.join(pargs.data_dir_prefix)
    output_dir = pargs.output_dir
    plot_dir = os.path.join(output_dir, "plots")
    if comm_rank == 0:
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
        if visualize and not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)

    # Setup WandB
    if (pargs.logging_frequency > 0) and (comm_rank == 0):
        # get wandb api token
        with open(os.path.join(pargs.wandb_certdir, ".wandbirc")) as f:
            token = f.readlines()[0].replace("\n", "").split()
            wblogin = token[0]
            wbtoken = token[1]
        # log in: that call can be blocking, it should be quick
        sp.call(["wandb", "login", wbtoken])

        #init db and get config
        resume_flag = pargs.run_tag if pargs.resume_logging else False
        wandb.init(entity=wblogin,
                   project='deepcam',
                   name=pargs.run_tag,
                   id=pargs.run_tag,
                   resume=resume_flag)
        config = wandb.config

        #set general parameters
        config.root_dir = root_dir
        config.output_dir = pargs.output_dir
        config.max_epochs = pargs.max_epochs
        config.local_batch_size = pargs.local_batch_size
        config.num_workers = comm_size
        config.channels = pargs.channels
        config.optimizer = pargs.optimizer
        config.start_lr = pargs.start_lr
        config.adam_eps = pargs.adam_eps
        config.weight_decay = pargs.weight_decay
        config.model_prefix = pargs.model_prefix
        config.amp_opt_level = pargs.amp_opt_level
        config.loss_weight_pow = pargs.loss_weight_pow
        config.lr_warmup_steps = pargs.lr_warmup_steps
        config.lr_warmup_factor = pargs.lr_warmup_factor

        # lr schedule if applicable
        if pargs.lr_schedule:
            for key in pargs.lr_schedule:
                config.update({"lr_schedule_" + key: pargs.lr_schedule[key]},
                              allow_val_change=True)

    # Define architecture
    n_input_channels = len(pargs.channels)
    n_output_channels = 3
    net = deeplab_xception.DeepLabv3_plus(n_input=n_input_channels,
                                          n_classes=n_output_channels,
                                          os=16,
                                          pretrained=False,
                                          rank=comm_rank)
    net.to(device)

    #select loss
    loss_pow = pargs.loss_weight_pow
    #some magic numbers
    class_weights = [
        0.986267818390377**loss_pow, 0.0004578708870701058**loss_pow,
        0.01327431072255291**loss_pow
    ]
    fpw_1 = 2.61461122397522257612
    fpw_2 = 1.71641974795896018744
    criterion = losses.fp_loss

    #select optimizer
    optimizer = None
    if pargs.optimizer == "Adam":
        optimizer = optim.Adam(net.parameters(),
                               lr=pargs.start_lr,
                               eps=pargs.adam_eps,
                               weight_decay=pargs.weight_decay)
    elif pargs.optimizer == "AdamW":
        optimizer = optim.AdamW(net.parameters(),
                                lr=pargs.start_lr,
                                eps=pargs.adam_eps,
                                weight_decay=pargs.weight_decay)
    elif have_apex and (pargs.optimizer == "LAMB"):
        optimizer = aoptim.FusedLAMB(net.parameters(),
                                     lr=pargs.start_lr,
                                     eps=pargs.adam_eps,
                                     weight_decay=pargs.weight_decay)
    else:
        raise NotImplementedError("Error, optimizer {} not supported".format(
            pargs.optimizer))

    if have_apex:
        #wrap model and opt into amp
        net, optimizer = amp.initialize(net,
                                        optimizer,
                                        opt_level=pargs.amp_opt_level)

    #make model distributed
    net = DDP(net)

    #restart from checkpoint if desired
    #if (comm_rank == 0) and (pargs.checkpoint):
    #load it on all ranks for now
    if pargs.checkpoint:
        checkpoint = torch.load(pargs.checkpoint, map_location=device)
        start_step = checkpoint['step']
        start_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer'])
        net.load_state_dict(checkpoint['model'])
        if have_apex:
            amp.load_state_dict(checkpoint['amp'])
    else:
        start_step = 0
        start_epoch = 0

    #select scheduler
    if pargs.lr_schedule:
        scheduler_after = ph.get_lr_schedule(pargs.start_lr,
                                             pargs.lr_schedule,
                                             optimizer,
                                             last_step=start_step)

        if pargs.lr_warmup_steps > 0:
            scheduler = GradualWarmupScheduler(
                optimizer,
                multiplier=pargs.lr_warmup_factor,
                total_epoch=pargs.lr_warmup_steps,
                after_scheduler=scheduler_after)
        else:
            scheduler = scheduler_after

    #broadcast model and optimizer state
    steptens = torch.tensor(np.array([start_step, start_epoch]),
                            requires_grad=False).to(device)
    dist.broadcast(steptens, src=0)

    ##broadcast model and optimizer state
    #hvd.broadcast_parameters(net.state_dict(), root_rank = 0)
    #hvd.broadcast_optimizer_state(optimizer, root_rank = 0)

    #unpack the bcasted tensor
    start_step = steptens.cpu().numpy()[0]
    start_epoch = steptens.cpu().numpy()[1]

    # Set up the data feeder
    # train
    train_dir = os.path.join(root_dir, "train")
    train_set = cam.CamDataset(train_dir,
                               statsfile=os.path.join(root_dir, 'stats.h5'),
                               channels=pargs.channels,
                               shuffle=True,
                               preprocess=True,
                               comm_size=comm_size,
                               comm_rank=comm_rank)
    train_loader = DataLoader(
        train_set,
        pargs.local_batch_size,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        drop_last=True)

    # validation: we only want to shuffle the set if we are cutting off validation after a certain number of steps
    validation_dir = os.path.join(root_dir, "validation")
    validation_set = cam.CamDataset(validation_dir,
                                    statsfile=os.path.join(
                                        root_dir, 'stats.h5'),
                                    channels=pargs.channels,
                                    shuffle=(pargs.max_validation_steps
                                             is not None),
                                    preprocess=True,
                                    comm_size=comm_size,
                                    comm_rank=comm_rank)
    validation_loader = DataLoader(
        validation_set,
        pargs.local_batch_size,
        num_workers=min([pargs.max_inter_threads, pargs.local_batch_size]),
        drop_last=True)

    #for visualization
    if visualize:
        viz = vizc.CamVisualizer()

    # Train network
    if (pargs.logging_frequency > 0) and (comm_rank == 0):
        wandb.watch(net)

    printr(
        '{:14.4f} REPORT: starting training'.format(
            dt.datetime.now().timestamp()), 0)
    step = start_step
    epoch = start_epoch
    current_lr = pargs.start_lr if not pargs.lr_schedule else scheduler.get_last_lr(
    )[0]
    net.train()
    while True:

        printr(
            '{:14.4f} REPORT: starting epoch {}'.format(
                dt.datetime.now().timestamp(), epoch), 0)

        #for inputs_raw, labels, source in train_loader:
        for inputs, label, filename in train_loader:

            #send to device
            inputs = inputs.to(device)
            label = label.to(device)

            # forward pass
            outputs = net.forward(inputs)

            # Compute loss and average across nodes
            loss = criterion(outputs,
                             label,
                             weight=class_weights,
                             fpw_1=fpw_1,
                             fpw_2=fpw_2)

            # allreduce for loss
            loss_avg = loss.detach()
            dist.reduce(loss_avg, dst=0, op=dist.ReduceOp.SUM)

            # Compute score
            predictions = torch.max(outputs, 1)[1]
            iou = utils.compute_score(predictions,
                                      label,
                                      device_id=device,
                                      num_classes=3)
            iou_avg = iou.detach()
            dist.reduce(iou_avg, dst=0, op=dist.ReduceOp.SUM)

            # Backprop
            optimizer.zero_grad()
            if have_apex:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()

            #step counter
            step += 1

            if pargs.lr_schedule:
                current_lr = scheduler.get_last_lr()[0]
                scheduler.step()

            #print some metrics
            printr(
                '{:14.4f} REPORT training: step {} loss {} iou {} LR {}'.
                format(dt.datetime.now().timestamp(), step,
                       loss_avg.item() / float(comm_size),
                       iou_avg.item() / float(comm_size), current_lr), 0)

            #visualize if requested
            if (step % pargs.training_visualization_frequency
                    == 0) and (comm_rank == 0):
                #extract sample id and data tensors
                sample_idx = np.random.randint(low=0, high=label.shape[0])
                plot_input = inputs.detach()[sample_idx, 0, ...].cpu().numpy()
                plot_prediction = predictions.detach()[sample_idx,
                                                       ...].cpu().numpy()
                plot_label = label.detach()[sample_idx, ...].cpu().numpy()

                #create filenames
                outputfile = os.path.basename(filename[sample_idx]).replace(
                    "data-", "training-").replace(".h5", ".png")
                outputfile = os.path.join(plot_dir, outputfile)

                #plot
                viz.plot(filename[sample_idx], outputfile, plot_input,
                         plot_prediction, plot_label)

                #log if requested
                if pargs.logging_frequency > 0:
                    img = Image.open(outputfile)
                    wandb.log(
                        {
                            "Training Examples": [
                                wandb.Image(
                                    img, caption="Prediction vs. Ground Truth")
                            ]
                        },
                        step=step)

            #log if requested
            if (pargs.logging_frequency > 0) and (
                    step % pargs.logging_frequency == 0) and (comm_rank == 0):
                wandb.log(
                    {"Training Loss": loss_avg.item() / float(comm_size)},
                    step=step)
                wandb.log({"Training IoU": iou_avg.item() / float(comm_size)},
                          step=step)
                wandb.log({"Current Learning Rate": current_lr}, step=step)

            # validation step if desired
            if (step % pargs.validation_frequency == 0):

                #eval
                net.eval()

                count_sum_val = torch.Tensor([0.]).to(device)
                loss_sum_val = torch.Tensor([0.]).to(device)
                iou_sum_val = torch.Tensor([0.]).to(device)

                # disable gradients
                with torch.no_grad():

                    # iterate over validation sample
                    step_val = 0
                    # only print once per eval at most
                    visualized = False
                    for inputs_val, label_val, filename_val in validation_loader:

                        #send to device
                        inputs_val = inputs_val.to(device)
                        label_val = label_val.to(device)

                        # forward pass
                        outputs_val = net.forward(inputs_val)

                        # Compute loss and average across nodes
                        loss_val = criterion(outputs_val,
                                             label_val,
                                             weight=class_weights)
                        loss_sum_val += loss_val

                        #increase counter
                        count_sum_val += 1.

                        # Compute score
                        predictions_val = torch.max(outputs_val, 1)[1]
                        iou_val = utils.compute_score(predictions_val,
                                                      label_val,
                                                      device_id=device,
                                                      num_classes=3)
                        iou_sum_val += iou_val

                        # Visualize
                        if (step_val % pargs.validation_visualization_frequency
                                == 0) and (not visualized) and (comm_rank
                                                                == 0):
                            #extract sample id and data tensors
                            sample_idx = np.random.randint(
                                low=0, high=label_val.shape[0])
                            plot_input = inputs_val.detach()[
                                sample_idx, 0, ...].cpu().numpy()
                            plot_prediction = predictions_val.detach()[
                                sample_idx, ...].cpu().numpy()
                            plot_label = label_val.detach()[sample_idx,
                                                            ...].cpu().numpy()

                            #create filenames
                            outputfile = os.path.basename(
                                filename[sample_idx]).replace(
                                    "data-",
                                    "validation-").replace(".h5", ".png")
                            outputfile = os.path.join(plot_dir, outputfile)

                            #plot
                            viz.plot(filename[sample_idx], outputfile,
                                     plot_input, plot_prediction, plot_label)
                            visualized = True

                            #log if requested
                            if pargs.logging_frequency > 0:
                                img = Image.open(outputfile)
                                wandb.log(
                                    {
                                        "Validation Examples": [
                                            wandb.Image(
                                                img,
                                                caption=
                                                "Prediction vs. Ground Truth")
                                        ]
                                    },
                                    step=step)

                        #increase eval step counter
                        step_val += 1

                        if (pargs.max_validation_steps is not None
                            ) and step_val > pargs.max_validation_steps:
                            break

                # average the validation loss
                dist.reduce(count_sum_val, dst=0, op=dist.ReduceOp.SUM)
                dist.reduce(loss_sum_val, dst=0, op=dist.ReduceOp.SUM)
                dist.reduce(iou_sum_val, dst=0, op=dist.ReduceOp.SUM)
                loss_avg_val = loss_sum_val.item() / count_sum_val.item()
                iou_avg_val = iou_sum_val.item() / count_sum_val.item()

                # print results
                printr(
                    '{:14.4f} REPORT validation: step {} loss {} iou {}'.
                    format(dt.datetime.now().timestamp(), step, loss_avg_val,
                           iou_avg_val), 0)

                # log in wandb
                if (pargs.logging_frequency > 0) and (comm_rank == 0):
                    wandb.log({"Validation Loss": loss_avg_val}, step=step)
                    wandb.log({"Validation IoU": iou_avg_val}, step=step)

                # set to train
                net.train()

            #save model if desired
            if (step % pargs.save_frequency == 0) and (comm_rank == 0):
                checkpoint = {
                    'step': step,
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                if have_apex:
                    checkpoint['amp'] = amp.state_dict()
                torch.save(
                    checkpoint,
                    os.path.join(
                        output_dir,
                        pargs.model_prefix + "_step_" + str(step) + ".cpt"))

        #do some after-epoch prep, just for the books
        epoch += 1
        if comm_rank == 0:

            # Save the model
            checkpoint = {
                'step': step,
                'epoch': epoch,
                'model': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            if have_apex:
                checkpoint['amp'] = amp.state_dict()
            torch.save(
                checkpoint,
                os.path.join(
                    output_dir,
                    pargs.model_prefix + "_epoch_" + str(epoch) + ".cpt"))

        #are we done?
        if epoch >= pargs.max_epochs:
            break

    printr(
        '{:14.4f} REPORT: finishing training'.format(
            dt.datetime.now().timestamp()), 0)