Пример #1
0
def train(max_iter,
          dataset,
          sampler,
          const_z,
          G,
          G_ema,
          D,
          optimizer_G,
          optimizer_D,
          r1_lambda,
          pl_lambda,
          d_k,
          g_k,
          policy,
          device,
          amp,
          save=1000):

    status = Status(max_iter)
    pl_mean = 0.
    loss = GANLoss()
    scaler = GradScaler() if amp else None
    augment = functools.partial(DiffAugment, policy=policy)

    if G_ema is not None:
        G_ema.eval()

    while status.batches_done < max_iter:
        for index, real in enumerate(dataset):
            optimizer_G.zero_grad()
            optimizer_D.zero_grad()

            real = real.to(device)
            '''discriminator'''
            z = sampler()
            with autocast(amp):
                # D(x)
                real_aug = augment(real)
                real_prob = D(real_aug)
                # D(G(z))
                fake, _ = G(z)
                fake_aug = augment(fake)
                fake_prob = D(fake_aug.detach())
                # loss
                if status.batches_done % d_k == 0 \
                    and r1_lambda > 0 \
                    and status.batches_done is not 0:
                    # lazy regularization
                    r1 = r1_penalty(real, D, scaler)
                    D_loss = r1 * r1_lambda * d_k
                else:
                    # gan loss on other iter
                    D_loss = loss.d_loss(real_prob, fake_prob)

            if scaler is not None:
                scaler.scale(D_loss).backward()
                scaler.step(optimizer_D)
            else:
                D_loss.backward()
                optimizer_D.step()
            '''generator'''
            z = sampler()
            with autocast(amp):
                # D(G(z))
                fake, style = G(z)
                fake_aug = augment(fake)
                fake_prob = D(fake_aug)
                # loss
                if status.batches_done % g_k == 0 \
                    and pl_lambda > 0 \
                    and status.batches_done is not 0:
                    # lazy regularization
                    pl = pl_penalty(style, fake, pl_mean, scaler)
                    G_loss = pl * pl_lambda * g_k
                    avg_pl = np.mean(pl.detach().cpu().numpy())
                    pl_mean = update_pl_mean(pl_mean, avg_pl)
                else:
                    # gan loss on other iter
                    G_loss = loss.g_loss(fake_prob)

            if scaler is not None:
                scaler.scale(G_loss).backward()
                scaler.step(optimizer_G)
            else:
                G_loss.backward()
                optimizer_G.step()

            if G_ema is not None:
                update_ema(G, G_ema)

            # save
            if status.batches_done % save == 0:
                with torch.no_grad():
                    images, _ = G_ema(const_z)
                save_image(
                    images,
                    f'implementations/StyleGAN2/result/{status.batches_done}.jpg',
                    nrow=4,
                    normalize=True,
                    value_range=(-1, 1))
                torch.save(
                    G_ema.state_dict(),
                    f'implementations/StyleGAN2/result/G_{status.batches_done}.pt'
                )
            save_image(fake,
                       f'running.jpg',
                       nrow=4,
                       normalize=True,
                       value_range=(-1, 1))

            # updates
            loss_dict = dict(
                G=G_loss.item() if not torch.isnan(G_loss).any() else 0,
                D=D_loss.item() if not torch.isnan(D_loss).any() else 0)
            status.update(loss_dict)
            if scaler is not None:
                scaler.update()

            if status.batches_done == max_iter:
                break

    status.plot()
Пример #2
0
def create_trainer(model, optimizer, criterion, train_sampler, config, logger, with_clearml):
    device = config.device
    prepare_batch = data.prepare_image_mask

    # Setup trainer
    accumulation_steps = config.get("accumulation_steps", 1)
    model_output_transform = config.get("model_output_transform", lambda x: x)

    with_amp = config.get("with_amp", True)
    scaler = GradScaler(enabled=with_amp)

    def forward_pass(batch):
        model.train()
        x, y = prepare_batch(batch, device=device, non_blocking=True)
        with autocast(enabled=with_amp):
            y_pred = model(x)
            y_pred = model_output_transform(y_pred)
            loss = criterion(y_pred, y) / accumulation_steps
        return loss

    def amp_backward_pass(engine, loss):
        scaler.scale(loss).backward()
        if engine.state.iteration % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

    def hvd_amp_backward_pass(engine, loss):
        scaler.scale(loss).backward()
        optimizer.synchronize()
        with optimizer.skip_synchronize():
            scaler.step(optimizer)
            scaler.update()
        optimizer.zero_grad()

    if idist.backend() == "horovod" and with_amp:
        backward_pass = hvd_amp_backward_pass
    else:
        backward_pass = amp_backward_pass

    def training_step(engine, batch):
        loss = forward_pass(batch)
        output = {"supervised batch loss": loss.item()}
        backward_pass(engine, loss)
        return output

    trainer = Engine(training_step)
    trainer.logger = logger

    output_names = [
        "supervised batch loss",
    ]
    lr_scheduler = config.lr_scheduler

    to_save = {
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
        "trainer": trainer,
        "amp": scaler,
    }

    save_every_iters = config.get("save_every_iters", 1000)

    common.setup_common_training_handlers(
        trainer,
        train_sampler,
        to_save=to_save,
        save_every_iters=save_every_iters,
        save_handler=utils.get_save_handler(config.output_path.as_posix(), with_clearml),
        lr_scheduler=lr_scheduler,
        output_names=output_names,
        with_pbars=not with_clearml,
        log_every_iters=1,
    )

    resume_from = config.get("resume_from", None)
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Пример #3
0
    def __init__(
        self,
        model: torch.nn.Module,
        loss_fn: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        generator: Generator,
        projector: torch.nn.Module,
        batch_size: int,
        iterations: int,
        device: torch.device,
        eval_freq: int = 1000,
        eval_iters: int = 100,
        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        grad_clip_max_norm: Optional[float] = None,
        writer: Optional[SummaryWriter] = None,
        save_path: Optional[str] = None,
        checkpoint_path: Optional[str] = None,
        mixed_precision: bool = False,
        train_projector: bool = True,
        feed_layers: Optional[List[int]] = None,
    ) -> None:

        # Logging
        self.logger = logging.getLogger()
        self.writer = writer

        # Saving
        self.save_path = save_path

        # Device
        self.device = device

        # Model
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.generator = generator
        self.projector = projector
        self.train_projector = train_projector
        self.feed_layers = feed_layers

        #  Eval
        self.eval_freq = eval_freq
        self.eval_iters = eval_iters

        # Scheduler
        self.scheduler = scheduler
        self.grad_clip_max_norm = grad_clip_max_norm

        # Batch & Iteration
        self.batch_size = batch_size
        self.iterations = iterations
        self.start_iteration = 0

        # Floating-point precision
        self.mixed_precision = (True if self.device.type == "cuda"
                                and mixed_precision else False)
        self.scaler = GradScaler() if self.mixed_precision else None

        if checkpoint_path:
            self._load_from_checkpoint(checkpoint_path)

        # Metrics
        self.train_acc_metric = LossMetric()
        self.train_loss_metric = LossMetric()

        self.val_acc_metric = LossMetric()
        self.val_loss_metric = LossMetric()

        # Best
        self.best_loss = -1
Пример #4
0
             cnfg.scale_down_emb, cnfg.freeze_tgtemb)

if use_cuda:
    mymodel.to(cuda_device)
    lossf.to(cuda_device)

optimizer = Optimizer(mymodel.parameters(),
                      lr=init_lr,
                      betas=adam_betas_default,
                      eps=ieps_adam_default,
                      weight_decay=cnfg.weight_decay,
                      amsgrad=use_ams)
optimizer.zero_grad(set_to_none=True)

use_amp = cnfg.use_amp and use_cuda
scaler = GradScaler() if use_amp else None

if multi_gpu:
    mymodel = DataParallelMT(mymodel,
                             device_ids=cuda_devices,
                             output_device=cuda_device.index,
                             host_replicate=True,
                             gather_output=False)
    lossf = DataParallelCriterion(lossf,
                                  device_ids=cuda_devices,
                                  output_device=cuda_device.index,
                                  replicate_once=True)

fine_tune_state = cnfg.fine_tune_state
if fine_tune_state is not None:
    logger.info("Load optimizer state from: " + fine_tune_state)
