Exemple #1
0
def get_lr_scheduler(optimizer, num_iterations_per_epoch, config):
    lr_max_value = config['lr_max_value']
    warmup_duration = config['warmup_duration'] * num_iterations_per_epoch
    num_iterations = config['num_epochs'] * num_iterations_per_epoch
    cooldown_duration = config['cooldown_duration'] * num_iterations_per_epoch

    scheduler_1 = LinearCyclicalScheduler(
        optimizer,
        "lr",
        start_value=lr_max_value,
        end_value=lr_max_value * 0.4,
        cycle_size=(num_iterations - warmup_duration - cooldown_duration) * 2)

    scheduler_2 = LinearCyclicalScheduler(optimizer,
                                          "lr",
                                          start_value=lr_max_value * 0.2,
                                          end_value=lr_max_value * 0.01,
                                          cycle_size=cooldown_duration * 2)

    lr_scheduler = ConcatScheduler(schedulers=[
        scheduler_1,
        scheduler_2,
    ],
                                   durations=[
                                       num_iterations - warmup_duration -
                                       cooldown_duration,
                                   ])

    return create_lr_scheduler_with_warmup(
        lr_scheduler,
        warmup_start_value=0.0,
        warmup_end_value=lr_max_value,
        warmup_duration=warmup_duration,
        save_history=True,
    )
Exemple #2
0
def get_momentum_scheduler(optimizer, num_iterations_per_epoch, config):

    warmup_duration = config['warmup_duration'] * num_iterations_per_epoch
    num_iterations = config['num_epochs'] * num_iterations_per_epoch
    cooldown_duration = config['cooldown_duration'] * num_iterations_per_epoch

    scheduler_1 = LinearCyclicalScheduler(optimizer,
                                          "momentum",
                                          start_value=0.0,
                                          end_value=0.9,
                                          cycle_size=warmup_duration * 2)

    scheduler_2 = LinearCyclicalScheduler(
        optimizer,
        "momentum",
        start_value=0.9,
        end_value=0.9,
        cycle_size=(num_iterations - warmup_duration - cooldown_duration) * 2)

    scheduler_3 = LinearCyclicalScheduler(optimizer,
                                          "momentum",
                                          start_value=0.9,
                                          end_value=0.5,
                                          cycle_size=cooldown_duration * 2)

    momentum_scheduler = ConcatScheduler(
        schedulers=[scheduler_1, scheduler_2, scheduler_3],
        durations=[
            warmup_duration,
            num_iterations - warmup_duration - cooldown_duration,
        ])

    return momentum_scheduler
 def _init_scheduler(self):
     if self.hparams.scheduler_name == "none":
         self.scheduler = None
     elif self.hparams.scheduler_name == "warmup_with_cosine":
         from ignite.contrib.handlers import LinearCyclicalScheduler, CosineAnnealingScheduler, ConcatScheduler
         lr = self.hparams.lr
         if self.hparams.run_params["epoch_length"]:
             epoch_length = self.hparams.run_params["epoch_length"]
         else:
             epoch_length = len(self.train_loader)
         num_epochs = self.hparams.run_params["max_epochs"]
         scheduler_1 = LinearCyclicalScheduler(self.optimizer, "lr", start_value=lr*0.01, end_value=lr, cycle_size=epoch_length*2)
         scheduler_2 = CosineAnnealingScheduler(self.optimizer, "lr", start_value=lr, end_value=lr*0.001, cycle_size=num_epochs*epoch_length)
         durations = [epoch_length, ]
         self.scheduler = ConcatScheduler(schedulers=[scheduler_1, scheduler_2], durations=durations)
     elif self.hparams.scheduler_name == "warmup_with_cosine_100":
         from ignite.contrib.handlers import LinearCyclicalScheduler, CosineAnnealingScheduler, ConcatScheduler
         lr = self.hparams.lr
         if self.hparams.run_params["epoch_length"]:
             epoch_length = self.hparams.run_params["epoch_length"]
         else:
             epoch_length = len(self.train_loader)
         num_epochs = self.hparams.run_params["max_epochs"]
         scheduler_1 = LinearCyclicalScheduler(self.optimizer, "lr", start_value=lr*0.01, end_value=lr, cycle_size=epoch_length*2)
         scheduler_2 = CosineAnnealingScheduler(self.optimizer, "lr", start_value=lr, end_value=lr*0.01, cycle_size=num_epochs*epoch_length)
         durations = [epoch_length, ]
         self.scheduler = ConcatScheduler(schedulers=[scheduler_1, scheduler_2], durations=durations)
     elif self.hparams.scheduler_name == "warmup_with_cosine_10":
         from ignite.contrib.handlers import LinearCyclicalScheduler, CosineAnnealingScheduler, ConcatScheduler
         lr = self.hparams.lr
         if self.hparams.run_params["epoch_length"]:
             epoch_length = self.hparams.run_params["epoch_length"]
         else:
             epoch_length = len(self.train_loader)
         num_epochs = self.hparams.run_params["max_epochs"]
         scheduler_1 = LinearCyclicalScheduler(self.optimizer, "lr", start_value=lr*0.1, end_value=lr, cycle_size=epoch_length*2)
         scheduler_2 = CosineAnnealingScheduler(self.optimizer, "lr", start_value=lr, end_value=lr*0.1, cycle_size=num_epochs*epoch_length)
         durations = [epoch_length, ]
         self.scheduler = ConcatScheduler(schedulers=[scheduler_1, scheduler_2], durations=durations)
     elif self.hparams.scheduler_name == "one_cycle_cosine_10":
         from ignite.contrib.handlers import CosineAnnealingScheduler
         lr = self.hparams.lr
         if self.hparams.run_params["epoch_length"]:
             epoch_length = self.hparams.run_params["epoch_length"]
         else:
             epoch_length = len(self.train_loader)
         num_epochs = self.hparams.run_params["max_epochs"]
         self.scheduler  = CosineAnnealingScheduler(self.optimizer, "lr", start_value=lr, end_value=lr*0.1, cycle_size=num_epochs*epoch_length)
     elif self.hparams.scheduler_name == "one_cycle_cosine_100":
         from ignite.contrib.handlers import CosineAnnealingScheduler
         lr = self.hparams.lr
         if self.hparams.run_params["epoch_length"]:
             epoch_length = self.hparams.run_params["epoch_length"]
         else:
             epoch_length = len(self.train_loader)
         num_epochs = self.hparams.run_params["max_epochs"]
         self.scheduler  = CosineAnnealingScheduler(self.optimizer, "lr", start_value=lr, end_value=lr*0.01, cycle_size=num_epochs*epoch_length)
Exemple #4
0
def create_lr_scheduler(opt, args, name = None, num_steps=1):
    if name is None:
        name = args.lr_scheduler.lower()

    g = args.gamma

    if name == 'plateau':
        from .lr_scheduler import ReduceLROnPlateau1
        p = args.patience
        pf = args.patience_factor
        mp = args.max_patience
        ml = args.min_lr
        t = args.threshold
        sched = ReduceLROnPlateau1(
            opt,
            factor = g,
            patience = p,
            patience_factor = pf,
            max_patience = mp,
            min_lr = ml,
            threshold = t,
            verbose = True
        )

    elif name == 'warmup':
        from .lr_scheduler import LinearLR
        for param_group in opt.param_groups:
            param_group['lr'] = args.lr_start
        n = int(num_steps * args.lr_warmup)
        sched = LinearLR(opt, args.lr, n)

    elif name == 'step':
        from torch.optim.lr_scheduler import StepLR
        s = args.step_size
        sched = StepLR(opt, step_size=s, gamma=g)

    elif name == 'multistep':
        from torch.optim.lr_scheduler import MultiStepLR
        m = args.milestones
        sched = MultiStepLR(opt, milestones=m, gamma=g)

    elif name == 'exponential':
        from torch.optim.lr_scheduler import ExponentialLR
        sched = ExponentialLR(opt, gamma=g)

    elif name == 'linearcycle':
        from ignite.contrib.handlers import LinearCyclicalScheduler
        n = int(num_steps * args.epochs)
        sched = LinearCyclicalScheduler(opt, 'lr', args.lr_start, args.lr, n)

    else:
        raise ValueError(
            'lr_scheduler must be one of plateau, step, multistep, exponential, '
            'linearcycle'
        )

    return sched