Пример #5
0
def main(args):

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    print(args)

    # config = []

    device = torch.device('cuda')

    model = None
    if args.arch == "UNet":
        model = UNet(args).to(device)
    else:
        raise("architectures other than Unet hasn't been added!!")


    # update_lrs = nn.Parameter(args.update_lr*torch.ones(self.update_step, len(self.net.vars)), requires_grad=True)
    model.optimizer = optim.Adam(model.parameters(), lr=args.lr, eps=1e-7, amsgrad=True, weight_decay=args.weight_decay)
    model.lr_scheduler = optim.lr_scheduler.ExponentialLR(model.optimizer, args.exp_decay)

    tmp = filter(lambda x: x.requires_grad, model.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(model)
    
    #for name, param in model.named_parameters():
    #    print(name, param.size())
    print('Total trainable tensors:', num, flush=True)

    SUMMARY_INTERVAL=5
    TEST_PRINT_INTERVAL=SUMMARY_INTERVAL*5
    ITER_SAVE_INTERVAL=300
    EPOCH_SAVE_INTERVAL=5

    model_path = args.model_saving_path + args.model_name + "_batch_size_" + str(args.batch_size) + "_lr_" + str(args.lr) + "_data_" + str(args.data_folder.split('/')[-1]) + "_HIDDEN_DIM_" + str(args.HIDDEN_DIM)+"_regNorm_"+str(args.reg_norm)
    if not os.path.isdir(model_path):
        os.mkdir(model_path)
    
    ds = SimulationDataset(args.data_folder, total_sample_number = args.total_sample_number)
    torch.manual_seed(42)
    train_ds, test_ds = random_split(ds, [int(0.9*len(ds)), len(ds) - int(0.9*len(ds))])

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=True, num_workers=0)

    train_mean = 0
    test_mean = 0
    # first get the mean-absolute-field value:
    for sample_batched in train_loader:
        train_mean += torch.mean(torch.abs(sample_batched["field"]))
    for sample_batched in test_loader:
        test_mean += torch.mean(torch.abs(sample_batched["field"]))
    train_mean /= len(train_loader)
    test_mean /= len(test_loader)

    print("total training samples: %d, total test samples: %d, train_abs_mean: %f, test_abs_mean: %f" % (len(train_ds), len(test_ds), train_mean, test_mean), flush=True)
    

    # for visualizing the graph:
    #writer = SummaryWriter('runs/'+args.model_name)

    #test_input = None
    #for sample in train_loader:
    #    test_input = sample['structure']
    #    break
    #writer.add_graph(model, test_input.to(device))
    #writer.close()
    

    df = pd.DataFrame(columns=['epoch','train_loss', 'train_phys_reg', 'test_loss', 'test_phys_reg'])

    train_loss_history = []
    train_phys_reg_history = []

    test_loss_history = []
    test_phys_reg_history = []

    start_epoch=0
    if (args.continue_train):
        print("Restoring weights from ", model_path+"/last_model.pt", flush=True)
        checkpoint = torch.load(model_path+"/last_model.pt")
        start_epoch=checkpoint['epoch']
        model = checkpoint['model']
        model.lr_scheduler = checkpoint['lr_scheduler']
        model.optimizer = checkpoint['optimizer']
        df = pd.read_csv(model_path + '/'+'df.csv')
        
    scaler = GradScaler()

    best_loss = 1e4
    last_epoch_data_loss = 1.0
    last_epoch_physical_loss = 1.0
    for step in range(start_epoch, args.epoch):
        print("epoch: ", step, flush=True)
        reg_norm = regConstScheduler(step, args, last_epoch_data_loss, last_epoch_physical_loss);
        # training
        for sample_batched in train_loader:
            model.optimizer.zero_grad()
            
            x_batch_train, y_batch_train = sample_batched['structure'].to(device), sample_batched['field'].to(device)
            with autocast():
                logits = model(x_batch_train, bn_training=True)
                #calculate the loss using the ground truth
                loss = model.loss_fn(logits, y_batch_train)
                # print("loss: ", loss, flush=True)

                pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                fields = logits; # predicted fields
                FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)

                phys_reg = model.loss_fn(FD_H, fields[:,:,1:(Nz-1),:])*reg_norm
                loss = loss + phys_reg            

                scaler.scale(loss).backward()
                scaler.step(model.optimizer)
                scaler.update()
                # loss.backward()
                # model.optimizer.step()

        #Save the weights at the end of each epoch
        checkpoint = {
                        'epoch': step,
                        'model': model,
                        'optimizer': model.optimizer,
                        'lr_scheduler': model.lr_scheduler
                     }
        torch.save(checkpoint, model_path+"/last_model.pt")


        # evaluation
        train_loss = 0
        train_phys_reg = 0
        for sample_batched in train_loader:
            x_batch_train, y_batch_train = sample_batched['structure'].to(device), sample_batched['field'].to(device)
            
            with torch.no_grad():
                logits = model(x_batch_train, bn_training=False)
                loss = model.loss_fn(logits, y_batch_train)
                
                # Calculate physical residue
                pattern = (x_batch_train*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                fields = logits;
                FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)
                phys_reg = model.loss_fn(FD_H, fields[:,:,1:(Nz-1),:])*reg_norm

                train_loss += loss
                train_phys_reg += phys_reg

        train_loss /= len(train_loader)*train_mean
        train_phys_reg /= len(train_loader)

        test_loss = 0
        test_phys_reg = 0
        for sample_batched in test_loader:
            x_batch_test, y_batch_test = sample_batched['structure'].to(device), sample_batched['field'].to(device)
            
            with torch.no_grad():
                logits = model(x_batch_test, bn_training=False)
                loss = model.loss_fn(logits, y_batch_test)

                # Calculate physical residue
                pattern = (x_batch_test*(n_Si - n_air) + n_air)**2; # rescale the 0/1 pattern into dielectric constant
                fields = logits;
                FD_H = H_to_H(fields, pattern, wavelength, Nx, Nz, dx, dz)
                phys_reg = model.loss_fn(FD_H, fields[:,:,1:(Nz-1),:])

                test_loss += loss
                test_phys_reg += phys_reg
        test_loss /= len(test_loader)*test_mean
        test_phys_reg /= len(test_loader)
        last_epoch_data_loss = test_loss
        last_epoch_physical_loss = test_phys_reg.detach().clone()

        test_phys_reg *= reg_norm
        

        print('train loss: %.5f, test loss: %.5f, train phys reg: %.5f, test phys reg: %.5f, last_physical_loss: %.5f' % (train_loss, test_loss, train_phys_reg, test_phys_reg, last_epoch_physical_loss), flush=True)

            
        model.lr_scheduler.step()

        df = df.append({'epoch': step+1, 'lr': str(model.lr_scheduler.get_last_lr()),
                        'train_loss': train_loss.item(),
                        'train_phys_reg': train_phys_reg.item(),
                        'test_loss': test_loss.item(),
                        'test_phys_reg': test_phys_reg.item()
                       }, ignore_index=True)

        df.to_csv(model_path + '/'+'df.csv',index=False)

        if(test_loss<best_loss):
            best_loss = test_loss
            checkpoint = {
                            'epoch': step,
                            'model': model,
                            'optimizer': model.optimizer,
                            'lr_scheduler': model.lr_scheduler
                         }
            torch.save(checkpoint, model_path+"/best_model.pt")
Пример #6
0
def main(dataset_path, batch_size=256, max_epochs=10):
    assert torch.cuda.is_available()
    assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled."
    torch.backends.cudnn.benchmark = True

    device = "cuda"

    train_loader, test_loader, eval_train_loader = get_train_eval_loaders(
        dataset_path, batch_size=batch_size)

    model = wide_resnet50_2(num_classes=100).to(device)
    optimizer = SGD(model.parameters(), lr=0.01)
    criterion = CrossEntropyLoss().to(device)

    scaler = GradScaler()

    def train_step(engine, batch):
        x = convert_tensor(batch[0], device, non_blocking=True)
        y = convert_tensor(batch[1], device, non_blocking=True)

        optimizer.zero_grad()

        # Runs the forward pass with autocasting.
        with autocast():
            y_pred = model(x)
            loss = criterion(y_pred, y)

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same precision that autocast used for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        return loss.item()

    trainer = Engine(train_step)
    timer = Timer(average=True)
    timer.attach(trainer, step=Events.EPOCH_COMPLETED)
    ProgressBar(persist=True).attach(
        trainer, output_transform=lambda out: {"batch loss": out})

    metrics = {"Accuracy": Accuracy(), "Loss": Loss(criterion)}

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)

    def log_metrics(engine, title):
        for name in metrics:
            print(f"\t{title} {name}: {engine.state.metrics[name]:.2f}")

    @trainer.on(Events.COMPLETED)
    def run_validation(_):
        print(f"- Mean elapsed time for 1 epoch: {timer.value()}")
        print("- Metrics:")
        with evaluator.add_event_handler(Events.COMPLETED, log_metrics,
                                         "Train"):
            evaluator.run(eval_train_loader)

        with evaluator.add_event_handler(Events.COMPLETED, log_metrics,
                                         "Test"):
            evaluator.run(test_loader)

    trainer.run(train_loader, max_epochs=max_epochs)
Пример #7
0
def do_train(cfg, arguments, train_data_loader, test_data_loader, model,
             criterion, optimizer, lr_scheduler, check_pointer, device):
    meters = MetricLogger()
    evaluator = train_data_loader.dataset.evaluator
    summary_writer = None
    use_tensorboard = cfg.TRAIN.USE_TENSORBOARD
    if is_master_proc() and use_tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        summary_writer = SummaryWriter(
            log_dir=os.path.join(cfg.OUTPUT_DIR, 'tf_logs'))

    log_step = cfg.TRAIN.LOG_STEP
    save_epoch = cfg.TRAIN.SAVE_EPOCH
    eval_epoch = cfg.TRAIN.EVAL_EPOCH
    max_epoch = cfg.TRAIN.MAX_EPOCH
    gradient_accumulate_step = cfg.TRAIN.GRADIENT_ACCUMULATE_STEP

    start_epoch = arguments['cur_epoch']
    epoch_iters = len(train_data_loader)
    max_iter = (max_epoch - start_epoch) * epoch_iters
    current_iterations = 0

    if cfg.TRAIN.HYBRID_PRECISION:
        # Creates a GradScaler once at the beginning of training.
        scaler = GradScaler()

    synchronize()
    model.train()
    logger.info("Start training ...")
    # Perform the training loop.
    logger.info("Start epoch: {}".format(start_epoch))
    start_training_time = time.time()
    end = time.time()
    for cur_epoch in range(start_epoch, max_epoch + 1):
        if cfg.DATALOADER.SHUFFLE:
            shuffle_dataset(train_data_loader, cur_epoch)
        data_loader = Prefetcher(
            train_data_loader,
            device) if cfg.DATALOADER.PREFETCHER else train_data_loader
        for iteration, (images, targets) in enumerate(data_loader):
            if not cfg.DATALOADER.PREFETCHER:
                images = images.to(device=device, non_blocking=True)
                targets = targets.to(device=device, non_blocking=True)

            if cfg.TRAIN.HYBRID_PRECISION:
                # Runs the forward pass with autocasting.
                with autocast():
                    output_dict = model(images)
                    loss_dict = criterion(output_dict, targets)
                    loss = loss_dict[KEY_LOSS] / gradient_accumulate_step

                current_iterations += 1
                if current_iterations % gradient_accumulate_step != 0:
                    if isinstance(model, DistributedDataParallel):
                        # multi-gpu distributed training
                        with model.no_sync():
                            scaler.scale(loss).backward()
                    else:
                        scaler.scale(loss).backward()
                else:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    current_iterations = 0
                    optimizer.zero_grad()
            else:
                output_dict = model(images)
                loss_dict = criterion(output_dict, targets)
                loss = loss_dict[KEY_LOSS] / gradient_accumulate_step

                current_iterations += 1
                if current_iterations % gradient_accumulate_step != 0:
                    if isinstance(model, DistributedDataParallel):
                        # multi-gpu distributed training
                        with model.no_sync():
                            loss.backward()
                    else:
                        loss.backward()
                else:
                    loss.backward()
                    optimizer.step()
                    current_iterations = 0
                    optimizer.zero_grad()

            acc_list = evaluator.evaluate_train(output_dict, targets)
            update_stats(cfg.NUM_GPUS, meters, loss_dict, acc_list)

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time)
            if (iteration + 1) % log_step == 0:
                logger.info(
                    log_iter_stats(iteration, epoch_iters, cur_epoch,
                                   max_epoch, optimizer.param_groups[0]['lr'],
                                   meters))
            if is_master_proc() and summary_writer:
                global_step = (cur_epoch - 1) * epoch_iters + (iteration + 1)
                for name, meter in meters.meters.items():
                    summary_writer.add_scalar('{}/avg'.format(name),
                                              float(meter.avg),
                                              global_step=global_step)
                    summary_writer.add_scalar('{}/global_avg'.format(name),
                                              meter.global_avg,
                                              global_step=global_step)
                summary_writer.add_scalar('lr',
                                          optimizer.param_groups[0]['lr'],
                                          global_step=global_step)

        if cfg.DATALOADER.PREFETCHER:
            data_loader.release()
        logger.info(
            log_epoch_stats(epoch_iters, cur_epoch, max_epoch,
                            optimizer.param_groups[0]['lr'], meters))
        arguments["cur_epoch"] = cur_epoch
        lr_scheduler.step()
        if is_master_proc(
        ) and save_epoch > 0 and cur_epoch % save_epoch == 0 and cur_epoch != max_epoch:
            check_pointer.save("model_{:04d}".format(cur_epoch), **arguments)
        if eval_epoch > 0 and cur_epoch % eval_epoch == 0 and cur_epoch != max_epoch:
            if cfg.MODEL.NORM.PRECISE_BN:
                calculate_and_update_precise_bn(
                    train_data_loader,
                    model,
                    min(cfg.MODEL.NORM.NUM_BATCHES_PRECISE,
                        len(train_data_loader)),
                    cfg.NUM_GPUS > 0,
                )

            eval_results = do_evaluation(cfg,
                                         model,
                                         test_data_loader,
                                         device,
                                         cur_epoch=cur_epoch)
            model.train()
            if is_master_proc() and summary_writer:
                for key, value in eval_results.items():
                    summary_writer.add_scalar(f'eval/{key}',
                                              value,
                                              global_step=cur_epoch + 1)

    if eval_epoch > 0:
        logger.info('Start final evaluating...')
        torch.cuda.empty_cache()  # speed up evaluating after training finished
        eval_results = do_evaluation(cfg, model, test_data_loader, device)

        if is_master_proc() and summary_writer:
            for key, value in eval_results.items():
                summary_writer.add_scalar(f'eval/{key}',
                                          value,
                                          global_step=arguments["cur_epoch"])
            summary_writer.close()
    if is_master_proc():
        check_pointer.save("model_final", **arguments)
    # compute training time
    total_training_time = int(time.time() - start_training_time)
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / max_iter))
    return model
Пример #8
0
 def begin_fit(self):
     from torch.cuda.amp import GradScaler
     self.old_one_batch = self.learn.one_batch
     self.learn.one_batch = partial(mixed_precision_one_batch, self.learn)
     self.learn.scaler = GradScaler()
Пример #9
0
def main( gpu,cfg,args):
    # Network Builders

    load_gpu = gpu+args.start_gpu
    rank = gpu
    torch.cuda.set_device(load_gpu)
    dist.init_process_group(
        backend='nccl',
        init_method='tcp://127.0.0.1:{}'.format(args.port),
        world_size=args.gpu_num,
        rank=rank,
        timeout=datetime.timedelta(seconds=300))
            # self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model).cuda(self.gpu)


    if args.use_float16:
        from torch.cuda.amp import autocast as autocast, GradScaler
        scaler = GradScaler()
    else:
        scaler = None
        autocast = None

    label_num_=args.num_class
    net_encoder = ModelBuilder.build_encoder(
        arch=cfg.MODEL.arch_encoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        weights=cfg.MODEL.weights_encoder)
    net_decoder = ModelBuilder.build_decoder(
        arch=cfg.MODEL.arch_decoder.lower(),
        fc_dim=cfg.MODEL.fc_dim,
        num_class=label_num_,
        weights=cfg.MODEL.weights_decoder)

    crit = nn.NLLLoss(ignore_index=255)

    if cfg.MODEL.arch_decoder.endswith('deepsup'):
        segmentation_module = SegmentationModule(
            net_encoder, net_decoder, crit, cfg.TRAIN.deep_sup_scale)
    else:
        segmentation_module = SegmentationModule(
            net_encoder, net_decoder, crit)

    if args.use_clipdataset:
        dataset_train = BaseDataset_longclip(args,'train')
    else:
        dataset_train = BaseDataset(
            args,
            'train'
            )

    sampler_train =torch.utils.data.distributed.DistributedSampler(dataset_train)
    loader_train = torch.utils.data.DataLoader(dataset_train,  batch_size=args.batchsize,  shuffle=False,sampler=sampler_train,   pin_memory=True,
                                    num_workers=args.workers)


    print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))

    dataset_val = BaseDataset(
        args,
        'val'
        )
    sampler_val =torch.utils.data.distributed.DistributedSampler(dataset_val)
    loader_val = torch.utils.data.DataLoader(dataset_val,  batch_size=args.batchsize,  shuffle=False,sampler=sampler_val,   pin_memory=True,
                                    num_workers=args.workers)
#    loader_val = torch.utils.data.DataLoader(dataset_val,batch_size=args.batchsize,shuffle=False,num_workers=args.workers)
    # create loader iterator
    

    # load nets into gpu

    segmentation_module = segmentation_module.cuda(load_gpu)

    segmentation_module= nn.SyncBatchNorm.convert_sync_batchnorm(segmentation_module)

    if args.resume_epoch!=0:
#        if dist.get_rank() == 0:
        to_load = torch.load(os.path.join('./resume','model_epoch_{}.pth'.format(args.resume_epoch)),map_location=torch.device("cuda:"+str(load_gpu)))
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in to_load.items():
            name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
            new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
        cfg.TRAIN.start_epoch=args.resume_epoch
        segmentation_module.load_state_dict(new_state_dict)


    segmentation_module= torch.nn.parallel.DistributedDataParallel(
                    segmentation_module,
                device_ids=[load_gpu],
                find_unused_parameters=True)

    # Set up optimizers
#    nets = (net_encoder, net_decoder, crit)
    nets = segmentation_module
    optimizers = create_optimizers(segmentation_module, cfg)
    if args.resume_epoch!=0:
#        if dist.get_rank() == 0:
        optimizers.load_state_dict(torch.load(os.path.join('./resume','opt_epoch_{}.pth'.format(args.resume_epoch)),map_location=torch.device("cuda:"+str(load_gpu))))
        print('resume from epoch {}'.format(args.resume_epoch))

    # Main loop
    history = {'train': {'epoch': [], 'loss': [], 'acc': []}}

#    test(segmentation_module,loader_val,args)
    for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
        if dist.get_rank() == 0 and epoch==0:
            checkpoint(nets,optimizers, history, args, epoch+1)
        print('Epoch {}'.format(epoch))
        train(segmentation_module, loader_train, optimizers, history, epoch+1, cfg,args,load_gpu,scaler=scaler,autocast=autocast)

###################        # checkpointing
        if dist.get_rank() == 0 and (epoch+1)%10==0:
            checkpoint(segmentation_module,optimizers, history, args, epoch+1)
        if args.validation:
            test(segmentation_module,loader_val,args)

    print('Training Done!')
Пример #10
0
    def __init__(self, c: Configs, name: str):
        self.name = name
        self.c = c
        # total number of samples for a single update
        self.envs = self.c.n_workers * self.c.env_per_worker
        self.batch_size = self.envs * self.c.worker_steps
        assert (self.batch_size %
                (self.c.n_update_per_epoch * self.c.mini_batch_size) == 0)
        self.update_batch_size = self.batch_size // self.c.n_update_per_epoch

        # #### Initialize
        self.total_games = 0

        # model for sampling
        self.model = Model(c.channels, c.blocks).to(device)

        # dynamic hyperparams
        self.cur_lr = self.c.lr()
        self.cur_reg_l2 = self.c.reg_l2()
        self.cur_step_reward = 0.
        self.cur_right_gain = 0.
        self.cur_fix_prob = 0.
        self.cur_neg_mul = 0.
        self.cur_entropy_weight = self.c.entropy_weight()
        self.cur_prob_reg_weight = self.c.prob_reg_weight()
        self.cur_target_prob_weight = self.c.target_prob_weight()
        self.cur_gamma = self.c.gamma()
        self.cur_lamda = self.c.lamda()

        # optimizer
        self.scaler = GradScaler()
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.cur_lr,
                                    weight_decay=self.cur_reg_l2)

        # initialize tensors for observations
        shapes = [(self.envs, *kTensorDim),
                  (self.envs, self.c.worker_steps, 3),
                  (self.envs, self.c.worker_steps)]
        types = [np.dtype('float32'), np.dtype('float32'), np.dtype('bool')]
        self.shms = [
            shared_memory.SharedMemory(create=True,
                                       size=math.prod(shape) * typ.itemsize)
            for shape, typ in zip(shapes, types)
        ]
        self.obs_np, self.rewards, self.done = [
            np.ndarray(shape, dtype=typ, buffer=shm.buf)
            for shm, shape, typ in zip(self.shms, shapes, types)
        ]
        # create workers
        shm = [(shm.name, shape, typ)
               for shm, shape, typ in zip(self.shms, shapes, types)]
        self.workers = [
            Worker(name, shm, self.w_range(i), 27 + i)
            for i in range(self.c.n_workers)
        ]
        self.set_game_param(self.c.right_gain(), self.c.fix_prob(),
                            self.c.neg_mul(), self.c.step_reward())
        for i in self.workers:
            i.child.send(('reset', None))
        for i in self.workers:
            i.child.recv()

        self.obs = obs_to_torch(self.obs_np, device)
Пример #11
0
class Imagine(nn.Module):
    def __init__(
            self,
            *,
            text=None,
            img=None,
            clip_encoding=None,
            lr=1e-5,
            batch_size=4,
            gradient_accumulate_every=4,
            save_every=100,
            image_width=512,
            num_layers=16,
            epochs=20,
            iterations=1050,
            save_progress=True,
            seed=None,
            open_folder=True,
            save_date_time=False,
            start_image_path=None,
            start_image_train_iters=10,
            start_image_lr=3e-4,
            theta_initial=None,
<<<<<<< HEAD
            theta_hidden=None,
<<<<<<< HEAD
            lower_bound_cutout=0.1, # should be smaller than 0.8
            upper_bound_cutout=1.0,
            saturate_bound=False,
            create_story=False,
            story_start_words=5,
            story_words_per_epoch=5,
=======
            savetodrive=False,
            drive_location=""
>>>>>>> add option to save to gdrive
=======
            theta_hidden=None
>>>>>>> strange argument behavior, rolling back gdrive saving
    ):

        super().__init__()

        if exists(seed):
            tqdm.write(f'setting seed: {seed}')
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            random.seed(seed)
            torch.backends.cudnn.deterministic = True
            
        # fields for story creation:
        self.create_story = create_story
        self.words = None
        self.all_words = text.split(" ") if text is not None else None
        self.num_start_words = story_start_words
        self.words_per_epoch = story_words_per_epoch
        if create_story:
            assert text is not None,  "We need text input to create a story..."
            # overwrite epochs to match story length
            num_words = len(self.all_words)
            self.epochs = 1 + (num_words - self.num_start_words) / self.words_per_epoch
            # add one epoch if not divisible
            self.epochs = int(self.epochs) if int(self.epochs) == self.epochs else int(self.epochs) + 1
            print("Running for ", self.epochs, "epochs")
        else: 
            self.epochs = epochs
        
        self.iterations = iterations
        self.image_width = image_width
        total_batches = self.epochs * self.iterations * batch_size * gradient_accumulate_every
        model = DeepDaze(
            total_batches=total_batches,
            batch_size=batch_size,
            image_width=image_width,
            num_layers=num_layers,
            theta_initial=theta_initial,
            theta_hidden=theta_hidden,
            lower_bound_cutout=lower_bound_cutout,
            upper_bound_cutout=upper_bound_cutout,
            saturate_bound=saturate_bound,
        ).cuda()

        self.model = model
        self.scaler = GradScaler()
        self.optimizer = AdamP(model.parameters(), lr)
        self.gradient_accumulate_every = gradient_accumulate_every
        self.save_every = save_every
        self.save_date_time = save_date_time
        self.open_folder = open_folder
        self.save_progress = save_progress
        self.text = text
        self.image = img
        self.textpath = create_text_path(text=text, img=img, encoding=clip_encoding)
        self.filename = self.image_output_path()
        
        # create coding to optimize for
        self.clip_img_transform = create_clip_img_transform(perceptor.input_resolution.item())
        self.clip_encoding = self.create_clip_encoding(text=text, img=img, encoding=clip_encoding)

        self.start_image = None
        self.start_image_train_iters = start_image_train_iters
        self.start_image_lr = start_image_lr
        if exists(start_image_path):
            file = Path(start_image_path)
            assert file.exists(), f'file does not exist at given starting image path {self.start_image_path}'
            image = Image.open(str(file))

            image_tensor = self.clip_img_transform(image)[None, ...].cuda()
            self.start_image = image_tensor