Exemple #5
0
def make_slanted_triangular_lr_scheduler(optimizer,
                                         n_events,
                                         lr_max,
                                         frac=0.1,
                                         ratio=32):
    n1 = int(n_events * frac)
    n2 = n_events - n1
    scheduler_1 = LinearCyclicalScheduler(optimizer,
                                          'lr',
                                          start_value=lr_max / ratio,
                                          end_value=lr_max,
                                          cycle_size=n1 * 2)
    scheduler_2 = LinearCyclicalScheduler(optimizer,
                                          'lr',
                                          start_value=lr_max,
                                          end_value=lr_max / ratio,
                                          cycle_size=n2 * 2)
    return ConcatScheduler([scheduler_1, scheduler_2], durations=[
        n1,
    ])
def run(*options, cfg=None, local_rank=0, debug=False):
    """Run training and validation of model

    Notes:
        Options can be passed in via the options argument and loaded from the cfg file
        Options from default.py will be overridden by options loaded from cfg file
        Options passed in via options argument will override option loaded from cfg file

    Args:
        *options (str,int ,optional): Options used to overide what is loaded from the
                                      config. To see what options are available consult
                                      default.py
        cfg (str, optional): Location of config file to load. Defaults to None.
    """
    update_config(config, options=options, config_file=cfg)

    # we will write the model under outputs / config_file_name / model_dir
    config_file_name = "default_config" if not cfg else cfg.split(
        "/")[-1].split(".")[0]

    # Start logging
    load_log_configuration(config.LOG_CONFIG)
    logger = logging.getLogger(__name__)
    logger.debug(config.WORKERS)
    silence_other_ranks = True
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    distributed = world_size > 1

    if distributed:
        # FOR DISTRIBUTED: Set the device according to local_rank.
        torch.cuda.set_device(local_rank)

        # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will
        # provide environment variables, and requires that you use init_method=`env://`.
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")

    epochs_per_cycle = config.TRAIN.END_EPOCH // config.TRAIN.SNAPSHOTS
    torch.backends.cudnn.benchmark = config.CUDNN.BENCHMARK

    torch.manual_seed(config.SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(config.SEED)
    np.random.seed(seed=config.SEED)
    # Setup Augmentations
    basic_aug = Compose([
        Normalize(mean=(config.TRAIN.MEAN, ),
                  std=(config.TRAIN.STD, ),
                  max_pixel_value=1),
        PadIfNeeded(
            min_height=config.TRAIN.PATCH_SIZE,
            min_width=config.TRAIN.PATCH_SIZE,
            border_mode=config.OPENCV_BORDER_CONSTANT,
            always_apply=True,
            mask_value=255,
        ),
        Resize(
            config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT,
            config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH,
            always_apply=True,
        ),
        PadIfNeeded(
            min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT,
            min_width=config.TRAIN.AUGMENTATIONS.PAD.WIDTH,
            border_mode=config.OPENCV_BORDER_CONSTANT,
            always_apply=True,
            mask_value=255,
        ),
    ])
    if config.TRAIN.AUGMENTATION:
        train_aug = Compose([basic_aug, HorizontalFlip(p=0.5)])
        val_aug = basic_aug
    else:
        train_aug = val_aug = basic_aug

    TrainPatchLoader = get_patch_loader(config)

    train_set = TrainPatchLoader(
        config.DATASET.ROOT,
        split="train",
        is_transform=True,
        stride=config.TRAIN.STRIDE,
        patch_size=config.TRAIN.PATCH_SIZE,
        augmentations=train_aug,
    )

    val_set = TrainPatchLoader(
        config.DATASET.ROOT,
        split="val",
        is_transform=True,
        stride=config.TRAIN.STRIDE,
        patch_size=config.TRAIN.PATCH_SIZE,
        augmentations=val_aug,
    )

    logger.info(f"Validation examples {len(val_set)}")
    n_classes = train_set.n_classes

    if debug:
        val_set = data.Subset(val_set,
                              range(config.VALIDATION.BATCH_SIZE_PER_GPU))
        train_set = data.Subset(train_set,
                                range(config.TRAIN.BATCH_SIZE_PER_GPU * 2))

    logger.info(f"Training examples {len(train_set)}")
    logger.info(f"Validation examples {len(val_set)}")

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_set, num_replicas=world_size, rank=local_rank)

    train_loader = data.DataLoader(
        train_set,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        num_workers=config.WORKERS,
        sampler=train_sampler,
    )

    val_sampler = torch.utils.data.distributed.DistributedSampler(
        val_set, num_replicas=world_size, rank=local_rank)

    val_loader = data.DataLoader(
        val_set,
        batch_size=config.VALIDATION.BATCH_SIZE_PER_GPU,
        num_workers=config.WORKERS,
        sampler=val_sampler,
    )

    model = getattr(models, config.MODEL.NAME).get_seg_model(config)

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    model = model.to(device)  # Send to GPU

    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config.TRAIN.MAX_LR,
        momentum=config.TRAIN.MOMENTUM,
        weight_decay=config.TRAIN.WEIGHT_DECAY,
    )

    # weights are inversely proportional to the frequency of the classes in
    # the training set
    class_weights = torch.tensor(config.DATASET.CLASS_WEIGHTS,
                                 device=device,
                                 requires_grad=False)

    criterion = torch.nn.CrossEntropyLoss(weight=class_weights,
                                          ignore_index=255,
                                          reduction="mean")

    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[device], find_unused_parameters=True)

    snapshot_duration = epochs_per_cycle * len(
        train_loader) if not debug else 2 * len(train_loader)

    warmup_duration = 5 * len(train_loader)

    warmup_scheduler = LinearCyclicalScheduler(
        optimizer,
        "lr",
        start_value=config.TRAIN.MAX_LR,
        end_value=config.TRAIN.MAX_LR * world_size,
        cycle_size=10 * len(train_loader),
    )
    cosine_scheduler = CosineAnnealingScheduler(
        optimizer,
        "lr",
        config.TRAIN.MAX_LR * world_size,
        config.TRAIN.MIN_LR * world_size,
        cycle_size=snapshot_duration,
    )

    scheduler = ConcatScheduler(
        schedulers=[warmup_scheduler, cosine_scheduler],
        durations=[warmup_duration])

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        prepare_batch,
                                        device=device)

    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    # Set to update the epoch parameter of our distributed data sampler so that we get
    # different shuffles
    trainer.add_event_handler(Events.EPOCH_STARTED,
                              update_sampler_epoch(train_loader))

    if silence_other_ranks & local_rank != 0:
        logging.getLogger("ignite.engine.engine.Engine").setLevel(
            logging.WARNING)

    def _select_pred_and_mask(model_out_dict):
        return (model_out_dict["y_pred"].squeeze(),
                model_out_dict["mask"].squeeze())

    evaluator = create_supervised_evaluator(
        model,
        prepare_batch,
        metrics={
            "nll":
            Loss(criterion,
                 output_transform=_select_pred_and_mask,
                 device=device),
            "pixa":
            pixelwise_accuracy(n_classes,
                               output_transform=_select_pred_and_mask,
                               device=device),
            "cacc":
            class_accuracy(n_classes,
                           output_transform=_select_pred_and_mask,
                           device=device),
            "mca":
            mean_class_accuracy(n_classes,
                                output_transform=_select_pred_and_mask,
                                device=device),
            "ciou":
            class_iou(n_classes,
                      output_transform=_select_pred_and_mask,
                      device=device),
            "mIoU":
            mean_iou(n_classes,
                     output_transform=_select_pred_and_mask,
                     device=device),
        },
        device=device,
    )

    # Set the validation run to start on the epoch completion of the training run

    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              Evaluator(evaluator, val_loader))

    if local_rank == 0:  # Run only on master process

        trainer.add_event_handler(
            Events.ITERATION_COMPLETED,
            logging_handlers.log_training_output(
                log_interval=config.TRAIN.BATCH_SIZE_PER_GPU),
        )
        trainer.add_event_handler(Events.EPOCH_STARTED,
                                  logging_handlers.log_lr(optimizer))

        try:
            output_dir = generate_path(
                config.OUTPUT_DIR,
                git_branch(),
                git_hash(),
                config_file_name,
                config.TRAIN.MODEL_DIR,
                current_datetime(),
            )
        except TypeError:
            output_dir = generate_path(
                config.OUTPUT_DIR,
                config_file_name,
                config.TRAIN.MODEL_DIR,
                current_datetime(),
            )

        summary_writer = create_summary_writer(
            log_dir=path.join(output_dir, config.LOG_DIR))
        logger.info(
            f"Logging Tensorboard to {path.join(output_dir, config.LOG_DIR)}")
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            tensorboard_handlers.log_lr(summary_writer, optimizer, "epoch"),
        )
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED,
            tensorboard_handlers.log_training_output(summary_writer),
        )
        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED,
            logging_handlers.log_metrics(
                "Validation results",
                metrics_dict={
                    "nll": "Avg loss :",
                    "mIoU": " Avg IoU :",
                    "pixa": "Pixelwise Accuracy :",
                    "mca": "Mean Class Accuracy :",
                },
            ),
        )
        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED,
            tensorboard_handlers.log_metrics(
                summary_writer,
                trainer,
                "epoch",
                metrics_dict={
                    "mIoU": "Validation/IoU",
                    "nll": "Validation/Loss",
                    "mca": "Validation/MCA",
                },
            ),
        )

        def _select_max(pred_tensor):
            return pred_tensor.max(1)[1]

        def _tensor_to_numpy(pred_tensor):
            return pred_tensor.squeeze().cpu().numpy()

        transform_func = compose(np_to_tb, decode_segmap(n_classes=n_classes),
                                 _tensor_to_numpy)
        transform_pred = compose(transform_func, _select_max)
        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED,
            create_image_writer(summary_writer, "Validation/Image", "image"),
        )
        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED,
            create_image_writer(summary_writer,
                                "Validation/Mask",
                                "mask",
                                transform_func=transform_func),
        )
        evaluator.add_event_handler(
            Events.EPOCH_COMPLETED,
            create_image_writer(
                summary_writer,
                "Validation/Pred",
                "y_pred",
                transform_func=transform_pred,
            ),
        )

        def snapshot_function():
            return (trainer.state.iteration % snapshot_duration) == 0

        checkpoint_handler = SnapshotHandler(
            output_dir,
            config.MODEL.NAME,
            extract_metric_from("mIoU"),
            snapshot_function,
        )
        evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                                    {"model": model})
        logger.info("Starting training")

        if debug:
            trainer.run(
                train_loader,
                max_epochs=config.TRAIN.END_EPOCH,
                epoch_length=config.TRAIN.BATCH_SIZE_PER_GPU * 2,
                seed=config.SEED,
            )
        else:
            trainer.run(train_loader,
                        max_epochs=config.TRAIN.END_EPOCH,
                        epoch_length=len(train_loader),
                        seed=config.SEED)