def main_worker(gpu, args):

    args.gpu = gpu

    if args.distributed:
        args.rank = args.rank * torch.cuda.device_count() + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    print(args.rank, " gpu", args.gpu)

    torch.cuda.set_device(
        args.gpu
    )  # use this default device (same as args.device if not distributed)
    torch.backends.cudnn.benchmark = True

    if args.rank == 0:
        print("Batch size is:", args.batch_size, "epochs", args.epochs)

    #############
    # Create MONAI dataset
    training_list = load_decathlon_datalist(
        data_list_file_path=args.dataset_json,
        data_list_key="training",
        base_dir=args.data_root,
    )
    validation_list = load_decathlon_datalist(
        data_list_file_path=args.dataset_json,
        data_list_key="validation",
        base_dir=args.data_root,
    )

    if args.quick:  # for debugging on a small subset
        training_list = training_list[:16]
        validation_list = validation_list[:16]

    train_transform = Compose([
        LoadImageD(keys=["image"],
                   reader=WSIReader,
                   backend="TiffFile",
                   dtype=np.uint8,
                   level=1,
                   image_only=True),
        LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes),
        TileOnGridd(
            keys=["image"],
            tile_count=args.tile_count,
            tile_size=args.tile_size,
            random_offset=True,
            background_val=255,
            return_list_of_dicts=True,
        ),
        RandFlipd(keys=["image"], spatial_axis=0, prob=0.5),
        RandFlipd(keys=["image"], spatial_axis=1, prob=0.5),
        RandRotate90d(keys=["image"], prob=0.5),
        ScaleIntensityRangeD(keys=["image"],
                             a_min=np.float32(255),
                             a_max=np.float32(0)),
        ToTensord(keys=["image", "label"]),
    ])

    valid_transform = Compose([
        LoadImageD(keys=["image"],
                   reader=WSIReader,
                   backend="TiffFile",
                   dtype=np.uint8,
                   level=1,
                   image_only=True),
        LabelEncodeIntegerGraded(keys=["label"], num_classes=args.num_classes),
        TileOnGridd(
            keys=["image"],
            tile_count=None,
            tile_size=args.tile_size,
            random_offset=False,
            background_val=255,
            return_list_of_dicts=True,
        ),
        ScaleIntensityRangeD(keys=["image"],
                             a_min=np.float32(255),
                             a_max=np.float32(0)),
        ToTensord(keys=["image", "label"]),
    ])

    dataset_train = Dataset(data=training_list, transform=train_transform)
    dataset_valid = Dataset(data=validation_list, transform=valid_transform)

    train_sampler = DistributedSampler(
        dataset_train) if args.distributed else None
    val_sampler = DistributedSampler(
        dataset_valid, shuffle=False) if args.distributed else None

    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        multiprocessing_context="spawn",
        sampler=train_sampler,
        collate_fn=list_data_collate,
    )
    valid_loader = torch.utils.data.DataLoader(
        dataset_valid,
        batch_size=1,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        multiprocessing_context="spawn",
        sampler=val_sampler,
        collate_fn=list_data_collate,
    )

    if args.rank == 0:
        print("Dataset training:", len(dataset_train), "validation:",
              len(dataset_valid))

    model = milmodel.MILModel(num_classes=args.num_classes,
                              pretrained=True,
                              mil_mode=args.mil_mode)

    best_acc = 0
    start_epoch = 0
    if args.checkpoint is not None:
        checkpoint = torch.load(args.checkpoint, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"])
        if "epoch" in checkpoint:
            start_epoch = checkpoint["epoch"]
        if "best_acc" in checkpoint:
            best_acc = checkpoint["best_acc"]
        print("=> loaded checkpoint '{}' (epoch {}) (bestacc {})".format(
            args.checkpoint, start_epoch, best_acc))

    model.cuda(args.gpu)

    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.gpu], output_device=args.gpu)

    if args.validate:
        # if we only want to validate existing checkpoint
        epoch_time = time.time()
        val_loss, val_acc, qwk = val_epoch(model,
                                           valid_loader,
                                           epoch=0,
                                           args=args,
                                           max_tiles=args.tile_count)
        if args.rank == 0:
            print(
                "Final validation loss: {:.4f}".format(val_loss),
                "acc: {:.4f}".format(val_acc),
                "qwk: {:.4f}".format(qwk),
                "time {:.2f}s".format(time.time() - epoch_time),
            )

        exit(0)

    params = model.parameters()

    if args.mil_mode in ["att_trans", "att_trans_pyramid"]:
        m = model if not args.distributed else model.module
        params = [
            {
                "params":
                list(m.attention.parameters()) + list(m.myfc.parameters()) +
                list(m.net.parameters())
            },
            {
                "params": list(m.transformer.parameters()),
                "lr": 6e-6,
                "weight_decay": 0.1
            },
        ]

    optimizer = torch.optim.AdamW(params,
                                  lr=args.optim_lr,
                                  weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=args.epochs,
                                                           eta_min=0)

    if args.logdir is not None and args.rank == 0:
        writer = SummaryWriter(log_dir=args.logdir)
        if args.rank == 0:
            print("Writing Tensorboard logs to ", writer.log_dir)
    else:
        writer = None

    ###RUN TRAINING
    n_epochs = args.epochs
    val_acc_max = 0.0

    scaler = None
    if args.amp:  # new native amp
        scaler = GradScaler()

    for epoch in range(start_epoch, n_epochs):

        if args.distributed:
            train_sampler.set_epoch(epoch)
            torch.distributed.barrier()

        print(args.rank, time.ctime(), "Epoch:", epoch)

        epoch_time = time.time()
        train_loss, train_acc = train_epoch(model,
                                            train_loader,
                                            optimizer,
                                            scaler=scaler,
                                            epoch=epoch,
                                            args=args)

        if args.rank == 0:
            print(
                "Final training  {}/{}".format(epoch, n_epochs - 1),
                "loss: {:.4f}".format(train_loss),
                "acc: {:.4f}".format(train_acc),
                "time {:.2f}s".format(time.time() - epoch_time),
            )

        if args.rank == 0 and writer is not None:
            writer.add_scalar("train_loss", train_loss, epoch)
            writer.add_scalar("train_acc", train_acc, epoch)

        if args.distributed:
            torch.distributed.barrier()

        b_new_best = False
        val_acc = 0
        if (epoch + 1) % args.val_every == 0:

            epoch_time = time.time()
            val_loss, val_acc, qwk = val_epoch(model,
                                               valid_loader,
                                               epoch=epoch,
                                               args=args,
                                               max_tiles=args.tile_count)
            if args.rank == 0:
                print(
                    "Final validation  {}/{}".format(epoch, n_epochs - 1),
                    "loss: {:.4f}".format(val_loss),
                    "acc: {:.4f}".format(val_acc),
                    "qwk: {:.4f}".format(qwk),
                    "time {:.2f}s".format(time.time() - epoch_time),
                )
                if writer is not None:
                    writer.add_scalar("val_loss", val_loss, epoch)
                    writer.add_scalar("val_acc", val_acc, epoch)
                    writer.add_scalar("val_qwk", qwk, epoch)

                val_acc = qwk

                if val_acc > val_acc_max:
                    print("qwk ({:.6f} --> {:.6f})".format(
                        val_acc_max, val_acc))
                    val_acc_max = val_acc
                    b_new_best = True

        if args.rank == 0 and args.logdir is not None:
            save_checkpoint(model,
                            epoch,
                            args,
                            best_acc=val_acc,
                            filename="model_final.pt")
            if b_new_best:
                print("Copying to model.pt new best model!!!!")
                shutil.copyfile(os.path.join(args.logdir, "model_final.pt"),
                                os.path.join(args.logdir, "model.pt"))

        scheduler.step()

    print("ALL DONE")
def fit(data, fold=None, log=True):

    best_score = 0.0
    model = EfficientNet("tf_efficientnet_b0_ns").to(device)
    # model.load_state_dict(
    #     torch.load("/content/siim-isic_efficientnet_b0_2.ckpt")[
    #         "model_state_dict"
    #     ]
    # )
    # if log:
    #    neptune.init("utsav/SIIM-ISIC", api_token=NEPTUNE_API_TOKEN)
    #    neptune.create_experiment(
    #        FLAGS["exp_name"],
    #        exp_description,
    #        params=FLAGS,
    #        upload_source_files="*.txt",
    #    )

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=FLAGS["learning_rate"],
        weight_decay=FLAGS["weight_decay"],
    )

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        factor=0.5,
        cooldown=0,
        mode="min",
        patience=3,
        verbose=True,
        min_lr=1e-8,
    )

    datasets = get_datasets(data)

    # sampler
    # labels_vcount = y_train["target"].value_counts()
    # class_counts = [
    #     labels_vcount[0].astype(np.float32),
    #     labels_vcount[1].astype(np.float32),
    # ]
    # num_samples = sum(class_counts)
    # class_weights = [
    #     num_samples / class_counts[i] for i in range(len(class_counts))
    # ]
    # weights = [
    #     class_weights[y_train["target"].values[i]]
    #     for i in range(int(num_samples))
    # ]
    # sampler = WeightedRandomSampler(
    #     torch.DoubleTensor(weights), int(num_samples)
    # )

    # loaders
    train_loader = DataLoader(
        datasets["train"],
        batch_size=FLAGS["batch_size"],
        num_workers=FLAGS["num_workers"],
        shuffle=True,  # sampler=sampler,
        pin_memory=True,
    )
    val_loader = DataLoader(
        datasets["valid"],
        batch_size=FLAGS["batch_size"] * 2,
        shuffle=False,
        num_workers=FLAGS["num_workers"],
        drop_last=True,
    )

    scaler = GradScaler()
    # train loop
    for epoch in range(0, FLAGS["num_epochs"]):

        print("-" * 27 + f"Epoch #{epoch+1} started" + "-" * 27)

        train_loss = train_one_epoch(
            train_loader,
            model,
            optimizer,
            epoch,
            scheduler=None,
            scaler=scaler,
            log=log,
        )

        print(f"\nAverage loss for epoch #{epoch+1} : {train_loss:.5f}")
        val_output = val_one_epoch(val_loader, model)
        val_loss, auc_score, roc_plot, hist, error_scaled = val_output
        scheduler.step(error_scaled)

        # logs
        # if log:
        #     neptune.log_metric("AUC/val", auc_score)
        #     neptune.log_image("ROC/val", roc_plot)
        #     neptune.log_metric("Loss/val", val_loss)
        #     neptune.log_image("hist/val", hist)

        # checkpoint+upload
        if (auc_score > best_score) or (best_score - auc_score < 0.025):
            if auc_score > best_score:
                best_score = auc_score
            save_upload(
                model,
                optimizer,
                best_score,
                epoch,
                fold,
                exp_name=FLAGS["exp_name"],
            )

        print("-" * 28 + f"Epoch #{epoch+1} ended" + "-" * 28)

    # if log:
    #    neptune.stop()

    return model
Пример #14
0
                              n_channels=args.outchannel,
                              reduction='mean').to(device)
    elif args.lossid == 3:
        loss_func = SSIM(data_range=1, channel=args.outchannel,
                         spatial_dims=3).to(device)
    else:
        sys.exit("Invalid Loss ID")
    if (args.lossid == 0 and args.plosstyp == "L1") or (args.lossid == 1):
        IsNegLoss = False
    else:
        IsNegLoss = True

    if (args.modelid == 7) or (args.modelid == 8):
        model.loss_func = loss_func

    scaler = GradScaler(enabled=args.amp)

    if args.chkpoint:
        chk = torch.load(args.chkpoint, map_location=device)
    elif args.finetune:
        if args.chkpointft:
            chk = torch.load(args.chkpointft, map_location=device)
        else:
            sys.exit("Finetune can't be performed if chkpointft not supplied")
    else:
        chk = None
        start_epoch = 0
        best_loss = float('-inf') if IsNegLoss else float('inf')

    if chk is not None:
        model.load_state_dict(chk['state_dict'])
Пример #15
0
def main():
    torch.manual_seed(317)
    torch.backends.cudnn.benckmark = True

    train_logger = Logger(opt, "train")
    val_logger = Logger(opt, "val")

    start_epoch = 0
    print('Creating model...')
    model = get_model(opt.arch, opt.heads).to(opt.device)
    optimizer = torch.optim.Adam(model.parameters(), opt.lr)
    criterion = CtdetLoss(opt)

    print('Loading model...')
    if opt.load_model != '':
        model, optimizer, start_epoch = load_model(
            model, opt.load_model, optimizer, opt.lr, opt.lr_step)
    model = torch.nn.DataParallel(model)

    # amp
    scaler = GradScaler()

    print('Setting up data...')
    train_dataset = Dataset(opt, 'train')
    val_dataset = Dataset(opt, 'val')

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=16,
        pin_memory=True,
        drop_last=True
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=True
    )
    # cal left time
    time_stats = TimeMeter(opt.num_epochs, len(train_loader))

    for epoch in range(start_epoch + 1, opt.num_epochs + 1):
        print('train...')
        train(model, train_loader, criterion, optimizer,
              train_logger, opt, epoch, scaler, time_stats)

        if epoch % opt.val_intervals == 0:
            print('val...')
            val(model, val_loader, criterion, val_logger, opt, epoch)
            save_model(os.path.join(opt.save_dir, f'model_{epoch}.pth'),
                       epoch, model, optimizer)

        # update learning rate
        if epoch in opt.lr_step:
            lr = opt.lr * (0.1 ** (opt.lr_step.index(epoch) + 1))
            print('Drop LR to', lr)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    # without optimizer
    save_model(os.path.join(opt.save_dir, 'model_final.pth'), epoch, model)
Пример #16
0
def train_imagenet():
    print("==> Preparing data..")
    img_dim = get_model_property("img_dim")
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        # train_loader = xu.SampleGenerator(
        #     data=(
        #         torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
        #         torch.zeros(FLAGS.batch_size, dtype=torch.int64),
        #     ),
        #     sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(),
        # )
        train_loader = xu.SampleGenerator(
            data=(
                torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            ),
            sample_count=train_dataset_len // FLAGS.batch_size,
        )
        # test_loader = xu.SampleGenerator(
        #     data=(
        #         torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
        #         torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64),
        #     ),
        #     sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(),
        # )
        test_loader = xu.SampleGenerator(
            data=(
                torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64),
            ),
            sample_count=50000 // FLAGS.batch_size,
        )
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, "train"),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]),
        )
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, "val"),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]),
        )

        train_sampler, test_sampler = None, None
        # if xm.xrt_world_size() > 1:
        #     train_sampler = torch.utils.data.distributed.DistributedSampler(
        #         train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True
        #     )
        #     test_sampler = torch.utils.data.distributed.DistributedSampler(
        #         test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False
        #     )
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers,
        )

    torch.manual_seed(42)

    # device = xm.xla_device()
    device = torch.device("cuda")  # Default CUDA device
    model = get_model_property("model_fn")().to(device)
    writer = None
    # if xm.is_master_ordinal():
    #     writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=1e-4)
    # num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size())
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * 1)
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, "lr_scheduler_type", None),
        scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, "lr_scheduler_divide_every_n_epochs", None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer,
    )
    loss_fn = nn.CrossEntropyLoss()
    scaler = GradScaler()

    def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            data = data.to(device)
            target = target.to(device)
            optimizer.zero_grad()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            # optimizer.step()
            # xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if step % FLAGS.log_steps == 0:
                _train_update(device, step, loss, tracker, epoch, writer)
                # xm.add_step_closure(
                #     _train_update, args=(device, step, loss, tracker, epoch, writer)
                # )

    def test_loop_fn(loader, epoch):
        total_samples, correct = 0, 0
        model.eval()
        for step, (data, target) in enumerate(loader):
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]
            if step % FLAGS.log_steps == 0:
                test_utils.print_test_update(device, None, epoch, step)
                # xm.add_step_closure(test_utils.print_test_update, args=(device, None, epoch, step))
        accuracy = 100.0 * correct.item() / total_samples
        # accuracy = xm.mesh_reduce("test_accuracy", accuracy, np.mean)
        return accuracy

    # train_device_loader = pl.MpDeviceLoader(train_loader, device)
    # test_device_loader = pl.MpDeviceLoader(test_loader, device)
    accuracy, max_accuracy = 0.0, 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        print("Epoch {} train begin {}".format(epoch, test_utils.now()))
        train_loop_fn(train_loader, epoch)
        print("Epoch {} train end {}".format(epoch, test_utils.now()))
        accuracy = test_loop_fn(test_loader, epoch)
        print("Epoch {} test end {}, Accuracy={:.2f}".format(
            epoch, test_utils.now(), accuracy))
        max_accuracy = max(accuracy, max_accuracy)
        test_utils.write_to_summary(writer,
                                    epoch,
                                    dict_to_write={"Accuracy/test": accuracy},
                                    write_xla_metrics=True)
        # if FLAGS.metrics_debug:
        #     print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    print("Max Accuracy: {:.2f}%".format(max_accuracy))
    return max_accuracy
Пример #17
0
    def __init__(self,
                 name='default',
                 results_dir='results',
                 models_dir='models',
                 base_dir='./',
                 optimizer="adam",
                 latent_dim=256,
                 image_size=128,
                 fmap_max=512,
                 transparent=False,
                 greyscale=False,
                 batch_size=4,
                 gp_weight=10,
                 gradient_accumulate_every=1,
                 attn_res_layers=[],
                 disc_output_size=5,
                 antialias=False,
                 lr=2e-4,
                 lr_mlp=1.,
                 ttur_mult=1.,
                 save_every=1000,
                 evaluate_every=1000,
                 trunc_psi=0.6,
                 aug_prob=None,
                 aug_types=['translation', 'cutout'],
                 dataset_aug_prob=0.,
                 calculate_fid_every=None,
                 is_ddp=False,
                 rank=0,
                 world_size=1,
                 log=False,
                 amp=False,
                 *args,
                 **kwargs):
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir
        self.config_path = self.models_dir / name / '.config.json'

        assert is_power_of_two(
            image_size
        ), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
        assert all(
            map(is_power_of_two, attn_res_layers)
        ), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)'

        self.optimizer = optimizer
        self.latent_dim = latent_dim
        self.image_size = image_size
        self.fmap_max = fmap_max
        self.transparent = transparent
        self.greyscale = greyscale

        assert (int(self.transparent) + int(self.greyscale)
                ) < 2, 'you can only set either transparency or greyscale'

        self.aug_prob = aug_prob
        self.aug_types = aug_types

        self.lr = lr
        self.ttur_mult = ttur_mult
        self.batch_size = batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        self.gp_weight = gp_weight

        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        self.generator_top_k_gamma = 0.99
        self.generator_top_k_frac = 0.5

        self.attn_res_layers = attn_res_layers
        self.disc_output_size = disc_output_size
        self.antialias = antialias

        self.d_loss = 0
        self.g_loss = 0
        self.last_gp_loss = None
        self.last_recon_loss = None
        self.last_fid = None

        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.calculate_fid_every = calculate_fid_every

        self.is_ddp = is_ddp
        self.is_main = rank == 0
        self.rank = rank
        self.world_size = world_size

        self.syncbatchnorm = is_ddp

        self.amp = amp
        self.G_scaler = GradScaler(enabled=self.amp)
        self.D_scaler = GradScaler(enabled=self.amp)