Exemple #7
0
def main():
  parser = argparse.ArgumentParser()

  # Required parameters
  parser.add_argument("--model", type=str, default='ffn', help="model's name")
  parser.add_argument("--mode", type=int, choices=[0, 1, 2], default=None)
  parser.add_argument("--SNRdb", type=float, default=None)
  parser.add_argument("--pilot_version", type=int, choices=[1, 2], default=1)
  parser.add_argument("--loss_type", type=str, default="BCELoss")
  parser.add_argument("--train_batch_size", type=int, default=128)
  parser.add_argument("--valid_batch_size", type=int, default=128)
  parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
  parser.add_argument("--max_norm", type=float, default=-1)
  parser.add_argument("--lr", type=float, default=1e-3)
  parser.add_argument("--noise_lambda", type=float, default=1.0)
  parser.add_argument("--lr_scheduler", type=str, choices=["linear", "cycle", "cosine"], default="linear")
  parser.add_argument("--reset_lr_scheduler", type=str, choices=["linear", "cycle", "cosine"], default=None)
  parser.add_argument("--reset_trainer", action='store_true')
  parser.add_argument("--modify_model", action='store_true')
  parser.add_argument("--wd", type=float, default=1e-4, help="weight decay")
  parser.add_argument("--eval_iter", type=int, default=10)
  parser.add_argument("--save_iter", type=int, default=10)
  parser.add_argument("--n_epochs", type=int, default=10)
  parser.add_argument("--flush_dataset", type=int, default=0)
  parser.add_argument("--no_cache", action='store_true')
  parser.add_argument("--with_pure_y", action='store_true') 
  parser.add_argument("--with_h", action='store_true') 
  parser.add_argument("--only_l1", action='store_true', help="Only loss 1")
  parser.add_argument("--interpolation", action='store_true', help="if interpolate between pure and reconstruction.") 
  parser.add_argument("--data_dir", type=str, default="data")
  parser.add_argument("--cache_dir", type=str, default="train_cache")
  parser.add_argument("--output_path", type=str, default="runs", help="model save")
  parser.add_argument("--resume_from", type=str, default=None, help="resume training.")
  parser.add_argument("--first_cache_index", type=int, default=0)
  parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                      help="Device (cuda or cpu)")
  parser.add_argument("--local_rank", type=int, default=-1,
                      help="Local rank for distributed training (-1: not distributed)")
  parser.add_argument("--seed", type=int, default=43)
  parser.add_argument("--debug", action='store_true')
  args = parser.parse_args()

  args.output_path = os.path.join(args.output_path, f'pilot_{args.pilot_version}')
  args.cache_dir = os.path.join(args.data_dir, args.cache_dir)
  # Setup CUDA, GPU & distributed training
  args.distributed = (args.local_rank != -1)
  if not args.distributed:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend="nccl", init_method='env://')
  args.n_gpu = torch.cuda.device_count() if not args.distributed else 1
  args.device = device

  # Set seed
  set_seed(args)
  logger = setup_logger("trainer", distributed_rank=args.local_rank)

  # Model construction
  model = getattr(models, args.model)(args)
  model = model.to(device)
  optimizer = AdamW(model.parameters(), lr = args.lr, weight_decay=args.wd)

  if args.loss_type == "MSELoss":
    criterion = nn.MSELoss(reduction='sum').to(device)
  else:
    criterion = getattr(nn, args.loss_type, getattr(auxiliary, args.loss_type, None))().to(device)
  criterion2 = nn.MSELoss(reduction='sum').to(device)

  if args.local_rank != -1:
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
    )

  train_dataset = SIGDataset(args, data_type="train")
  valid_dataset = SIGDataset(args, data_type="valid")
  train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
  valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
  train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, pin_memory=True, shuffle=(not args.distributed))
  valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, pin_memory=True, shuffle=False)
  
  lr_scheduler = None
  if args.lr_scheduler == "linear":
    lr_scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
  elif args.lr_scheduler == "cycle":
    lr_scheduler = LinearCyclicalScheduler(optimizer, 'lr', 0.0, args.lr, args.eval_iter * len(train_loader))
  elif args.lr_scheduler == "cosine":
    lr_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, args.eval_iter * len(train_loader))

  # Training function and trainer
  def update(engine, batch):
      model.train()
      y, x_label, y_pure, H = train_dataset.prepare_batch(batch, device=args.device)

      if args.with_pure_y and args.with_h:
        x_pred, y_pure_pred, H_pred = model(y, pure=y_pure, H=H, opp=True)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        if args.loss_type == "MSELoss":
          loss_1 = loss_1 / x_pred.size(0)
        loss_noise = criterion2(y_pure_pred, y_pure) / y.size(0) / args.gradient_accumulation_steps
        loss_noise_h = criterion2(H_pred, H) / H.size(0) / args.gradient_accumulation_steps
        if args.only_l1:
          loss = loss_1
        else:
          loss = loss_1 + loss_noise * args.noise_lambda + loss_noise_h
        output = (loss.item(), loss_1.item(), loss_noise.item(), loss_noise_h.item())
      elif args.with_pure_y:
        x_pred, y_pure_pred = model(y, pure=y_pure if args.interpolation else None, opp=True)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        loss_noise = criterion2(y_pure_pred, y_pure) / y.size(0) / args.gradient_accumulation_steps
        loss = loss_1 + loss_noise * args.noise_lambda
        output = (loss.item(), loss_1.item(), loss_noise.item())
      elif args.with_h:
        x_pred, H_pred = model(y, opp=True)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        loss_noise = criterion2(H_pred, H) / H.size(0) / args.gradient_accumulation_steps
        loss = loss_1 + loss_noise * args.noise_lambda
        output = (loss.item(), loss_1.item(), loss_noise.item())
      else:
        x_pred = model(y)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        loss = loss_1
        output = (loss.item(), loss_1.item(), torch.zeros_like(loss_1).item())

      loss.backward()
      if args.max_norm > 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
      if engine.state.iteration % args.gradient_accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
      return output
  trainer = Engine(update)

  to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
  metric_names = ["loss", "l1", "ln"]
  if args.with_pure_y and args.with_h:
    metric_names.append("lnH")

  common.setup_common_training_handlers(
    trainer=trainer,
    train_sampler=train_loader.sampler,
    to_save=to_save,
    save_every_iters=len(train_loader) * args.save_iter,
    lr_scheduler=lr_scheduler,
    output_names=metric_names,
    with_pbars=False,
    clear_cuda_cache=False,
    output_path=args.output_path,
    n_saved=2,
  )

  resume_from = args.resume_from
  if resume_from is not None:
    checkpoint_fp = Path(resume_from)
    assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix())
    logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
    checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
    if args.reset_trainer:
      to_save.pop("trainer")
    checkpoint_to_load = to_save if 'validation' not in resume_from else {"model": model}
    Checkpoint.load_objects(to_load=checkpoint_to_load, checkpoint=checkpoint)
    if args.reset_lr_scheduler is not None:
      if args.reset_lr_scheduler == "linear":
        lr_scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
      elif args.reset_lr_scheduler == "cycle":
        lr_scheduler = LinearCyclicalScheduler(optimizer, 'lr', 0.0, args.lr, args.eval_iter * len(train_loader))
      elif args.reset_lr_scheduler == "cosine":
        lr_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, args.eval_iter * len(train_loader))

  metrics = {
    "accuracy": Accuracy(lambda output: (torch.round(output[0][0]), output[1][0])), 
    "loss_1": Loss(criterion, output_transform=lambda output: (output[0][0], output[1][0])),
    "loss_noise": Loss(criterion2, output_transform=lambda output: (output[0][1], output[1][1]))
  }
  if args.with_pure_y and args.with_h:
    metrics["loss_noise_h"] = Loss(criterion2, output_transform=lambda output: (output[0][2], output[1][2]))

  def _inference(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
    model.eval()
    with torch.no_grad():
      x, y, x_pure, H = valid_dataset.prepare_batch(batch, device=args.device, non_blocking=True)
      if args.with_pure_y and args.with_h:
        y_pred, x_pure_pred, h_pred = model(x, opp=True)
        outputs = (y_pred, x_pure_pred, h_pred), (y, x_pure, H)
      elif args.with_pure_y:
        y_pred, x_pure_pred = model(x, opp=True)
        outputs = (y_pred, x_pure_pred), (y, x_pure)
      elif args.with_h:
        y_pred, h_pred = model(x, opp=True)
        outputs = (y_pred, h_pred), (y, H)
      else:
        y_pred = model(x)
        x_pure_pred = x_pure
        outputs = (y_pred, x_pure_pred), (y, x_pure)       
      return outputs
  evaluator = Engine(_inference)
  for name, metric in metrics.items():
      metric.attach(evaluator, name)

  trainer.add_event_handler(Events.EPOCH_COMPLETED(every=args.eval_iter), lambda _: evaluator.run(valid_loader))

  if args.flush_dataset > 0:
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=args.n_epochs//args.flush_dataset), 
                  lambda _: train_loader.dataset.reset() if args.no_cache else train_loader.dataset.reload())

  # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
  if args.local_rank in [-1, 0]:
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=metric_names, output_transform=lambda _: {"lr": f"{optimizer.param_groups[0]['lr']:.2e}"})
    evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

    tb_logger = common.setup_tb_logging(args.output_path, trainer, optimizer, evaluators={'validation': evaluator}, log_every_iters=1)

  # Store 3 best models by validation accuracy:
  common.gen_save_best_models_by_val_score(
    save_handler=DiskSaver(args.output_path, require_empty=False),
    evaluator=evaluator,
    models={"model": model},
    metric_name="accuracy",
    n_saved=3,
    trainer=trainer,
    tag="validation"
  )

  # Run the training
  trainer.run(train_loader, max_epochs=args.n_epochs)

  if args.local_rank in [-1, 0]:
    tb_logger.close()