Пример #18
0
def train():

    # Training DataLoader
    dataset_train = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'],
                          mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'],
                          mode='RGB'),
        ],
                   transforms=A.PairCompose([
                       A.PairRandomAffineAndResize((512, 512),
                                                   degrees=(-5, 5),
                                                   translate=(0.1, 0.1),
                                                   scale=(0.4, 1),
                                                   shear=(-5, 5)),
                       A.PairRandomHorizontalFlip(),
                       A.PairRandomBoxBlur(0.1, 5),
                       A.PairRandomSharpen(0.1),
                       A.PairApplyOnlyAtIndices([1],
                                                T.ColorJitter(
                                                    0.15, 0.15, 0.15, 0.05)),
                       A.PairApply(T.ToTensor())
                   ]),
                   assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['train'],
                      mode='RGB',
                      transforms=T.Compose([
                          A.RandomAffineAndResize((512, 512),
                                                  degrees=(-5, 5),
                                                  translate=(0.1, 0.1),
                                                  scale=(1, 2),
                                                  shear=(-5, 5)),
                          T.RandomHorizontalFlip(),
                          A.RandomBoxBlur(0.1, 5),
                          A.RandomSharpen(0.1),
                          T.ColorJitter(0.15, 0.15, 0.15, 0.05),
                          T.ToTensor()
                      ])),
    ])
    dataloader_train = DataLoader(dataset_train,
                                  shuffle=True,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    # Validation DataLoader
    dataset_valid = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'],
                          mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'],
                          mode='RGB')
        ],
                   transforms=A.PairCompose([
                       A.PairRandomAffineAndResize((512, 512),
                                                   degrees=(-5, 5),
                                                   translate=(0.1, 0.1),
                                                   scale=(0.3, 1),
                                                   shear=(-5, 5)),
                       A.PairApply(T.ToTensor())
                   ]),
                   assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['valid'],
                      mode='RGB',
                      transforms=T.Compose([
                          A.RandomAffineAndResize((512, 512),
                                                  degrees=(-5, 5),
                                                  translate=(0.1, 0.1),
                                                  scale=(1, 1.2),
                                                  shear=(-5, 5)),
                          T.ToTensor()
                      ])),
    ])
    dataset_valid = SampleDataset(dataset_valid, 50)
    dataloader_valid = DataLoader(dataset_valid,
                                  pin_memory=True,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    # Model
    model = MattingBase(args.model_backbone).cuda()

    if args.model_last_checkpoint is not None:
        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
    elif args.model_pretrain_initialization is not None:
        model.load_pretrained_deeplabv3_state_dict(
            torch.load(args.model_pretrain_initialization)['model_state'])

    optimizer = Adam([{
        'params': model.backbone.parameters(),
        'lr': 1e-4
    }, {
        'params': model.aspp.parameters(),
        'lr': 5e-4
    }, {
        'params': model.decoder.parameters(),
        'lr': 5e-4
    }])
    scaler = GradScaler()

    # Logging and checkpoints
    if not os.path.exists(f'checkpoint/{args.model_name}'):
        os.makedirs(f'checkpoint/{args.model_name}')
    writer = SummaryWriter(f'log/{args.model_name}')

    # Run loop
    for epoch in range(args.epoch_start, args.epoch_end):
        for i, ((true_pha, true_fgr),
                true_bgr) in enumerate(tqdm(dataloader_train)):
            step = epoch * len(dataloader_train) + i

            true_pha = true_pha.cuda(non_blocking=True)
            true_fgr = true_fgr.cuda(non_blocking=True)
            true_bgr = true_bgr.cuda(non_blocking=True)
            true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr,
                                                       true_bgr)

            true_src = true_bgr.clone()

            # Augment with shadow
            aug_shadow_idx = torch.rand(len(true_src)) < 0.3
            if aug_shadow_idx.any():
                aug_shadow = true_pha[aug_shadow_idx].mul(0.3 *
                                                          random.random())
                aug_shadow = T.RandomAffine(degrees=(-5, 5),
                                            translate=(0.2, 0.2),
                                            scale=(0.5, 1.5),
                                            shear=(-5, 5))(aug_shadow)
                aug_shadow = kornia.filters.box_blur(
                    aug_shadow, (random.choice(range(20, 40)), ) * 2)
                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(
                    aug_shadow).clamp_(0, 1)
                del aug_shadow
            del aug_shadow_idx

            # Composite foreground onto source
            true_src = true_fgr * true_pha + true_src * (1 - true_pha)

            # Augment with noise
            aug_noise_idx = torch.rand(len(true_src)) < 0.4
            if aug_noise_idx.any():
                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(
                    torch.randn_like(true_src[aug_noise_idx]).mul_(
                        0.03 * random.random())).clamp_(0, 1)
                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(
                    torch.randn_like(true_bgr[aug_noise_idx]).mul_(
                        0.03 * random.random())).clamp_(0, 1)
            del aug_noise_idx

            # Augment background with jitter
            aug_jitter_idx = torch.rand(len(true_src)) < 0.8
            if aug_jitter_idx.any():
                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(
                    0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
            del aug_jitter_idx

            # Augment background with affine
            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
            if aug_affine_idx.any():
                true_bgr[aug_affine_idx] = T.RandomAffine(
                    degrees=(-1, 1),
                    translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
            del aug_affine_idx

            with autocast():
                pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
                loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha,
                                    true_fgr)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            if (i + 1) % args.log_train_loss_interval == 0:
                writer.add_scalar('loss', loss, step)

            if (i + 1) % args.log_train_images_interval == 0:
                writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5),
                                 step)
                writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5),
                                 step)
                writer.add_image('train_pred_com',
                                 make_grid(pred_fgr * pred_pha, nrow=5), step)
                writer.add_image('train_pred_err', make_grid(pred_err, nrow=5),
                                 step)
                writer.add_image('train_true_src', make_grid(true_src, nrow=5),
                                 step)
                writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5),
                                 step)

            del true_pha, true_fgr, true_bgr
            del pred_pha, pred_fgr, pred_err

            if (i + 1) % args.log_valid_interval == 0:
                valid(model, dataloader_valid, writer, step)

            if (step + 1) % args.checkpoint_interval == 0:
                torch.save(
                    model.state_dict(),
                    f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth'
                )

        torch.save(model.state_dict(),
                   f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
Пример #19
0
 def __init__(self, model_set, device, config):
     self.model = model_set['model']
     self.loss = model_set['loss']
     self.optimizer = model_set['optimizer']
     self.device = device
     self.scaler = GradScaler()
def main(**kwargs):
    fp16 = kwargs['fp16']
    batch_size = kwargs['batch_size']
    num_workers = kwargs['num_workers']
    pin_memory = kwargs['pin_memory']
    patience_limit = kwargs['patience']
    initial_lr = kwargs['learning_rate']
    gradient_accumulation_steps = kwargs['grad_accumulation_steps']
    device = kwargs['device']
    epochs = kwargs['epochs']
    freeze_bert = kwargs['freeze_bert_layers']
    slots = kwargs['slots']

    if fp16:
        scaler = GradScaler()

    tokenizer = BertTokenizer.from_pretrained(
        'bert-base-uncased', model_max_length=128
    )  # for TM_1, out of 303066 samples, 5 are above 128 tokens

    if kwargs['dataset'] == "TM":
        train_data, val_data = load_taskmaster_datasets(
            utils.datasets,
            tokenizer,
            train_percent=0.9,
            for_testing_purposes=kwargs['testing_for_bugs'])

    if kwargs['dataset'] == "MW":
        train_data = load_multiwoz_dataset("multi-woz/train_dials.json",
                                           tokenizer, slots,
                                           kwargs['testing_for_bugs'])
        val_data = load_multiwoz_dataset("multi-woz/dev_dials.json", tokenizer,
                                         slots, kwargs['testing_for_bugs'])

    if kwargs['dataset'] == "MW22":
        train_data = load_MW_22_dataset_training(tokenizer, slots,
                                                 kwargs['testing_for_bugs'])
        val_data = load_MW_22_dataset_validation(tokenizer, slots,
                                                 kwargs['testing_for_bugs'])

    train_dataset = VE_dataset(train_data)
    val_dataset = VE_dataset(val_data)

    collator = collate_class(tokenizer.pad_token_id)
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   collate_fn=collator,
                                                   num_workers=num_workers,
                                                   pin_memory=pin_memory)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=int(batch_size /
                                                                2),
                                                 shuffle=True,
                                                 collate_fn=collator,
                                                 num_workers=num_workers,
                                                 pin_memory=pin_memory)

    # BertForValueExtraction
    if kwargs['model_path'] and os.path.isdir(kwargs['model_path']):
        from_pretrained = kwargs['model_path']
    else:
        from_pretrained = 'bert-base-uncased'
    model = BertForValueExtraction(num_labels=len(label2id.keys()),
                                   from_pretrained=from_pretrained,
                                   freeze_bert=freeze_bert)
    model.to(device)
    model.train()

    optimizer = torch.optim.Adam(params=model.parameters(), lr=initial_lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'max',
                                                           factor=0.5,
                                                           patience=1,
                                                           min_lr=initial_lr /
                                                           100,
                                                           verbose=True)

    best_acc, count = 0, 0

    for epoch in range(epochs):
        # train loop
        total_loss = 0
        optimizer.zero_grad()
        pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
        for i, batch in pbar:

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            labels = batch['labels'].to(device)
            text = batch['text']
            # print(f"input_ids: {input_ids.shape} - {input_ids}")
            # print(f"attention_mask: {attention_mask.shape} - {attention_mask}")
            # print(f"token_type_ids: {token_type_ids.shape} - {token_type_ids}")
            # print(f"labels: {labels.shape} - {labels}")
            if fp16:
                with autocast():
                    loss = model.calculate_loss(input_ids=input_ids,
                                                attention_mask=attention_mask,
                                                token_type_ids=token_type_ids,
                                                labels=labels)
                    total_loss += loss.item()
                    loss = loss / gradient_accumulation_steps

                scaler.scale(loss).backward()

            else:
                loss = model.calculate_loss(input_ids=input_ids,
                                            attention_mask=attention_mask,
                                            token_type_ids=token_type_ids,
                                            labels=labels)
                loss.backward()
                total_loss += loss.item()
            if ((i + 1) % gradient_accumulation_steps) == 0:
                if fp16:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                optimizer.zero_grad()

            batch_num = ((i + 1) / gradient_accumulation_steps)
            pbar.set_description(f"Loss: {total_loss/batch_num:.4f}")

        # validation loop
        model.eval()
        with torch.no_grad():
            TP, FP, FN, TN = model.evaluate(val_dataloader, device)

            pr = utils.calculate_precision(TP, FP)
            re = utils.calculate_recall(TP, FN)
            F1 = utils.calculate_F1(TP, FP, FN)
            acc = utils.calculate_accuracy(TP, FP, FN, TN)
            balanced_acc = utils.calculate_balanced_accuracy(TP, FP, FN, TN)
            print(
                f"Validation: pr {pr:.4f} - re {re:.4f} - F1 {F1:.4f} - acc {acc:.4f} - balanced acc {balanced_acc:.4f}"
            )
            scheduler.step(balanced_acc)

            if balanced_acc > best_acc:
                best_acc = balanced_acc
                count = 0
                if kwargs['model_path']:
                    model.save_(f"{kwargs['model_path']}-ACC{best_acc:.4f}")

            else:
                count += 1

            if count == patience_limit:
                print("ran out of patience stopping early")
                break

        model.train()
Пример #21
0
 def before_fit(self):
     self.old_one_batch = self.learn.one_batch
     self.learn.one_batch = partial(mixed_precision_one_batch, self.learn)
     self.learn.scaler = GradScaler(**self.scaler_kwargs)
Пример #22
0
    def __init__(
        self,
        config,
        tester,
        monitor,
        rank='cuda',
        world_size=0,
    ):
        self.config = config
        self.rank = rank
        self.world_size = world_size
        self.tester = tester
        # base
        self.tag = config['base']['tag']
        self.runs_dir = config['base']['runs_dir']
        #
        self.max_epochs = config['training'].get('num_epochs', 400)
        self.batch_size = config["training"].get('batch_size', 4)
        self.num_workers = config['training'].get('num_workers', 0)
        self.save_interval = config['training'].get('save_interval', 50)

        self.load_pretrain_model = config['training'].get(
            'load_pretrain_model', False)
        self.pretrain_model = config['training'].get('pretrain_model', None)
        self.load_optimizer = config['training'].get('load_optimizer', False)
        self.val_check_interval = config['training'].get(
            'val_check_interval', 50)
        self.training_log_interval = config['training'].get(
            'training_log_interval', 1)
        self.use_amp = config['training'].get('use_amp', False)
        self.save_latest_only = config['training'].get('save_latest_only',
                                                       False)

        if self.rank != 'cuda' and self.world_size > 0:
            self.master_addr = self.ddp_config.get('master_addr', 'localhost')
            self.master_port = str(self.ddp_config.get('master_port', '25700'))
            self.ddp_config = config['training'].get('ddp', dict())
            self.dist_url = 'tcp://' + self.master_addr + ":" + self.master_port
            torch.distributed.init_process_group(backend="nccl",
                                                 init_method=self.dist_url,
                                                 world_size=self.world_size,
                                                 rank=self.rank)

        self.logging_available = (self.rank == 0 or self.rank == 'cuda')
        self.trainloader = None
        self.optimizer = None
        self.scheduler = None
        self.init_timers()

        if self.use_amp:
            self.scalar = GradScaler()
            print("Debug settings: use amp=", self.use_amp)

        # Logging, in GPU 0
        self.recorder_mode = config['logger'].get("recorder_reduction", "sum")
        if self.logging_available:
            print("Logger at Process(rank=0)")
            self.recorder = Recorder(reduction=self.recorder_mode)
            self.recorder_test = Recorder(reduction=self.recorder_mode)
            self.logger = None
            self.csvlogger = CSVLogger(tfilename(self.runs_dir, "best_record"))
            self.monitor = monitor
            self.tester = tester