Exemple #8
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3."
    )
    parser.add_argument(
        "--logdir", type=str, default=None, help="If provided, the model will be output to this folder."
    )
    parser.add_argument("--dataset_cache", type=str, default="./dataset_cache", help="Path or url of the dataset cache")
    parser.add_argument("--use_mlflow", action="store_true", help="If true we enable mlflow")
    parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient")
    parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient")

    parser.add_argument(
        "--tracking_uri", type=str, default="http://localhost:5000", help="url for mlflow tracking server"
    )
    parser.add_argument("--num_candidates", type=int, default=5, help="Number of candidates for training")

    parser.add_argument("--experiment", type=str, help="experiment name for mlflow")

    parser.add_argument("--task_config", type=str, help="Path to the tokenization config file")
    parser.add_argument("--special_tokens_file", type=str, default=None, help="Path to the special tokens file")
    parser.add_argument(
        "--model_checkpoint", type=str, default="distilgpt2", help="Path, url or short name of the model"
    )
    parser.add_argument("--model_type", type=str, default=None, help="gpt or gpt2")
    parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training")
    parser.add_argument("--valid_batch_size", type=int, default=1, help="Batch size for validation")
    parser.add_argument(
        "--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps"
    )

    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--adam_epsilon", type=float, default=1e-6, help="Learning rate")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--patience", type=int, default=1, help="patience parameter for early stopping")
    parser.add_argument("--n_epochs", type=int, default=10, help="Number of training epochs")
    parser.add_argument("--max_data", type=int, default=0, help="Number of data items (0 includes everything)")
    parser.add_argument(
        "--val_max_data", type=int, default=0, help="Number of validation data items (0 includes everything)"
    )
    parser.add_argument(
        "--eval_before_start", action="store_true", help="If true start with a first evaluation before training"
    )
    parser.add_argument(
        "--overwrite_output_dir",
        action="store_true",
        help="If true, and the logdir is explictly passed, it will be overwritten",
    )
    parser.add_argument("--ul", action="store_true", help="If true use unlikelihood sampling")
    parser.add_argument("--freeze", action="store_true", help="If true freeze layers")
    parser.add_argument("--smoothing", type=float, default=0.0, help="label smoothing epsilon")
    parser.add_argument("--ignore_cache", action="store_true", help="If true ignore the dataset cache")
    parser.add_argument(
        "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)"
    )
    parser.add_argument(
        "--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)"
    )
    parser.add_argument(
        "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)"
    )
    parser.add_argument("--warmup-steps", default=0, type=int, help="Linear warmup over warmup_steps.")

    # custom training
    parser.add_argument("--sequence-tune-rate", type=float, default=0.5)
    parser.add_argument("--sequence-ngram-n", type=int, default=4)
    parser.add_argument(
        "--multitask", action="store_true", help="If true use multitask training with multiple choice loss"
    )
    parser.add_argument(
        "--retrain_base",
        type=str,
        default=None,
        help="JSON file with training parameters or MLflow run_id from which to get the parameters for retraining",
    )
    parser.add_argument(
        "--training_args_file",
        type=str,
        default=None,
        help="File with the training arguments generated by a previous run to use as parameters",
    )
    parser.add_argument("--scheduler", type=str, default="piecewiselinear", help="scheduler choice")
    parser.add_argument("--optimizer", type=str, default="AdamW", help="optimizer choice")
    parser.add_argument(
        "--max_block_size", type=int, default=None, help="If set, data is truncated to fit this max size"
    )

    args = parser.parse_args()

    if args.retrain_base:
        try:
            logger.info(f"reading the arguments from {args.retrain_base}")
            model_training_args = json.load(open(args.retrain_base))
        except:
            model_training_args = load_training_args(args.retrain_base)

        passed_args = [x[2:] for x in sys.argv if x.startswith("--")]
        # this is set by pytorch
        passed_args.extend(["ignore_cache", "local_rank"])

        for key, value in model_training_args.items():
            # we only update an argument if it's not passed explicitly
            if key not in passed_args:
                if value:
                    args.__setattr__(key, value)
        logger.info(vars(args))

    if args.logdir is None:
        args.logdir = Path(f"runs/{get_curr_time()}")
    else:
        args.logdir = Path(args.logdir)
        if not is_empty_or_absent_dir(args.logdir) and not args.overwrite_output_dir:
            logger.error(f"Error: {args.logdir} is not empty and you did not pass --overwrite_output_dir as True")
            exit()
        else:
            if args.local_rank in [-1, 0]:
                logger.info(f"deleting the existing folder {args.logdir}")
                try:
                    rmtree(args.logdir)
                except:
                    pass

    logger.info(f"outputting model to {args.logdir}")
    try:

        def finalize():

            if args.local_rank not in [-1, 0,]:
                # Make sure only the first process in distributed training will download model & vocab
                torch.distributed.barrier()

            if args.local_rank in [-1, 0] and args.n_epochs > 0:
                try:
                    # On the main process: rename the last checkpoint
                    # (for easy re-loading with from_pretrained method)
                    os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(args.logdir, WEIGHTS_NAME))

                    if args.use_mlflow:
                        mlflow.log_artifact(args.logdir / WEIGHTS_NAME, "training")
                        logger.info("ending mlflow run")
                        logger.info(f"run_id: {mlflow.active_run().info.run_id}")
                        mlflow.end_run()

                        rmtree(args.logdir)

                except:
                    logger.info("No checkpoint to finalize the model. Deleting run")
                    # TODO: fix issue in mlflow trying to delete the experiment multiple times
                    mlflow.delete_run(mlflow.active_run().info.run_id)
                    rmtree(args.logdir)

                if args.local_rank == 0:
                    torch.distributed.barrier()

        args.logdir.mkdir(parents=True, exist_ok=True)
        TRAINING_ARGS_FILE = args.logdir / "model_training_args.json"
        args_dict = deepcopy(vars(args))
        args_dict["logdir"] = str(args_dict["logdir"])
        json.dump(args_dict, open(TRAINING_ARGS_FILE, "w"), indent=2)

        if args.use_mlflow:
            if args.local_rank in [-1, 0]:
                assert args.tracking_uri
                assert args.experiment
                mlflow.set_tracking_uri(args.tracking_uri)
                mlflow.set_experiment(args.experiment)
                mlflow.start_run()

                # Log parameters
                mlflow.log_params(vars(args))
                # Log training arguments into a file
                mlflow.log_artifact(TRAINING_ARGS_FILE, "training")

        # The validation maximum number of items shouldn't be more than the training (used during debugging)
        if args.val_max_data == 0 and args.max_data > 0:
            args.val_max_data = args.max_data

        # Logging is set to INFO (resp. WARN) for main (resp. auxiliary)
        # process. logger.info => log main process only, logger.warning => log all processes
        logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
        # This is a logger.warning: it will be printed by all distributed processes
        logger.warning("Running process %d", args.local_rank)

        # Initialize distributed training if needed
        args.distributed = args.local_rank != -1

        if args.distributed:
            torch.cuda.set_device(args.local_rank)
            args.device = torch.device("cuda", args.local_rank)
            torch.distributed.init_process_group(backend="nccl", init_method="env://")

        logger.info(f"Reading the task configuration: {args.task_config}")
        copyfile(args.task_config, args.logdir / "task_config.json")
        task_config = load_task_config(args.task_config)

        logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")

        model_directory, is_local = get_model_directory(args.model_checkpoint)

        model, tokenizer = load_pretrained(
            model_directory,
            model_type=args.model_type,
            smoothing=args.smoothing,
            multitask=args.multitask,
            special_tokens_file=args.special_tokens_file,
            task_config=task_config,
            dataset_path=args.dataset_path,
        )

        special_tokens = read_special_tokens(
            task_config=task_config,
            special_tokens_file=args.special_tokens_file,
            dataset_path=args.dataset_path
        )
        logger.info(f"adding {len(special_tokens)}")
        tokenizer.add_tokens(special_tokens)

        model.resize_token_embeddings(len(tokenizer))

        model.to(args.device)

        if args.freeze:
            transformer = list(model.children())[0]
            i = 0
            for param in transformer.parameters():
                param.requires_grad = False
                i += 1
                if i >= len(list(transformer.parameters())) // 2:
                    break

        if args.optimizer.lower() == "rmsprop":
            optimizer = RMSprop(model.parameters(), lr=args.lr)
        elif args.optimizer.lower() == "adam":
            optimizer = Adam(model.parameters(), lr=args.lr)
        elif args.optimizer.lower() == "adafactor":
            optimizer = Adafactor(model.parameters(), lr=args.lr, warmup_init=False)
        elif args.optimizer.lower() == "sgd":
            optimizer = SGD(model.parameters(), lr=args.lr)
        elif args.optimizer.lower() == "novograd":
            optimizer = Novograd(model.parameters(), lr=args.lr)
        else:
            optimizer = AdamW(model.parameters(), lr=args.lr)

        # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
        if args.fp16:
            from apex import amp  # Apex is only required if we use fp16 training

            model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)

        if args.distributed:
            model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)

        logger.info("Prepare datasets")
        train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, task_config, tokenizer)

        def named_batch(batch, with_labels=True):
            """Helper function so that we get a dictionary with key as the input name and the value as the input value. 
            This makes it easier to pass parameters to the model by their name, without caring about the order
            """
            named_batch = {}
            # The components in the batch are ordered as in MODEL_INPUTS
            i = 0
            for input_name in MODEL_INPUTS:

                if not with_labels and "labels" in input_name:
                    continue

                key = input_name
                if not args.multitask:
                    if "mc_" in input_name:
                        continue
                    # the field is called `lm_labels` in the DoubleHeads and `labels` in single head model
                    if input_name == "lm_labels":
                        key = "labels"

                named_batch[key] = batch[i]
                i += 1
            return named_batch

        # Training function and trainer
        def update(engine, batch):
            model.train()

            n_batch = named_batch(tuple(input_tensor.to(args.device) for input_tensor in batch))

            outputs = model(**n_batch)

            lm_loss = outputs[0]
            if args.multitask:
                mc_loss = outputs[1]
            else:
                mc_loss = 0

            loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
            if engine.state.iteration % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
            return loss.item()

        trainer = Engine(update)

        # Evaluation function and evaluator (evaluator output is the input of the metrics)
        def inference(engine, batch):
            model.eval()
            with torch.no_grad():
                n_batch = named_batch(tuple(input_tensor.to(args.device) for input_tensor in batch))
                outputs = model(**{key: n_batch[key] for key in n_batch if "labels" not in key})
                lm_logits = outputs[0]
                lm_labels = n_batch["lm_labels"] if args.multitask else n_batch["labels"]

                lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
                lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)

                if args.multitask:
                    mc_logits = outputs[1]
                    mc_labels = n_batch["mc_labels"]

                    return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)
                else:
                    return lm_logits_flat_shifted, lm_labels_flat_shifted

        evaluator = Engine(inference)

        def checkpointing_score_function(engine):
            """"""
            val_metric = engine.state.metrics["average_ppl"]
            logger.info(val_metric)
            return -val_metric

        def score_function(engine):
            """"""
            val_ppl = engine.state.metrics["average_ppl"]
            return -val_ppl

        # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
        trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
        if args.n_epochs < 1:
            trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
        if args.eval_before_start:
            trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))
        # Attach mlflow logger
        # trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))

        # Make sure distributed data samplers split the dataset nicely between the distributed processes
        if args.distributed:
            trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
            evaluator.add_event_handler(
                Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)
            )

        if args.scheduler.lower() == "piecewiselinear":
            # Linearly decrease the learning rate from lr to zero
            scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
        elif args.scheduler.lower() == "linearcyclical":
            scheduler = LinearCyclicalScheduler(optimizer, "lr", args.lr / 10, args.lr, len(train_loader))
        elif args.scheduler.lower() == "cosine":
            scheduler = CosineAnnealingLR(optimizer, args.n_epochs * len(train_loader), 1e-4)
        elif args.warmup_steps > 0:
            t_total = len(train_loader) // args.gradient_accumulation_steps * args.n_epochs
            scheduler = get_linear_schedule_with_warmup(optimizer, args.warmup_steps, t_total)

        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

        # Prepare metrics - note how we compute distributed metrics
        RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
        if args.multitask:
            metrics = {
                "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
                "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])),
            }
            metrics.update(
                {
                    "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args),
                    "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args),
                }
            )
        else:
            metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1, reduction="mean"))}
            metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args)})

        metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])

        for name, metric in metrics.items():
            metric.attach(evaluator, name)

        # On the main process: add progress bar, tensorboard, checkpoints and save model,
        # configuration and tokenizer before we start to train

        if args.local_rank in [-1, 0]:
            pbar = ProgressBar(persist=True)
            pbar.attach(trainer, metric_names=["loss"])
            evaluator.add_event_handler(
                Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics))
            )

            checkpoint_handler = ModelCheckpoint(
                args.logdir,
                filename_prefix="checkpoint",
                score_function=checkpointing_score_function,
                create_dir=True,
                n_saved=2,
            )

            evaluator.add_event_handler(
                Events.COMPLETED, checkpoint_handler, {"mymodel": getattr(model, "module", model)}
            )  # "getattr" takes care of distributed encapsulation

            getattr(model, "module", model).config.to_json_file(os.path.join(args.logdir, CONFIG_NAME))
            tokenizer.save_pretrained(args.logdir)

            early_handler = EarlyStopping(patience=args.patience, score_function=score_function, trainer=trainer)
            evaluator.add_event_handler(Events.COMPLETED, early_handler)

        if args.use_mlflow and args.local_rank in [-1, 0]:

            class MLflowTracker:
                def __init__(self):
                    self.iteration = 1

                def eval_metric_logger(self, engine):
                    mlflow.log_metric("last_epoch", self.iteration)
                    for metric in engine.state.metrics:
                        mlflow.log_metric(f"eval_{metric}", engine.state.metrics[metric], step=self.iteration)
                    self.iteration += 1

                def train_metric_logger(self, engine):
                    for metric in engine.state.metrics:
                        mlflow.log_metric(f"train_{metric}", engine.state.metrics[metric], step=engine.state.epoch)

                def finish_experiment(self, engine):
                    mlflow.log_metric("finished", True)

                def start_experiment(self, engine):
                    # log the initial artifacts in the dir
                    mlflow.log_artifacts(args.logdir, "training")
                    mlflow.log_metric("finished", False)

            mlflow_tracker = MLflowTracker()
            trainer.add_event_handler(Events.STARTED, mlflow_tracker.start_experiment)
            # Log the train and validation metrics
            trainer.add_event_handler(Events.EPOCH_COMPLETED, mlflow_tracker.train_metric_logger)
            evaluator.add_event_handler(Events.COMPLETED, mlflow_tracker.eval_metric_logger)
            # Log the model
            trainer.add_event_handler(Events.COMPLETED, mlflow_tracker.finish_experiment)

        # Run the training
        trainer.run(train_loader, max_epochs=args.n_epochs)
    except KeyboardInterrupt:
        finalize()

    logger.info("training about to finish")
    finalize()
    logger.info("finalized training")
    def setup_training(self, base_model, classifier, setops_model):

        #
        # Create the train and test dataset.
        #
        train_loader, train_subset_loader, val_loader = self.setup_datasets()

        logging.info("Setup logging and controls.")

        #
        # Setup metrics plotters.
        #
        mlflow_logger = MlflowLogger()

        #
        # Setup the optimizer.
        #
        logging.info("Setup optimizers and losses.")

        parameters = list(base_model.parameters())
        parameters += list(setops_model.parameters())
        if self.train_classifier:
            parameters += list(classifier.parameters())

        if self.optimizer_cls == "SGD":
            optimizer = torch.optim.SGD(parameters,
                                        lr=self.lr1,
                                        momentum=0.9,
                                        weight_decay=self.weight_decay)
        else:
            optimizer = torch.optim.Adam(parameters,
                                         lr=self.lr1,
                                         weight_decay=self.weight_decay)

        if self.focal_loss:
            attr_loss = FocalLoss().cuda()
        else:
            attr_loss = torch.nn.MultiLabelSoftMarginLoss().cuda()

        recon_loss = torch.nn.MSELoss(
        ) if self.recon_loss == "mse" else torch.nn.L1Loss()

        #
        # Setup the trainer object and its logging.
        #
        logging.info("Setup trainer")
        trainer = create_setops_trainer(base_model,
                                        classifier,
                                        setops_model,
                                        optimizer,
                                        criterion1=attr_loss,
                                        criterion2=recon_loss.cuda(),
                                        params_object=self,
                                        device=self.device)
        ProgressBar(bar_format=None).attach(trainer)

        mlflow_logger.attach(engine=trainer,
                             prefix="Train ",
                             plot_event=Events.ITERATION_COMPLETED,
                             update_period=LOG_INTERVAL,
                             output_transform=lambda x: x)

        #
        # Define the evaluation metrics.
        #
        logging.info("Setup evaluator")
        evaluation_losses = {
            'real class loss':
                Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["real class a"], o["targets"]["class a"])) + \
                Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["real class b"], o["targets"]["class b"])),
            'fake class loss':
                Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["fake class a"], o["targets"]["class a"])) + \
                Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["fake class b"], o["targets"]["class b"])),
            '{} fake loss'.format(self.recon_loss):
                (Loss(recon_loss.cuda(), lambda o: (o["outputs"]["fake embed a"], o["targets"]["embed a"])) +
                Loss(recon_loss.cuda(), lambda o: (o["outputs"]["fake embed b"], o["targets"]["embed b"]))) / 2,
        }
        labels_list = train_loader.dataset.labels_list
        mask = labels_list_to_1hot(labels_list, labels_list).astype(np.bool)
        evaluation_accuracies = {
            'real class acc':
            (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                "real class a"], o["targets"]["class a"])) +
             MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                 "real class b"], o["targets"]["class b"]))) / 2,
            'fake class acc':
            (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                "fake class a"], o["targets"]["class a"])) +
             MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                 "fake class b"], o["targets"]["class b"]))) / 2,
            'S class acc':
            (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                "a_S_b class"], o["targets"]["a_S_b class"])) +
             MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                 "b_S_a class"], o["targets"]["b_S_a class"]))) / 2,
            'I class acc':
            (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                "a_I_b class"], o["targets"]["a_I_b class"])) +
             MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                 "b_I_a class"], o["targets"]["a_I_b class"]))) / 2,
            'U class acc':
            (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                "a_U_b class"], o["targets"]["a_U_b class"])) +
             MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][
                 "b_U_a class"], o["targets"]["a_U_b class"]))) / 2,
            'MSE fake acc':
            (EWMeanSquaredError(lambda o: (o["outputs"]["fake embed a"], o[
                "targets"]["embed a"])) + EWMeanSquaredError(lambda o: (o[
                    "outputs"]["fake embed b"], o["targets"]["embed b"]))) / 2,
            'real mAP':
            mAP(mask=mask,
                output_transform=lambda o:
                (o["outputs"]["real class a"], o["targets"]["class a"])),
            'fake mAP':
            mAP(mask=mask,
                output_transform=lambda o:
                (o["outputs"]["fake class a"], o["targets"]["class a"])),
            'S mAP':
            mAP(mask=mask,
                output_transform=lambda o:
                (o["outputs"]["a_S_b class"], o["targets"]["a_S_b class"])),
            'I mAP':
            mAP(mask=mask,
                output_transform=lambda o:
                (o["outputs"]["a_I_b class"], o["targets"]["a_I_b class"])),
            'U mAP':
            mAP(mask=mask,
                output_transform=lambda o:
                (o["outputs"]["a_U_b class"], o["targets"]["a_U_b class"])),
        }

        #
        # Setup the training evaluator object and its logging.
        #
        train_evaluator = create_setops_evaluator(
            base_model,
            classifier,
            setops_model,
            metrics=evaluation_accuracies.copy(),
            device=self.device)

        mlflow_logger.attach(engine=train_evaluator,
                             prefix="Train Eval ",
                             plot_event=Events.EPOCH_COMPLETED,
                             metric_names=list(evaluation_accuracies.keys()))
        ProgressBar(bar_format=None).attach(train_evaluator)

        #
        # Setup the evaluator object and its logging.
        #
        evaluator = create_setops_evaluator(base_model,
                                            classifier,
                                            setops_model,
                                            metrics={
                                                **evaluation_losses,
                                                **evaluation_accuracies
                                            },
                                            device=self.device)

        mlflow_logger.attach(engine=evaluator,
                             prefix="Eval ",
                             plot_event=Events.EPOCH_COMPLETED,
                             metric_names=list({
                                 **evaluation_losses,
                                 **evaluation_accuracies
                             }.keys()))
        ProgressBar(bar_format=None).attach(evaluator)

        #
        # Checkpoint of the model
        #
        self.setup_checkpoint(base_model, classifier, setops_model, evaluator)

        logging.info("Setup schedulers.")

        #
        # Update learning rate manually using the Visdom interface.
        #
        one_cycle_size = len(train_loader) * self.warmup_epochs * 2

        scheduler_1 = LinearCyclicalScheduler(optimizer,
                                              "lr",
                                              start_value=self.lr1,
                                              end_value=self.lr2,
                                              cycle_size=one_cycle_size)
        scheduler_2 = ReduceLROnPlateau(optimizer,
                                        factor=0.5,
                                        patience=4 * len(train_loader),
                                        cooldown=len(train_loader),
                                        output_transform=lambda x: x["main"])
        lr_scheduler = ConcatScheduler(schedulers=[scheduler_1, scheduler_2],
                                       durations=[one_cycle_size // 2],
                                       save_history=True)
        trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)

        #
        # Evaluation
        #
        @trainer.on(Events.EPOCH_COMPLETED)
        def epoch_completed(engine):
            #
            # Re-randomize the indices of the training dataset.
            #
            train_loader.dataset.calc_indices()

            #
            # Run the evaluator on a subset of the training dataset.
            #
            logging.info("Evaluation on a subset of the training data.")
            train_evaluator.run(train_subset_loader)

            #
            # Run the evaluator on the validation set.
            #
            logging.info("Evaluation on the eval data.")
            evaluator.run(val_loader)

        return trainer, train_loader
Exemple #10
0
def run(*options,
        cfg=None,
        local_rank=0,
        debug=False,
        input=None,
        distributed=False):
    """Run training and validation of model

    Notes:
        Options can be passed in via the options argument and loaded from the cfg file
        Options from default.py will be overridden by options loaded from cfg file
        Options from default.py will be overridden by options loaded from cfg file
        Options passed in via options argument will override option loaded from cfg file
    
    Args:
        *options (str,int ,optional): Options used to overide what is loaded from the
                                      config. To see what options are available consult
                                      default.py
        cfg (str, optional): Location of config file to load. Defaults to None.        
        debug (bool): Places scripts in debug/test mode and only executes a few iterations
        input (str, optional): Location of data if Azure ML run, 
            for local runs input is config.DATASET.ROOT
        distributed (bool): This flag tells the training script to run in distributed mode
            if more than one GPU exists.
    """

    # if AML training pipeline supplies us with input
    if input is not None:
        data_dir = input
        output_dir = data_dir + config.OUTPUT_DIR

    # Start logging
    load_log_configuration(config.LOG_CONFIG)
    logger = logging.getLogger(__name__)
    logger.debug(config.WORKERS)

    # Configuration:
    update_config(config, options=options, config_file=cfg)
    silence_other_ranks = True

    world_size = int(os.environ.get("WORLD_SIZE", 1))
    distributed = world_size > 1

    if distributed:
        # FOR DISTRIBUTED: Set the device according to local_rank.
        torch.cuda.set_device(local_rank)

        # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will
        # provide environment variables, and requires that you use init_method=`env://`.
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        logging.info(f"Started train.py using distributed mode.")
    else:
        logging.info(f"Started train.py using local mode.")

    # Set CUDNN benchmark mode:
    torch.backends.cudnn.benchmark = config.CUDNN.BENCHMARK

    # Fix random seeds:
    torch.manual_seed(config.SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(config.SEED)
    np.random.seed(seed=config.SEED)

    # Augmentation:
    basic_aug = Compose([
        Normalize(mean=(config.TRAIN.MEAN, ),
                  std=(config.TRAIN.STD, ),
                  max_pixel_value=1),
        PadIfNeeded(
            min_height=config.TRAIN.PATCH_SIZE,
            min_width=config.TRAIN.PATCH_SIZE,
            border_mode=config.OPENCV_BORDER_CONSTANT,
            always_apply=True,
            mask_value=255,
            value=0,
        ),
        Resize(
            config.TRAIN.AUGMENTATIONS.RESIZE.HEIGHT,
            config.TRAIN.AUGMENTATIONS.RESIZE.WIDTH,
            always_apply=True,
        ),
        PadIfNeeded(
            min_height=config.TRAIN.AUGMENTATIONS.PAD.HEIGHT,
            min_width=config.TRAIN.AUGMENTATIONS.PAD.WIDTH,
            border_mode=config.OPENCV_BORDER_CONSTANT,
            always_apply=True,
            mask_value=255,
        ),
    ])
    if config.TRAIN.AUGMENTATION:
        train_aug = Compose([basic_aug, HorizontalFlip(p=0.5)])
        val_aug = basic_aug
    else:
        train_aug = val_aug = basic_aug

    # Training and Validation Loaders:
    TrainPatchLoader = get_patch_loader(config)
    logging.info(f"Using {TrainPatchLoader}")

    train_set = TrainPatchLoader(
        config,
        split="train",
        is_transform=True,
        augmentations=train_aug,
        debug=debug,
    )
    logger.info(train_set)

    n_classes = train_set.n_classes
    val_set = TrainPatchLoader(
        config,
        split="val",
        is_transform=True,
        augmentations=val_aug,
        debug=debug,
    )

    logger.info(val_set)

    if debug:
        data_flow_dict = dict()

        data_flow_dict["train_patch_loader_length"] = len(train_set)
        data_flow_dict["validation_patch_loader_length"] = len(val_set)
        data_flow_dict["train_input_shape"] = train_set.seismic.shape
        data_flow_dict["train_label_shape"] = train_set.labels.shape
        data_flow_dict["n_classes"] = n_classes

        logger.info("Running in debug mode..")
        train_range = min(
            config.TRAIN.BATCH_SIZE_PER_GPU * config.NUM_DEBUG_BATCHES,
            len(train_set))
        logging.info(f"train range in debug mode {train_range}")
        train_set = data.Subset(train_set, range(train_range))
        valid_range = min(config.VALIDATION.BATCH_SIZE_PER_GPU, len(val_set))
        val_set = data.Subset(val_set, range(valid_range))

        data_flow_dict["train_length_subset"] = len(train_set)
        data_flow_dict["validation_length_subset"] = len(val_set)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_set, num_replicas=world_size, rank=local_rank)
    val_sampler = torch.utils.data.distributed.DistributedSampler(
        val_set, num_replicas=world_size, rank=local_rank)

    train_loader = data.DataLoader(
        train_set,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        num_workers=config.WORKERS,
        sampler=train_sampler,
    )
    val_loader = data.DataLoader(
        val_set,
        batch_size=config.VALIDATION.BATCH_SIZE_PER_GPU,
        num_workers=config.WORKERS,
        sampler=val_sampler)

    if debug:
        data_flow_dict["train_loader_length"] = len(train_loader)
        data_flow_dict["validation_loader_length"] = len(val_loader)
        config_file_name = "default_config" if not cfg else cfg.split(
            "/")[-1].split(".")[0]
        fname = f"data_flow_train_{config_file_name}_{config.TRAIN.MODEL_DIR}.json"
        with open(fname, "w") as f:
            json.dump(data_flow_dict, f, indent=2)

    # Model:
    model = getattr(models, config.MODEL.NAME).get_seg_model(config)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    # Optimizer and LR Scheduler:
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config.TRAIN.MAX_LR,
        momentum=config.TRAIN.MOMENTUM,
        weight_decay=config.TRAIN.WEIGHT_DECAY,
    )

    epochs_per_cycle = config.TRAIN.END_EPOCH // config.TRAIN.SNAPSHOTS
    snapshot_duration = epochs_per_cycle * len(
        train_loader) if not debug else 2 * len(train_loader)
    cosine_scheduler = CosineAnnealingScheduler(
        optimizer,
        "lr",
        config.TRAIN.MAX_LR * world_size,
        config.TRAIN.MIN_LR * world_size,
        cycle_size=snapshot_duration,
    )

    if distributed:
        warmup_duration = 5 * len(train_loader)
        warmup_scheduler = LinearCyclicalScheduler(
            optimizer,
            "lr",
            start_value=config.TRAIN.MAX_LR,
            end_value=config.TRAIN.MAX_LR * world_size,
            cycle_size=10 * len(train_loader),
        )
        scheduler = ConcatScheduler(
            schedulers=[warmup_scheduler, cosine_scheduler],
            durations=[warmup_duration])
    else:
        scheduler = cosine_scheduler

    # class weights are inversely proportional to the frequency of the classes in the training set
    class_weights = torch.tensor(config.DATASET.CLASS_WEIGHTS,
                                 device=device,
                                 requires_grad=False)

    # Loss:
    criterion = torch.nn.CrossEntropyLoss(weight=class_weights,
                                          ignore_index=255,
                                          reduction="mean")

    # Model:
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[device], find_unused_parameters=True)
        if silence_other_ranks & local_rank != 0:
            logging.getLogger("ignite.engine.engine.Engine").setLevel(
                logging.WARNING)

    # Ignite trainer and evaluator:
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        prepare_batch,
                                        device=device)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    # Set to update the epoch parameter of our distributed data sampler so that we get
    # different shuffles
    trainer.add_event_handler(Events.EPOCH_STARTED,
                              update_sampler_epoch(train_loader))

    transform_fn = lambda output_dict: (output_dict["y_pred"].squeeze(),
                                        output_dict["mask"].squeeze())
    evaluator = create_supervised_evaluator(
        model,
        prepare_batch,
        metrics={
            "nll":
            Loss(criterion, output_transform=transform_fn, device=device),
            "pixacc":
            pixelwise_accuracy(n_classes,
                               output_transform=transform_fn,
                               device=device),
            "cacc":
            class_accuracy(n_classes,
                           output_transform=transform_fn,
                           device=device),
            "mca":
            mean_class_accuracy(n_classes,
                                output_transform=transform_fn,
                                device=device),
            "ciou":
            class_iou(n_classes, output_transform=transform_fn, device=device),
            "mIoU":
            mean_iou(n_classes, output_transform=transform_fn, device=device),
        },
        device=device,
    )

    # The model will be saved under: outputs/<config_file_name>/<model_dir>
    config_file_name = "default_config" if not cfg else cfg.split(
        "/")[-1].split(".")[0]
    try:
        output_dir = generate_path(
            config.OUTPUT_DIR,
            git_branch(),
            git_hash(),
            config_file_name,
            config.TRAIN.MODEL_DIR,
            current_datetime(),
        )
    except:
        output_dir = generate_path(
            config.OUTPUT_DIR,
            config_file_name,
            config.TRAIN.MODEL_DIR,
            current_datetime(),
        )

    if local_rank == 0:  # Run only on master process
        # Logging:
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED,
            logging_handlers.log_training_output(
                log_interval=config.PRINT_FREQ),
        )
        trainer.add_event_handler(Events.EPOCH_STARTED,
                                  logging_handlers.log_lr(optimizer))

        # Checkpointing: snapshotting trained models to disk
        checkpoint_handler = SnapshotHandler(
            output_dir,
            config.MODEL.NAME,
            extract_metric_from("mIoU"),
            lambda: (trainer.state.iteration % snapshot_duration) == 0,
        )

        evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                                    {"model": model})

        # Tensorboard and Logging:
        summary_writer = create_summary_writer(
            log_dir=path.join(output_dir, "logs"))
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            tensorboard_handlers.log_lr(summary_writer, optimizer, "epoch"))
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED,
            tensorboard_handlers.log_training_output(summary_writer))
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED,
            tensorboard_handlers.log_validation_output(summary_writer))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        if local_rank == 0:  # Run only on master process
            tensorboard_handlers.log_results(engine,
                                             evaluator,
                                             summary_writer,
                                             n_classes,
                                             stage="Training")
            logging_handlers.log_metrics(engine, evaluator, stage="Training")
            logger.info("Logging training results..")

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        if local_rank == 0:  # Run only on master process
            tensorboard_handlers.log_results(engine,
                                             evaluator,
                                             summary_writer,
                                             n_classes,
                                             stage="Validation")
            logging_handlers.log_metrics(engine, evaluator, stage="Validation")
            logger.info("Logging validation results..")
            # dump validation set metrics at the very end for debugging purposes
            if engine.state.epoch == config.TRAIN.END_EPOCH and debug:
                fname = f"metrics_{config_file_name}_{config.TRAIN.MODEL_DIR}.json"
                metrics = evaluator.state.metrics
                out_dict = {
                    x: metrics[x]
                    for x in ["nll", "pixacc", "mca", "mIoU"]
                }
                with open(fname, "w") as fid:
                    json.dump(out_dict, fid)
                log_msg = " ".join(f"{k}: {out_dict[k]}"
                                   for k in out_dict.keys())
                logging.info(log_msg)

    logger.info("Starting training")
    trainer.run(train_loader,
                max_epochs=config.TRAIN.END_EPOCH,
                epoch_length=len(train_loader),
                seed=config.SEED)
    if local_rank == 0:
        summary_writer.close()