Пример #23
0
 def _maybe_init_amp(self):
     if self.fp16 and self.amp_grad_scaler is None and torch.cuda.is_available(
     ):
         self.amp_grad_scaler = GradScaler()
def train_fn(train_loader, teacher_model, model, criterion, optimizer, epoch,
             scheduler, device):
    if CFG.device == 'GPU':
        scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    global_step = 0
    for step, (images, images_annot, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        with torch.no_grad():
            teacher_features, _, _ = teacher_model(images_annot.to(device))
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        if CFG.device == 'GPU':
            with autocast():
                features, _, y_preds = model(images)
                loss = criterion(teacher_features, features, y_preds, labels)
                # record loss
                losses.update(loss.item(), batch_size)
                if CFG.gradient_accumulation_steps > 1:
                    loss = loss / CFG.gradient_accumulation_steps
                scaler.scale(loss).backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), CFG.max_grad_norm)
                if (step + 1) % CFG.gradient_accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    global_step += 1
        elif CFG.device == 'TPU':
            features, _, y_preds = model(images)
            loss = criterion(teacher_features, features, y_preds, labels)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.gradient_accumulation_steps > 1:
                loss = loss / CFG.gradient_accumulation_steps
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       CFG.max_grad_norm)
            if (step + 1) % CFG.gradient_accumulation_steps == 0:
                xm.optimizer_step(optimizer, barrier=True)
                optimizer.zero_grad()
                global_step += 1
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if CFG.device == 'GPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
                print('Epoch: [{0}][{1}/{2}] '
                      'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                      'Elapsed {remain:s} '
                      'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                      'Grad: {grad_norm:.4f}  '
                      #'LR: {lr:.6f}  '
                      .format(
                       epoch+1, step, len(train_loader), batch_time=batch_time,
                       data_time=data_time, loss=losses,
                       remain=timeSince(start, float(step+1)/len(train_loader)),
                       grad_norm=grad_norm,
                       #lr=scheduler.get_lr()[0],
                       ))
        elif CFG.device == 'TPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
                xm.master_print('Epoch: [{0}][{1}/{2}] '
                                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                                'Elapsed {remain:s} '
                                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                                'Grad: {grad_norm:.4f}  '
                                #'LR: {lr:.6f}  '
                                .format(
                                epoch+1, step, len(train_loader), batch_time=batch_time,
                                data_time=data_time, loss=losses,
                                remain=timeSince(start, float(step+1)/len(train_loader)),
                                grad_norm=grad_norm,
                                #lr=scheduler.get_lr()[0],
                                ))
    return losses.avg
Пример #25
0
def train(model, dataloaders, opts):
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    scaler = GradScaler()

    global_step = 0
    if opts.rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps, desc=opts.model)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'results'),
                    exist_ok=True)  # store val predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {opts.n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(dataloaders['train'].dataset))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    best_ckpt = 0
    best_eval = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(dataloaders['train']):
            targets = batch['targets']
            n_examples += targets.size(0)

            with autocast():
                _, loss, _ = model(**batch, compute_loss=True)
                loss = loss.mean()

            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            scaler.scale(loss).backward()
            if not delay_unscale:
                # gather gradients from every processes
                # do this before unscaling to make sure every process uses
                # the same gradient scale
                grads = [
                    p.grad.data for p in model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                losses = all_gather_list(running_loss)
                running_loss = RunningMeter(
                    'loss',
                    sum(l.val for l in losses) / len(losses))
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    grad_norm = clip_grad_norm_(model.parameters(),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)

                # scaler.step() first unscales gradients of the optimizer's params.
                # If gradients don't contain infs/NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(optimizer)

                # Updates the scale for next iteration.
                scaler.update()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{opts.model}: {n_epoch}-{global_step}: '
                                f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s '
                                f'best_acc-{best_eval * 100:.2f}')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)

                if global_step % opts.valid_steps == 0:
                    log = evaluation(
                        model,
                        dict(
                            filter(lambda x: x[0].startswith('val'),
                                   dataloaders.items())), opts, global_step)
                    if log['val/acc'] > best_eval:
                        best_ckpt = global_step
                        best_eval = log['val/acc']
                        pbar.set_description(
                            f'{opts.model}: {n_epoch}-{best_ckpt} best_acc-{best_eval * 100:.2f}'
                        )
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs")
        # if n_epoch >= opts.num_train_epochs:
        #     break
    return best_ckpt
def model_train(
        config: ModelConfigBase,
        run_recovery: Optional[RunRecovery] = None) -> ModelTrainingResults:
    """
    The main training loop. It creates the model, dataset, optimizer_type, and criterion, then proceeds
    to train the model. If a checkpoint was specified, then it loads the checkpoint before resuming training.

    :param config: The arguments which specify all required information.
    :param run_recovery: Recovery information to restart training from an existing run.
    :raises TypeError: If the arguments are of the wrong type.
    :raises ValueError: When there are issues loading a previous checkpoint.
    """
    # Save the dataset files for later use in cross validation analysis
    config.write_dataset_files()

    # set the random seed for all libraries
    ml_util.set_random_seed(config.get_effective_random_seed(),
                            "Model Training")

    logging.debug("Creating the PyTorch model.")

    # Create the train loader and validation loader to load images from the dataset
    data_loaders = config.create_data_loaders()

    # Get the path to the checkpoint to recover from
    checkpoint_path = get_recovery_path_train(run_recovery=run_recovery,
                                              epoch=config.start_epoch)
    models_and_optimizer = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TRAIN,
        checkpoint_path=checkpoint_path
        if config.should_load_checkpoint_for_training() else None)

    # Create the main model
    # If continuing from a previous run at a specific epoch, then load the previous model.
    model_loaded = models_and_optimizer.try_create_model_and_load_from_checkpoint(
    )
    if not model_loaded:
        raise ValueError(
            "There was no checkpoint file available for the model for given start_epoch {}"
            .format(config.start_epoch))

    # Print out a detailed breakdown of layers, memory consumption and time.
    generate_and_print_model_summary(config, models_and_optimizer.model)

    # Move model to GPU and adjust for multiple GPUs
    models_and_optimizer.adjust_model_for_gpus()

    # Create the mean teacher model and move to GPU
    if config.compute_mean_teacher_model:
        mean_teacher_model_loaded = models_and_optimizer.try_create_mean_teacher_model_load_from_checkpoint_and_adjust(
        )
        if not mean_teacher_model_loaded:
            raise ValueError(
                "There was no checkpoint file available for the mean teacher model for given start_epoch {}"
                .format(config.start_epoch))

    # Create optimizer
    optimizer_loaded = models_and_optimizer.try_create_optimizer_and_load_from_checkpoint(
    )
    if not optimizer_loaded:
        raise ValueError(
            "There was no checkpoint file available for the optimizer for given start_epoch {}"
            .format(config.start_epoch))

    # Create checkpoint directory for this run if it doesn't already exist
    logging.info("Models are saved at {}".format(config.checkpoint_folder))
    if not os.path.isdir(config.checkpoint_folder):
        os.makedirs(config.checkpoint_folder)

    # Create the SummaryWriters for Tensorboard
    writers = create_summary_writers(config)
    config.create_dataframe_loggers()

    # Create LR scheduler
    l_rate_scheduler = SchedulerWithWarmUp(config,
                                           models_and_optimizer.optimizer)

    # Training loop
    logging.info("Starting training")
    train_results_per_epoch, val_results_per_epoch, learning_rates_per_epoch = [], [], []

    resource_monitor = None
    if config.monitoring_interval_seconds > 0:
        # initialize and start GPU monitoring
        resource_monitor = ResourceMonitor(
            interval_seconds=config.monitoring_interval_seconds,
            tb_log_file_path=str(config.logs_folder / "diagnostics"))
        resource_monitor.start()

    gradient_scaler = GradScaler(
    ) if config.use_gpu and config.use_mixed_precision else None
    optimal_temperature_scale_values = []
    for epoch in config.get_train_epochs():
        logging.info("Starting epoch {}".format(epoch))
        save_epoch = config.should_save_epoch(
            epoch) and models_and_optimizer.optimizer is not None

        # store the learning rates used for each epoch
        epoch_lrs = l_rate_scheduler.get_last_lr()
        learning_rates_per_epoch.append(epoch_lrs)

        train_val_params: TrainValidateParameters = \
            TrainValidateParameters(data_loader=data_loaders[ModelExecutionMode.TRAIN],
                                    model=models_and_optimizer.model,
                                    mean_teacher_model=models_and_optimizer.mean_teacher_model,
                                    epoch=epoch,
                                    optimizer=models_and_optimizer.optimizer,
                                    gradient_scaler=gradient_scaler,
                                    epoch_learning_rate=epoch_lrs,
                                    summary_writers=writers,
                                    dataframe_loggers=config.metrics_data_frame_loggers,
                                    in_training_mode=True)
        training_steps = create_model_training_steps(config, train_val_params)
        train_epoch_results = train_or_validate_epoch(training_steps)
        train_results_per_epoch.append(train_epoch_results.metrics)

        metrics.validate_and_store_model_parameters(writers.train, epoch,
                                                    models_and_optimizer.model)
        # Run without adjusting weights on the validation set
        train_val_params.in_training_mode = False
        train_val_params.data_loader = data_loaders[ModelExecutionMode.VAL]
        # if temperature scaling is enabled then do not save validation metrics for the checkpoint epochs
        # as these will be re-computed after performing temperature scaling on the validation set.
        if isinstance(config, SequenceModelBase):
            train_val_params.save_metrics = not (
                save_epoch and config.temperature_scaling_config)

        training_steps = create_model_training_steps(config, train_val_params)
        val_epoch_results = train_or_validate_epoch(training_steps)
        val_results_per_epoch.append(val_epoch_results.metrics)

        if config.is_segmentation_model:
            metrics.store_epoch_stats_for_segmentation(
                config.outputs_folder, epoch, epoch_lrs,
                train_epoch_results.metrics, val_epoch_results.metrics)

        if save_epoch:
            # perform temperature scaling if required
            if isinstance(
                    config,
                    SequenceModelBase) and config.temperature_scaling_config:
                optimal_temperature, scaled_val_results = \
                    temperature_scaling_steps(config, train_val_params, val_epoch_results)
                optimal_temperature_scale_values.append(optimal_temperature)
                # overwrite the metrics for the epoch with the metrics from the temperature scaled model
                val_results_per_epoch[-1] = scaled_val_results.metrics

            models_and_optimizer.save_checkpoint(epoch)

        # Updating the learning rate should happen at the end of the training loop, so that the
        # initial learning rate will be used for the very first epoch.
        l_rate_scheduler.step()

    model_training_results = ModelTrainingResults(
        train_results_per_epoch=train_results_per_epoch,
        val_results_per_epoch=val_results_per_epoch,
        learning_rates_per_epoch=learning_rates_per_epoch,
        optimal_temperature_scale_values_per_checkpoint_epoch=
        optimal_temperature_scale_values)

    logging.info("Finished training")

    # Upload visualization directory to AML run context to be able to see it
    # in the Azure UI.
    if config.max_batch_grad_cam > 0 and config.visualization_folder.exists():
        RUN_CONTEXT.upload_folder(name=VISUALIZATION_FOLDER,
                                  path=str(config.visualization_folder))

    writers.close_all()
    config.metrics_data_frame_loggers.close_all()
    if resource_monitor:
        # stop the resource monitoring process
        resource_monitor.kill()

    return model_training_results
Пример #27
0
    # loss
    if config.get('focalloss', False):
        criterion = Sigmoid_focal_loss(config['focal_gamma'],
                                       config['focal_alpha'])
    elif config['weighted_BCE']:
        criterion = WeightedBCELoss(weight=config['bce_wieghts'])
    else:
        criterion = nn.BCEWithLogitsLoss()
    # metric variabels

    best_val = None  # Best validation score within this fold
    patience = es_patience  # Current patience counter

    if config['mixed_prec']:
        scaler = GradScaler()
    # train on each epoch
    for epoch in range(epochs):

        correct = 0
        epoch_loss = 0
        model.train()

        # train on each mini-batch
        for x, y in train_loader:
            if y.size(0) == 1:
                continue
            x[0] = x[0].float().to(device)
            x[1] = x[1].float().to(device)
            y = y.float().to(device)
            optimizer.zero_grad()
Пример #28
0
def main(opts):
    device = torch.device("cuda", hvd.local_rank())
    torch.cuda.set_device(hvd.local_rank())
    rank = hvd.rank()
    opts.rank = rank
    opts.size = hvd.size()
    LOGGER.info("device: {} n_gpu: {}, rank: {}, "
                "16-bits training: {}".format(device, n_gpu, hvd.rank(),
                                              opts.fp16))

    if opts.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                             opts.gradient_accumulation_steps))

    set_random_seed(opts.seed)

    # data loaders
    DatasetCls = DATA_REGISTRY[opts.dataset_cls]
    EvalDatasetCls = DATA_REGISTRY[opts.eval_dataset_cls]
    splits, dataloaders = create_dataloaders(DatasetCls, EvalDatasetCls, opts)

    # Prepare model
    model = build_model(opts)
    model.to(device)
    # make sure every process has same model parameters in the beginning
    broadcast_tensors([p.data for p in model.parameters()], 0)
    set_dropout(model, opts.dropout)

    # Prepare optimizer
    optimizer = build_optimizer(model, opts)
    scaler = GradScaler()

    global_step = 0
    if rank == 0:
        save_training_meta(opts)
        TB_LOGGER.create(join(opts.output_dir, 'log'))
        pbar = tqdm(total=opts.num_train_steps, desc=opts.model)
        model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
        os.makedirs(join(opts.output_dir, 'results'),
                    exist_ok=True)  # store val predictions
        add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
    else:
        LOGGER.disabled = True
        pbar = NoOp()
        model_saver = NoOp()

    LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
    LOGGER.info("  Num examples = %d", len(dataloaders['train'].dataset))
    LOGGER.info("  Batch size = %d", opts.train_batch_size)
    LOGGER.info("  Accumulate steps = %d", opts.gradient_accumulation_steps)
    LOGGER.info("  Num steps = %d", opts.num_train_steps)

    running_loss = RunningMeter('loss')
    model.train()
    n_examples = 0
    n_epoch = 0
    best_ckpt = 0
    best_eval = 0
    start = time()
    # quick hack for amp delay_unscale bug
    optimizer.zero_grad()
    optimizer.step()
    while True:
        for step, batch in enumerate(dataloaders['train']):
            targets = batch['targets']
            del batch['gather_index']
            n_examples += targets.size(0)

            with autocast():
                original_loss, enlarged_loss = model(**batch,
                                                     compute_loss=True)
                if opts.candidates == 'original':
                    loss = original_loss
                elif opts.candidates == 'enlarged':
                    loss = enlarged_loss
                elif opts.candidates == 'combined':
                    loss = original_loss + enlarged_loss
                else:
                    raise AssertionError("No such loss!")

                loss = loss.mean()

            delay_unscale = (step + 1) % opts.gradient_accumulation_steps != 0
            scaler.scale(loss).backward()
            if not delay_unscale:
                # gather gradients from every processes
                # do this before unscaling to make sure every process uses
                # the same gradient scale
                grads = [
                    p.grad.data for p in model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                all_reduce_and_rescale_tensors(grads, float(1))

            running_loss(loss.item())

            if (step + 1) % opts.gradient_accumulation_steps == 0:
                global_step += 1

                # learning rate scheduling
                lr_this_step = get_lr_sched(global_step, opts)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                TB_LOGGER.add_scalar('lr', lr_this_step, global_step)

                # log loss
                losses = all_gather_list(running_loss)
                running_loss = RunningMeter(
                    'loss',
                    sum(l.val for l in losses) / len(losses))
                TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
                TB_LOGGER.step()

                # update model params
                if opts.grad_norm != -1:
                    # Unscales the gradients of optimizer's assigned params in-place
                    scaler.unscale_(optimizer)
                    grad_norm = clip_grad_norm_(model.parameters(),
                                                opts.grad_norm)
                    TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)

                # scaler.step() first unscales gradients of the optimizer's params.
                # If gradients don't contain infs/NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(optimizer)

                # Updates the scale for next iteration.
                scaler.update()
                optimizer.zero_grad()
                pbar.update(1)

                if global_step % 100 == 0:
                    # monitor training throughput
                    tot_ex = sum(all_gather_list(n_examples))
                    ex_per_sec = int(tot_ex / (time() - start))
                    LOGGER.info(f'{opts.model}: {n_epoch}-{global_step}: '
                                f'{tot_ex} examples trained at '
                                f'{ex_per_sec} ex/s '
                                f'best_acc-{best_eval * 100:.2f}')
                    TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec,
                                         global_step)

                if global_step % opts.valid_steps == 0:
                    log = evaluation(
                        model,
                        dict(
                            filter(lambda x: x[0].startswith('val'),
                                   dataloaders.items())), opts, global_step)
                    if log['val/acc'] > best_eval:
                        best_ckpt = global_step
                        best_eval = log['val/acc']
                        pbar.set_description(
                            f'{opts.model}: {n_epoch}-{best_ckpt} best_acc-{best_eval * 100:.2f}'
                        )
                    model_saver.save(model, global_step)
            if global_step >= opts.num_train_steps:
                break
        if global_step >= opts.num_train_steps:
            break
        n_epoch += 1
        LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs")

    sum(all_gather_list(opts.rank))

    best_pt = f'{opts.output_dir}/ckpt/model_step_{best_ckpt}.pt'
    model.load_state_dict(torch.load(best_pt), strict=False)
    evaluation(model,
               dict(filter(lambda x: x[0] != 'train', dataloaders.items())),
               opts, best_ckpt)
Пример #29
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()

        with autocast(enabled=with_amp):
            y_pred = model(x)
            loss = criterion(y_pred, y)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(
        ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Пример #30
0
    def train(self,
              num_epochs,
              steps_per_epoch,
              validate_interval,
              gradient_accumulation=1,
              amp=True,
              verbosity=None,
              grid_mask=None,
              mix=None,
              dist_val=False):
        # Epochs are 0-indexed
        self.num_epochs = num_epochs
        self.steps_per_epoch = steps_per_epoch
        self.validate_interval = validate_interval
        self.gradient_accumulation = gradient_accumulation
        self.amp = amp
        if self.amp: self.scaler = GradScaler()
        if verbosity:
            verbosity = verbosity
        else:
            verbosity = self.steps_per_epoch // 10
        verbosity = max(1, verbosity)
        self.grid_mask = grid_mask
        self.mix = mix
        tic = datetime.datetime.now()
        while 1:
            self.train_step()
            self.steps += 1
            if self.scheduler.update == 'on_batch':
                self.scheduler_step()
            # Check- print training progress
            if self.steps % verbosity == 0 and self.steps > 0:
                self.print_progress()
            # Check- run validation
            if self.check_validation():
                self.print('VALIDATING ...')
                validation_start_time = datetime.datetime.now()
                # Start validation
                self.model.eval()
                if dist_val:
                    self.evaluator.validate(self.model,
                                            self.criterion,
                                            str(self.current_epoch).zfill(
                                                len(str(self.num_epochs))),
                                            save_pickle=True)
                    if self.local_rank == 0:
                        while 1:
                            finished = glob.glob(
                                osp.join(self.evaluator.save_checkpoint_dir,
                                         '.done_rank*.txt'))
                            if len(finished) == self.world_size:
                                break
                        _ = os.system(
                            f'rm {osp.join(self.evaluator.save_checkpoint_dir, ".done_rank*.txt")}'
                        )
                        time.sleep(1)
                        predictions = glob.glob(
                            osp.join(self.evaluator.save_checkpoint_dir,
                                     '.tmp_preds_rank*.pkl'))
                        # Combine and calculate validation stats
                        predictions = [
                            self.load_pickle(p) for p in predictions
                        ]
                        y_true = np.concatenate(
                            [p['y_true'] for p in predictions])
                        y_pred = np.concatenate(
                            [p['y_pred'] for p in predictions])
                        losses = np.concatenate(
                            [p['losses'] for p in predictions])
                        del predictions
                        valid_metric = self.evaluator.calculate_metrics(
                            y_true, y_pred, losses)
                        self.evaluator.save_checkpoint(self.model,
                                                       valid_metric, y_true,
                                                       y_pred)
                        self.print('Validation took {} !'.format(
                            datetime.datetime.now() - validation_start_time))
                elif self.local_rank == 0:
                    y_true, y_pred, losses = self.evaluator.validate(
                        self.model,
                        self.criterion,
                        str(self.current_epoch).zfill(len(str(
                            self.num_epochs))),
                        save_pickle=False)
                    valid_metric = self.evaluator.calculate_metrics(
                        y_true, y_pred, losses)
                    self.evaluator.save_checkpoint(self.model, valid_metric,
                                                   y_true, y_pred)
                    self.print(
                        'Validation took {} !'.format(datetime.datetime.now() -
                                                      validation_start_time))

                if self.scheduler.update == 'on_valid':
                    self.scheduler.step(valid_metric)
                # End validation
                self.model.train()
            # Check- end of epoch
            if self.check_end_epoch():
                if self.scheduler.update == 'on_epoch':
                    self.scheduler.step()
                self.current_epoch += 1
                self.steps = 0
                # RESET BEST MODEL IF USING COSINEANNEALINGWARMRESTARTS
                if 'warmrestarts' in str(self.scheduler).lower():
                    if self.current_epoch % self.scheduler.T_0 == 0:
                        self.evaluator.reset_best()
            #
            if self.evaluator.check_stopping():
                # Make sure to set number of epochs to max epochs
                # Remember, epochs are 0-indexed and we added 1 already
                # So, this should work (e.g., epoch 99 would now be epoch 100,
                # thus training would stop after epoch 99 if num_epochs = 100)
                self.current_epoch = num_epochs
            if self.check_end_train():
                # Break the while loop
                break
        self.print('TRAINING : END')
        self.print('Training took {}\n'.format(datetime.datetime.now() - tic))