예제 #1
0
def train_21k(model, train_loader, val_loader, optimizer,
              semantic_softmax_processor, met, args):
    # set loss
    loss_fn = SemanticSoftmaxLoss(semantic_softmax_processor)

    # set scheduler
    scheduler = lr_scheduler.OneCycleLR(optimizer,
                                        max_lr=args.lr,
                                        steps_per_epoch=len(train_loader),
                                        epochs=args.epochs,
                                        pct_start=0.1,
                                        cycle_momentum=False,
                                        div_factor=20)

    # set scalaer
    scaler = GradScaler()
    for epoch in range(args.epochs):
        if num_distrib() > 1:
            train_loader.sampler.set_epoch(epoch)

        # train epoch
        print_at_master("\nEpoch {}".format(epoch))
        epoch_start_time = time.time()
        for i, (input, target) in enumerate(train_loader):
            with autocast():  # mixed precision
                output = model(input)
                loss = loss_fn(output, target)  # note - loss also in fp16
            model.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

        epoch_time = time.time() - epoch_start_time
        print_at_master(
            "\nFinished Epoch, Training Rate: {:.1f} [img/sec]".format(
                len(train_loader) * args.batch_size / epoch_time *
                max(num_distrib(), 1)))

        # validation epoch
        validate_21k(val_loader, model, met)
예제 #2
0
def train_fn(train_loader, model, criterion, optimizer, config, device):
    assert hasattr(config, "apex"), "Please create apex(bool) attribute"
    assert hasattr(
        config, "gradient_accumulation_steps"
    ), "Please create gradient_accumulation_steps(int default=1) attribute"

    model.train()
    if config.apex:
        scaler = GradScaler()
    losses = AverageMeter()
    for step, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        if config.apex:
            with autocast():
                y_preds = model(images)
                loss = criterion(y_preds.view(-1), labels)
        else:
            y_preds = model(images)
            loss = criterion(y_preds.view(-1), labels)
        # record loss
        losses.update(loss.item(), batch_size)
        if config.gradient_accumulation_steps > 1:
            loss = loss / config.gradient_accumulation_steps
        if config.apex:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        if (step + 1) % config.gradient_accumulation_steps == 0:
            if config.apex:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            optimizer.zero_grad()
        del loss
        gc.collect()
    return losses.avg
class task3_train():
    def __init__(self, device, train_loader, val_loader, model, MODEL_PATH):
        self.device = device
        self.model_path = MODEL_PATH
        self.train_loader = train_loader
        self.val_loader = val_loader

        self.model = model.to(device)
        self.optimizer = optim.SGD(self.model.parameters(),
                                   lr=1e-2,
                                   momentum=0.9,
                                   weight_decay=5e-4)
        self.criterion = CriterionOhemDSN(thresh=0.7, min_kept=100000)
        self.scaler = GradScaler()

    def train(self):
        self.model.train()

        train_loss = 0

        for batch_idx, (inputs, GT) in enumerate(self.train_loader):
            inputs, GT = inputs.to(self.device), GT.long().to(self.device)
            self.optimizer.zero_grad()
            with autocast():
                SR = self.model(inputs)
                loss = self.criterion(SR, GT)

            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            train_loss += loss.item()

        print('Training Loss: %.4f' % (train_loss))

    def saveModel(self):
        torch.save(self.model.state_dict(), self.model_path + "model1.pth")
예제 #4
0
class NativeScaler:
    state_dict_key = "amp_scaler"

    def __init__(self):
        self._scaler = GradScaler()

    def __repr__(self) -> str:
        return repr(self.__class__.__name__)

    def __call__(
        self,
        loss,
        optimizer,
        step,
        accum_grad,
        clip_grad=None,
        parameters=None,
        create_graph=False,
    ):
        self._scaler.scale(loss /
                           accum_grad).backward(create_graph=create_graph)
        if step % accum_grad == 0:
            if clip_grad is not None:
                assert parameters is not None
                self._scaler.unscale_(
                    optimizer
                )  # unscale the gradients of optimizer's assigned params in-place
                nn.utils.clip_grad_norm_(parameters, clip_grad)
            self._scaler.step(optimizer)
            self._scaler.update()
            optimizer.zero_grad()

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)
예제 #5
0
def def_model_backward(losses: Dict[str, Tensor],
                       model: Module,
                       scaler: GradScaler = None):
    """
    Default function to perform a backwards pass for a model and the calculated losses
    Calls backwards for the DEFAULT_LOSS_KEY in losses Dict

    :param model: the model to run the backward for
    :param losses: the losses dictionary containing named tensors,
                   DEFAULT_LOSS_KEY is expected to exist and backwards is called on that
    :param scaler: GradScaler object for running in mixed precision with amp. If scaler
        is not None will call scaler.scale on the loss object. Default is None
    """
    # assume loss is at default loss key
    loss = losses[DEFAULT_LOSS_KEY]
    if scaler is not None:
        loss = scaler.scale(loss)
    loss.backward()
예제 #6
0
class Trainer():
    def __init__(self,
                 name='default',
                 results_dir='results',
                 models_dir='models',
                 base_dir='./',
                 optimizer="adam",
                 latent_dim=256,
                 image_size=128,
                 fmap_max=512,
                 transparent=False,
                 greyscale=False,
                 batch_size=4,
                 gp_weight=10,
                 gradient_accumulate_every=1,
                 attn_res_layers=[],
                 disc_output_size=5,
                 antialias=False,
                 lr=2e-4,
                 lr_mlp=1.,
                 ttur_mult=1.,
                 save_every=1000,
                 evaluate_every=1000,
                 trunc_psi=0.6,
                 aug_prob=None,
                 aug_types=['translation', 'cutout'],
                 dataset_aug_prob=0.,
                 calculate_fid_every=None,
                 is_ddp=False,
                 rank=0,
                 world_size=1,
                 log=False,
                 amp=False,
                 *args,
                 **kwargs):
        self.GAN_params = [args, kwargs]
        self.GAN = None

        self.name = name

        base_dir = Path(base_dir)
        self.base_dir = base_dir
        self.results_dir = base_dir / results_dir
        self.models_dir = base_dir / models_dir
        self.config_path = self.models_dir / name / '.config.json'

        assert is_power_of_two(
            image_size
        ), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
        assert all(
            map(is_power_of_two, attn_res_layers)
        ), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)'

        self.optimizer = optimizer
        self.latent_dim = latent_dim
        self.image_size = image_size
        self.fmap_max = fmap_max
        self.transparent = transparent
        self.greyscale = greyscale

        assert (int(self.transparent) + int(self.greyscale)
                ) < 2, 'you can only set either transparency or greyscale'

        self.aug_prob = aug_prob
        self.aug_types = aug_types

        self.lr = lr
        self.ttur_mult = ttur_mult
        self.batch_size = batch_size
        self.gradient_accumulate_every = gradient_accumulate_every

        self.gp_weight = gp_weight

        self.evaluate_every = evaluate_every
        self.save_every = save_every
        self.steps = 0

        self.generator_top_k_gamma = 0.99
        self.generator_top_k_frac = 0.5

        self.attn_res_layers = attn_res_layers
        self.disc_output_size = disc_output_size
        self.antialias = antialias

        self.d_loss = 0
        self.g_loss = 0
        self.last_gp_loss = None
        self.last_recon_loss = None
        self.last_fid = None

        self.init_folders()

        self.loader = None
        self.dataset_aug_prob = dataset_aug_prob

        self.calculate_fid_every = calculate_fid_every

        self.is_ddp = is_ddp
        self.is_main = rank == 0
        self.rank = rank
        self.world_size = world_size

        self.syncbatchnorm = is_ddp

        self.amp = amp
        self.G_scaler = GradScaler(enabled=self.amp)
        self.D_scaler = GradScaler(enabled=self.amp)

    @property
    def image_extension(self):
        return 'jpg' if not self.transparent else 'png'

    @property
    def checkpoint_num(self):
        return floor(self.steps // self.save_every)

    def init_GAN(self):
        args, kwargs = self.GAN_params

        # set some global variables before instantiating GAN

        global norm_class
        global Blur

        norm_class = nn.SyncBatchNorm if self.syncbatchnorm else nn.BatchNorm2d
        Blur = nn.Identity if not self.antialias else Blur

        # handle bugs when
        # switching from multi-gpu back to single gpu

        if self.syncbatchnorm and not self.is_ddp:
            import torch.distributed as dist
            os.environ['MASTER_ADDR'] = 'localhost'
            os.environ['MASTER_PORT'] = '12355'
            dist.init_process_group('nccl', rank=0, world_size=1)

        # instantiate GAN

        self.GAN = LightweightGAN(optimizer=self.optimizer,
                                  lr=self.lr,
                                  latent_dim=self.latent_dim,
                                  attn_res_layers=self.attn_res_layers,
                                  image_size=self.image_size,
                                  ttur_mult=self.ttur_mult,
                                  fmap_max=self.fmap_max,
                                  disc_output_size=self.disc_output_size,
                                  transparent=self.transparent,
                                  greyscale=self.greyscale,
                                  rank=self.rank,
                                  *args,
                                  **kwargs)

        if self.is_ddp:
            ddp_kwargs = {
                'device_ids': [self.rank],
                'output_device': self.rank,
                'find_unused_parameters': True
            }

            self.G_ddp = DDP(self.GAN.G, **ddp_kwargs)
            self.D_ddp = DDP(self.GAN.D, **ddp_kwargs)
            self.D_aug_ddp = DDP(self.GAN.D_aug, **ddp_kwargs)

    def write_config(self):
        self.config_path.write_text(json.dumps(self.config()))

    def load_config(self):
        config = self.config(
        ) if not self.config_path.exists() else json.loads(
            self.config_path.read_text())
        self.image_size = config['image_size']
        self.transparent = config['transparent']
        self.syncbatchnorm = config['syncbatchnorm']
        self.disc_output_size = config['disc_output_size']
        self.greyscale = config.pop('greyscale', False)
        self.attn_res_layers = config.pop('attn_res_layers', [])
        self.optimizer = config.pop('optimizer', 'adam')
        self.fmap_max = config.pop('fmap_max', 512)
        del self.GAN
        self.init_GAN()

    def config(self):
        return {
            'image_size': self.image_size,
            'transparent': self.transparent,
            'greyscale': self.greyscale,
            'syncbatchnorm': self.syncbatchnorm,
            'disc_output_size': self.disc_output_size,
            'optimizer': self.optimizer,
            'attn_res_layers': self.attn_res_layers
        }

    def set_data_src(self, folder):
        self.dataset = ImageDataset(folder,
                                    self.image_size,
                                    transparent=self.transparent,
                                    greyscale=self.greyscale,
                                    aug_prob=self.dataset_aug_prob)
        sampler = DistributedSampler(self.dataset,
                                     rank=self.rank,
                                     num_replicas=self.world_size,
                                     shuffle=True) if self.is_ddp else None
        dataloader = DataLoader(
            self.dataset,
            num_workers=math.ceil(NUM_CORES / self.world_size),
            batch_size=math.ceil(self.batch_size / self.world_size),
            sampler=sampler,
            shuffle=not self.is_ddp,
            drop_last=True,
            pin_memory=True)
        self.loader = cycle(dataloader)

        # auto set augmentation prob for user if dataset is detected to be low
        num_samples = len(self.dataset)
        if not exists(self.aug_prob) and num_samples < 1e5:
            self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
            print(
                f'autosetting augmentation probability to {round(self.aug_prob * 100)}%'
            )

    def train(self):
        assert exists(
            self.loader
        ), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'
        device = torch.device(f'cuda:{self.rank}')

        if not exists(self.GAN):
            self.init_GAN()

        self.GAN.train()
        total_disc_loss = torch.zeros([], device=device)
        total_gen_loss = torch.zeros([], device=device)

        batch_size = math.ceil(self.batch_size / self.world_size)

        image_size = self.GAN.image_size
        latent_dim = self.GAN.latent_dim

        aug_prob = default(self.aug_prob, 0)
        aug_types = self.aug_types
        aug_kwargs = {'prob': aug_prob, 'types': aug_types}

        G = self.GAN.G if not self.is_ddp else self.G_ddp
        D = self.GAN.D if not self.is_ddp else self.D_ddp
        D_aug = self.GAN.D_aug if not self.is_ddp else self.D_aug_ddp

        apply_gradient_penalty = self.steps % 4 == 0

        # amp related contexts and functions

        amp_context = autocast if self.amp else null_context

        # train discriminator
        self.GAN.D_opt.zero_grad()
        for i in gradient_accumulate_contexts(self.gradient_accumulate_every,
                                              self.is_ddp,
                                              ddps=[D_aug, G]):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
            image_batch = next(self.loader).cuda(self.rank)
            image_batch.requires_grad_()

            with amp_context():
                with torch.no_grad():
                    generated_images = G(latents)

                fake_output, fake_output_32x32, _ = D_aug(generated_images,
                                                          detach=True,
                                                          **aug_kwargs)

                real_output, real_output_32x32, real_aux_loss = D_aug(
                    image_batch, calc_aux_loss=True, **aug_kwargs)

                real_output_loss = real_output
                fake_output_loss = fake_output

                divergence = hinge_loss(real_output_loss, fake_output_loss)
                divergence_32x32 = hinge_loss(real_output_32x32,
                                              fake_output_32x32)
                disc_loss = divergence + divergence_32x32

                aux_loss = real_aux_loss
                disc_loss = disc_loss + aux_loss

            if apply_gradient_penalty:
                outputs = [real_output, real_output_32x32]
                outputs = list(map(self.D_scaler.scale,
                                   outputs)) if self.amp else outputs

                scaled_gradients = torch_grad(
                    outputs=outputs,
                    inputs=image_batch,
                    grad_outputs=list(
                        map(
                            lambda t: torch.ones(t.size(),
                                                 device=image_batch.device),
                            outputs)),
                    create_graph=True,
                    retain_graph=True,
                    only_inputs=True)[0]

                inv_scale = (1. /
                             self.D_scaler.get_scale()) if self.amp else 1.
                gradients = scaled_gradients * inv_scale

                with amp_context():
                    gradients = gradients.reshape(batch_size, -1)
                    gp = self.gp_weight * (
                        (gradients.norm(2, dim=1) - 1)**2).mean()

                    if not torch.isnan(gp):
                        disc_loss = disc_loss + gp
                        self.last_gp_loss = gp.clone().detach().item()

            with amp_context():
                disc_loss = disc_loss / self.gradient_accumulate_every

            disc_loss.register_hook(raise_if_nan)
            self.D_scaler.scale(disc_loss).backward()
            total_disc_loss += divergence

        self.last_recon_loss = aux_loss.item()
        self.d_loss = float(total_disc_loss.item() /
                            self.gradient_accumulate_every)
        self.D_scaler.step(self.GAN.D_opt)
        self.D_scaler.update()

        # train generator

        self.GAN.G_opt.zero_grad()

        for i in gradient_accumulate_contexts(self.gradient_accumulate_every,
                                              self.is_ddp,
                                              ddps=[G, D_aug]):
            latents = torch.randn(batch_size, latent_dim).cuda(self.rank)

            with amp_context():
                generated_images = G(latents)
                fake_output, fake_output_32x32, _ = D_aug(
                    generated_images, **aug_kwargs)
                fake_output_loss = fake_output.mean(
                    dim=1) + fake_output_32x32.mean(dim=1)

                epochs = (self.steps * batch_size *
                          self.gradient_accumulate_every) / len(self.dataset)
                k_frac = max(self.generator_top_k_gamma**epochs,
                             self.generator_top_k_frac)
                k = math.ceil(batch_size * k_frac)

                if k != batch_size:
                    fake_output_loss, _ = fake_output_loss.topk(k=k,
                                                                largest=False)

                loss = fake_output_loss.mean()
                gen_loss = loss

                gen_loss = gen_loss / self.gradient_accumulate_every

            gen_loss.register_hook(raise_if_nan)
            self.G_scaler.scale(gen_loss).backward()
            total_gen_loss += loss

        self.g_loss = float(total_gen_loss.item() /
                            self.gradient_accumulate_every)
        self.G_scaler.step(self.GAN.G_opt)
        self.G_scaler.update()

        # calculate moving averages

        if self.is_main and self.steps % 10 == 0 and self.steps > 20000:
            self.GAN.EMA()

        if self.is_main and self.steps <= 25000 and self.steps % 1000 == 2:
            self.GAN.reset_parameter_averaging()

        # save from NaN errors

        if any(torch.isnan(l) for l in (total_gen_loss, total_disc_loss)):
            print(
                f'NaN detected for generator or discriminator. Loading from checkpoint #{self.checkpoint_num}'
            )
            self.load(self.checkpoint_num)
            raise NanException

        del total_disc_loss
        del total_gen_loss

        # periodically save results

        if self.is_main:
            if self.steps % self.save_every == 0:
                self.save(self.checkpoint_num)

            if self.steps % self.evaluate_every == 0 or (
                    self.steps % 100 == 0 and self.steps < 20000):
                self.evaluate(floor(self.steps / self.evaluate_every))

            if exists(
                    self.calculate_fid_every
            ) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
                num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size)
                fid = self.calculate_fid(num_batches)
                self.last_fid = fid

                with open(
                        str(self.results_dir / self.name / f'fid_scores.txt'),
                        'a') as f:
                    f.write(f'{self.steps},{fid}\n')

        self.steps += 1

    @torch.no_grad()
    def evaluate(self, num=0, num_image_tiles=8, trunc=1.0):
        self.GAN.eval()

        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        # latents and noise

        latents = torch.randn((num_rows**2, latent_dim)).cuda(self.rank)

        # regular

        generated_images = self.generate_truncated(self.GAN.G, latents)
        torchvision.utils.save_image(generated_images,
                                     str(self.results_dir / self.name /
                                         f'{str(num)}.{ext}'),
                                     nrow=num_rows)

        # moving averages

        generated_images = self.generate_truncated(self.GAN.GE, latents)
        torchvision.utils.save_image(generated_images,
                                     str(self.results_dir / self.name /
                                         f'{str(num)}-ema.{ext}'),
                                     nrow=num_rows)

    @torch.no_grad()
    def calculate_fid(self, num_batches):
        from pytorch_fid import fid_score
        torch.cuda.empty_cache()

        real_path = str(self.results_dir / self.name / 'fid_real') + '/'
        fake_path = str(self.results_dir / self.name / 'fid_fake') + '/'

        # remove any existing files used for fid calculation and recreate directories
        rmtree(real_path, ignore_errors=True)
        rmtree(fake_path, ignore_errors=True)
        os.makedirs(real_path)
        os.makedirs(fake_path)

        for batch_num in tqdm(range(num_batches),
                              desc='calculating FID - saving reals'):
            real_batch = next(self.loader)
            for k in range(real_batch.size(0)):
                torchvision.utils.save_image(
                    real_batch[k, :, :, :], real_path +
                    '{}.png'.format(k + batch_num * self.batch_size))

        # generate a bunch of fake images in results / name / fid_fake
        self.GAN.eval()
        ext = self.image_extension

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        for batch_num in tqdm(range(num_batches),
                              desc='calculating FID - saving generated'):
            # latents and noise
            latents = torch.randn(self.batch_size, latent_dim).cuda(self.rank)

            # moving averages
            generated_images = self.generate_truncated(self.GAN.GE, latents)

            for j in range(generated_images.size(0)):
                torchvision.utils.save_image(
                    generated_images[j, :, :, :],
                    str(
                        Path(fake_path) /
                        f'{str(j + batch_num * self.batch_size)}-ema.{ext}'))

        return fid_score.calculate_fid_given_paths([real_path, fake_path], 256,
                                                   True, 2048)

    @torch.no_grad()
    def generate_truncated(self, G, style, trunc_psi=0.75, num_image_tiles=8):
        generated_images = evaluate_in_chunks(self.batch_size, G, style)
        return generated_images.clamp_(0., 1.)

    @torch.no_grad()
    def generate_interpolation(self,
                               num=0,
                               num_image_tiles=8,
                               trunc=1.0,
                               num_steps=100,
                               save_frames=False):
        self.GAN.eval()
        ext = self.image_extension
        num_rows = num_image_tiles

        latent_dim = self.GAN.latent_dim
        image_size = self.GAN.image_size

        # latents and noise

        latents_low = torch.randn(num_rows**2, latent_dim).cuda(self.rank)
        latents_high = torch.randn(num_rows**2, latent_dim).cuda(self.rank)

        ratios = torch.linspace(0., 8., num_steps)

        frames = []
        for ratio in tqdm(ratios):
            interp_latents = slerp(ratio, latents_low, latents_high)
            generated_images = self.generate_truncated(self.GAN.GE,
                                                       interp_latents)
            images_grid = torchvision.utils.make_grid(generated_images,
                                                      nrow=num_rows)
            pil_image = transforms.ToPILImage()(images_grid.cpu())

            if self.transparent:
                background = Image.new('RGBA', pil_image.size, (255, 255, 255))
                pil_image = Image.alpha_composite(background, pil_image)

            frames.append(pil_image)

        frames[0].save(str(self.results_dir / self.name / f'{str(num)}.gif'),
                       save_all=True,
                       append_images=frames[1:],
                       duration=80,
                       loop=0,
                       optimize=True)

        if save_frames:
            folder_path = (self.results_dir / self.name / f'{str(num)}')
            folder_path.mkdir(parents=True, exist_ok=True)
            for ind, frame in enumerate(frames):
                frame.save(str(folder_path / f'{str(ind)}.{ext}'))

    def print_log(self):
        data = [('G', self.g_loss), ('D', self.d_loss),
                ('GP', self.last_gp_loss), ('SS', self.last_recon_loss),
                ('FID', self.last_fid)]

        data = [d for d in data if exists(d[1])]
        log = ' | '.join(map(lambda n: f'{n[0]}: {n[1]:.2f}', data))
        print(log)

    def model_name(self, num):
        return str(self.models_dir / self.name / f'model_{num}.pt')

    def init_folders(self):
        (self.results_dir / self.name).mkdir(parents=True, exist_ok=True)
        (self.models_dir / self.name).mkdir(parents=True, exist_ok=True)

    def clear(self):
        rmtree(str(self.models_dir / self.name), True)
        rmtree(str(self.results_dir / self.name), True)
        rmtree(str(self.config_path), True)
        self.init_folders()

    def save(self, num):
        save_data = {
            'GAN': self.GAN.state_dict(),
            'version': __version__,
            'G_scaler': self.G_scaler.state_dict(),
            'D_scaler': self.D_scaler.state_dict()
        }

        torch.save(save_data, self.model_name(num))
        self.write_config()

    def load(self, num=-1):
        self.load_config()

        name = num
        if num == -1:
            file_paths = [
                p for p in Path(self.models_dir / self.name).glob('model_*.pt')
            ]
            saved_nums = sorted(
                map(lambda x: int(x.stem.split('_')[1]), file_paths))
            if len(saved_nums) == 0:
                return
            name = saved_nums[-1]
            print(f'continuing from previous epoch - {name}')

        self.steps = name * self.save_every

        load_data = torch.load(self.model_name(name))

        if 'version' in load_data and self.is_main:
            print(f"loading from version {load_data['version']}")

        try:
            self.GAN.load_state_dict(load_data['GAN'])
        except Exception as e:
            print(
                'unable to load save model. please try downgrading the package to the version specified by the saved model'
            )
            raise e

        if 'G_scaler' in load_data:
            self.G_scaler.load_state_dict(load_data['G_scaler'])
        if 'D_scaler' in load_data:
            self.D_scaler.load_state_dict(load_data['D_scaler'])
예제 #7
0
                    if IsDeepSup:
                        sys.exit("Not Implimented yet")
                    else:
                        out, _, _ = model(images)
                        loss = loss_func(out, gt)
                elif type(model) is ThisNewNet:
                    out, loss = model(images, gt=gt)
                else:
                    out = model(images)
                    loss = loss_func(out, gt)

            if IsNegLoss:
                loss = -loss

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            loss = round((-loss).data.item(),4) if IsNegLoss else round(loss.data.item(),4)
            train_loss.append(loss)
            runningLoss.append(loss)
            logging.info('[%d/%d][%d/%d] Train Loss: %.4f' % ((epoch+1), args.epochs, i, len(train_loader), loss))

            if i % args.logfreq == 0:
                niter = epoch*len(train_loader)+i
                tb_writer.add_scalar('Train/Loss', loss_reducer(runningLoss), niter)
                # tensorboard_images(tb_writer, inp, out.detach(), gt, epoch, 'train')
                runningLoss = []
        
        if epoch % args.savefreq == 0:            
예제 #8
0
def train(audio_model, train_loader, test_loader, lr=0.001, n_epochs=10):
    device = torch.device("cuda:2")
    torch.set_grad_enabled(True)

    batch_time = AverageMeter()
    per_sample_time = AverageMeter()
    data_time = AverageMeter()
    per_sample_data_time = AverageMeter()
    loss_meter = AverageMeter()
    per_sample_dnn_time = AverageMeter()
    progress = []

    best_epoch, best_cum_epoch, best_mAP, best_acc, best_cum_mAP = 0, 0, -np.inf, -np.inf, -np.inf
    global_step, epoch = 0, 0
    start_time = time.time()

    def _save_progress():
        progress.append([
            epoch, global_step, best_epoch, best_mAP,
            time.time() - start_time
        ])

        with open("progress.pkl", "wb") as f:
            pickle.dump(progress, f)

    if not isinstance(audio_model, nn.DataParallel):
        audio_model = nn.DataParallel(audio_model, [2, 1], 2)

    audio_model = audio_model.to(device)
    trainables = [p for p in audio_model.parameters() if p.requires_grad]
    print('Total parameter number is : {:.3f} million'.format(
        sum(p.numel() for p in audio_model.parameters()) / 1e6))
    print('Total trainable parameter number is : {:.3f} million'.format(
        sum(p.numel() for p in trainables) / 1e6))
    optimizer = torch.optim.Adam(trainables,
                                 lr,
                                 weight_decay=5e-7,
                                 betas=(0.95, 0.999))

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     list(range(5, 26)),
                                                     gamma=0.85)
    main_metrics = 'acc'
    loss_fn = nn.CrossEntropyLoss()
    warmup = False
    print(
        'now training with main metrics: {:s}, loss function: {:s}, learning rate scheduler: {:s}'
        .format(str(main_metrics), str(loss_fn), str(scheduler)))

    epoch += 1
    scaler = GradScaler()
    print("current #steps=%s, #epochs=%s" % (global_step, epoch))
    print("start training...")
    result = np.zeros([n_epochs, 10])
    audio_model.train()
    while epoch < n_epochs + 1:
        begin_time = time.time()
        end_time = time.time()
        audio_model.train()
        print('---------------')
        print(datetime.datetime.now())
        print("current #epochs=%s, #steps=%s" % (epoch, global_step))

        for i, d in enumerate(train_loader):
            audio_input = d['input']
            labels = d['label']
            B = audio_input.size(0)
            audio_input = audio_input.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            data_time.update(time.time(), -end_time)
            per_sample_data_time.update(
                (time.time() - end_time) / audio_input.shape[0])
            dnn_start_time = time.time()

            if global_step <= 1000 and global_step % 50 == 0 and warmup == True:
                warm_lr = (global_step / 1000) * lr
                for param_groups in optimizer.param_groups:
                    param_groups['lr'] = warm_lr
                print('warm-up learning rate is {:f}'.format(
                    optimizer.param_groups[0]['lr']))

            with autocast():
                audio_output = audio_model(audio_input)
                if isinstance(loss_fn, torch.nn.CrossEntropyLoss):
                    loss = loss_fn(audio_output, labels)
                else:
                    loss = loss_fn(audio_output, labels)

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            loss_meter.update(loss.item(), B)
            batch_time.update(time.time() - end_time)
            per_sample_time.update(
                (time.time() - end_time) / audio_input.shape[0])
            per_sample_dnn_time.update(
                (time.time() - dnn_start_time) / audio_input.shape[0])

            print_step = global_step % 50 == 0
            early_print_step = epoch == 0 and global_step % (50 / 10) == 0
            print_step = print_step or early_print_step

            if print_step and global_step != 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Total Time {per_sample_time.avg:.5f}\t'
                      'Data Time {per_sample_data_time.avg:.5f}\t'
                      'DNN Time {per_sample_dnn_time.avg:.5f}\t'
                      'Train Loss {loss_meter.avg:.4f}\t'.format(
                          epoch,
                          i,
                          len(train_loader),
                          per_sample_time=per_sample_time,
                          per_sample_data_time=per_sample_data_time,
                          per_sample_dnn_time=per_sample_dnn_time,
                          loss_meter=loss_meter),
                      flush=True)
                if np.isnan(loss_meter.avg):
                    print("training diverged...")
                    return
            end_time = time.time()
            global_step += 1

        print('start validation')

        stats, valid_loss = validate(audio_model, test_loader, epoch=10)

        #mAP = np.mean([stat['AP'] for stat in stats])
        #mAUC = np.mean([stat['auc'] for stat in stats])
        acc = stats[0]['acc']

        #middle_ps = [stat['precisions'][int(len(stat['precisions'])/2)] for stat in stats]
        #middle_rs = [stat['recalls'][int(len(stat['recalls'])/2)] for stat in stats]
        #average_precision = np.mean(middle_ps)
        #average_recall = np.mean(middle_rs)

        print("acc: {:.6f}".format(acc))
        #print("AUC: {:.6f}".format(mAUC))
        #print("Avg Precision: {:.6f}".format(average_precision))
        #print("Avg Recall: {:.6f}".format(average_recall))
        print("train_loss: {:.6f}".format(loss_meter.avg))
        print("valid_loss: {:.6f}".format(valid_loss))

        if acc > best_acc:
            best_acc = acc
            if main_metrics == 'acc':
                best_epoch = epoch

        scheduler.step()
        print('Epoch-{0} lr: {1}'.format(epoch,
                                         optimizer.param_groups[0]['lr']))

        _save_progress()
        finish_time = time.time()
        print('epoch {:d} training time: {:.3f}'.format(
            epoch, finish_time - begin_time))

        epoch += 1

        batch_time.reset()
        per_sample_time.reset()
        data_time.reset()
        per_sample_data_time.reset()
        loss_meter.reset()
        per_sample_dnn_time.reset()

        #mAP = np.mean([stat['AP'] for stat in stats])
        #mAUC = np.mean([stat['auc'] for stat in stats])
        #middle_ps = [stat['precisions'][int(len(stat['precisions'])/2)] for stat in stats]
        #middle_rs = [stat['recalls'][int(len(stat['recalls'])/2)] for stat in stats]
        #average_precision = np.mean(middle_ps)
        #average_recall = np.mean(middle_rs)
        #wa_result = [mAP, mAUC, average_precision, average_recall, d_prime(mAUC)]
        print('---------------Training Finished---------------')
        print('weighted averaged model results')
        #print("mAP: {:.6f}".format(mAP))
        #print("AUC: {:.6f}".format(mAUC))
        #print("Avg Precision: {:.6f}".format(average_precision))
        #print("Avg Recall: {:.6f}".format(average_recall))
        #print("d_prime: {:.6f}".format(d_prime(mAUC)))
        print("train_loss: {:.6f}".format(loss_meter.avg))
        print("valid_loss: {:.6f}".format(valid_loss))
예제 #9
0
    def train_epoch(self, model: Reader, optimizer: torch.optim.Optimizer,
                    scaler: GradScaler, train: DataLoader, val: DataLoader,
                    scheduler: torch.optim.lr_scheduler.LambdaLR) -> float:
        """
        Performs one training epoch.

        :param model: The model you are training.
        :type model: Reader
        :param optimizer: Use this optimizer for training.
        :type optimizer: torch.optim.Optimizer
        :param scaler: Scaler for gradients when the mixed precision is used.
        :type scaler: GradScaler
        :param train: The train dataset loader.
        :type train: DataLoader
        :param val: The validation dataset loader.
        :type val: DataLoader
        :param scheduler: Learning rate scheduler.
        :type scheduler: torch.optim.lr_scheduler.LambdaLR
        :return: Best achieved exact match among validations.
        :rtype: float
        """

        model.train()
        loss_sum = 0
        samples = 0
        startTime = time.time()

        total_tokens = 0
        optimizer.zero_grad()

        initStep = 0
        if self.resumeSkip is not None:
            initStep = self.resumeSkip
            self.resumeSkip = None

        iterator = tqdm(enumerate(train), total=len(train), initial=initStep)

        bestExactMatch = 0.0

        for current_it, batch in iterator:
            batch: ReaderBatch
            lastScale = scaler.get_scale()
            self.n_iter += 1

            batchOnDevice = batch.to(self.device)
            samples += 1

            try:
                with torch.cuda.amp.autocast(
                        enabled=self.config["mixed_precision"]):
                    startScores, endScores, jointScore, selectionScore = self._useModel(
                        model, batchOnDevice)

                    # according to the config we can get following loss combinations
                    # join components
                    # independent components
                    # join components with HardEM
                    # independent components with HardEM

                    logSpanProb = None
                    if not self.config["independent_components_in_loss"]:
                        # joined components in loss
                        logSpanProb = Reader.scores2logSpanProb(
                            startScores, endScores, jointScore, selectionScore)

                    # User may want to use hardEMLoss with certain probability.
                    # In the original article it is not written clearly and it seams like it is the other way around.
                    # After I had consulted it with authors the idea became clear.

                    if self.config["hard_em_steps"] > 0 and \
                            random.random() <= min(self.update_it/self.config["hard_em_steps"], self.config["max_hard_em_prob"]):
                        # loss is calculated for the max answer span with max probability
                        if self.config["independent_components_in_loss"]:
                            loss = Reader.hardEMIndependentComponentsLoss(
                                startScores, endScores, jointScore,
                                selectionScore, batchOnDevice.answersMask)
                        else:
                            loss = Reader.hardEMLoss(logSpanProb,
                                                     batchOnDevice.answersMask)
                    else:
                        # loss is calculated for all answer spans
                        if self.config["independent_components_in_loss"]:
                            loss = Reader.marginalCompoundLossWithIndependentComponents(
                                startScores, endScores, jointScore,
                                selectionScore, batchOnDevice.answersMask)
                        else:
                            loss = Reader.marginalCompoundLoss(
                                logSpanProb, batchOnDevice.answersMask)

                    if self.config[
                            "use_auxiliary_loss"] and batch.isGroundTruth:
                        # we must be sure that user wants it and that the true passage is ground truth
                        loss += Reader.auxiliarySelectedLoss(selectionScore)
                    loss_sum += loss.item()

                scaler.scale(loss).backward()

            # Catch out-of-memory errors
            except RuntimeError as e:
                if "CUDA out of memory." in str(e):
                    torch.cuda.empty_cache()
                    logging.error(e)
                    tb = traceback.format_exc()
                    logging.error(tb)
                    continue
                else:
                    raise e

            # update parameters

            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(
                filter(lambda p: p.requires_grad, model.parameters()),
                self.config["max_grad_norm"])

            scaler.step(optimizer)
            scaler.update()

            optimizer.zero_grad()
            self.update_it += 1

            if math.isclose(lastScale, scaler.get_scale(),
                            rel_tol=1e-6) and scheduler is not None:
                # we should not perform scheduler step when the optimizer step was omitted due to the
                # change of scale factor
                scheduler.step()

            if self.update_it % self.config["validate_after_steps"] == 0:
                valLoss, exactMatch, passageMatch, samplesWithLoss = self.validate(
                    model, val)

                logging.info(
                    f"Steps:{self.update_it}, Training loss: {loss_sum / samples:.5f}, Validation loss: {valLoss} (samples with loss {samplesWithLoss} [{samplesWithLoss / len(val):.1%}]), Exact match: {exactMatch:.5f}, Passage match: {passageMatch:.5f}"
                )

                bestExactMatch = max(exactMatch, bestExactMatch)
                if self.update_it > self.config["first_save_after_updates_K"]:
                    checkpoint = Checkpoint(
                        model.module if isinstance(model, DataParallel) else
                        model, optimizer, scheduler, train.sampler.actPerm,
                        current_it + 1, self.config, self.update_it)
                    checkpoint.save(f"{self.config['save_dir']}/Reader_train"
                                    f"_{get_timestamp()}"
                                    f"_{socket.gethostname()}"
                                    f"_{valLoss}"
                                    f"_S_{self.update_it}"
                                    f"_E_{current_it}.pt")

                model.train()

            # statistics & logging
            total_tokens += batch.inputSequences.numel()
            if (self.n_iter + 1) % 50 == 0 or current_it == len(iterator) - 1:
                iterator.set_description(
                    f"Steps: {self.update_it} Tokens/s: {total_tokens / (time.time() - startTime)}, Training loss: {loss_sum / samples}"
                )

            if self.config["max_steps"] <= self.update_it:
                break

        logging.info(
            f"End of epoch training loss: {loss_sum / samples:.5f}, best validation exact match: {bestExactMatch}"
        )

        return bestExactMatch
예제 #10
0
def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('T', ':.3f')
    data_time = AverageMeter('DT', ':.3f')
    losses = AverageMeter('Loss', ':.4f')
    top1 = AverageMeter('Acc@1', ':.2f')
    top5 = AverageMeter('Acc@5', ':.2f')

    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    if args.print_freq > 1:
        print_freq = (len(train_loader) + args.print_freq -
                      1) // args.print_freq
    else:
        print_freq = -args.print_freq // 1
    print_freq = max(print_freq, 1)

    # switch to train mode
    model.train()

    end = time.time()

    scaler = GradScaler() if args.mix else None

    for i, (images, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images[0] = images[0].cuda(args.gpu, non_blocking=True)
            images[1] = images[1].cuda(args.gpu, non_blocking=True)

        # compute output
        if args.mix:
            with autocast():
                output, target = model(im_q=images[0], im_k=images[1])
                loss = criterion(output, target)
        else:
            output, target = model(im_q=images[0], im_k=images[1])
            loss = criterion(output, target)

        # acc1/acc5 are (K+1)-way contrast classifier accuracy
        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images[0].size(0))
        top1.update(acc1[0], images[0].size(0))
        top5.update(acc5[0], images[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        if args.mix:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if args.rank == 0 and i % print_freq == 0:
            progress.display(i)

    return losses.avg, top1.avg
예제 #11
0
def do_train(cfg, arguments,
             train_data_loader, test_data_loader,
             model, criterion, optimizer, lr_scheduler,
             check_pointer, device):
    meters = MetricLogger()
    evaluator = train_data_loader.dataset.evaluator
    summary_writer = None
    use_tensorboard = cfg.TRAIN.USE_TENSORBOARD
    if is_master_proc() and use_tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        summary_writer = SummaryWriter(log_dir=os.path.join(cfg.OUTPUT_DIR, 'tf_logs'))

    log_step = cfg.TRAIN.LOG_STEP
    save_epoch = cfg.TRAIN.SAVE_EPOCH
    eval_epoch = cfg.TRAIN.EVAL_EPOCH
    max_epoch = cfg.TRAIN.MAX_EPOCH
    gradient_accumulate_step = cfg.TRAIN.GRADIENT_ACCUMULATE_STEP

    start_epoch = arguments['cur_epoch']
    epoch_iters = len(train_data_loader)
    max_iter = (max_epoch - start_epoch) * epoch_iters
    current_iterations = 0

    if cfg.TRAIN.HYBRID_PRECISION:
        # Creates a GradScaler once at the beginning of training.
        scaler = GradScaler()

    synchronize()
    model.train()
    logger.info("Start training ...")
    # Perform the training loop.
    logger.info("Start epoch: {}".format(start_epoch))
    start_training_time = time.time()
    end = time.time()
    for cur_epoch in range(start_epoch, max_epoch + 1):
        shuffle_dataset(train_data_loader, cur_epoch)
        data_loader = Prefetcher(train_data_loader, device) if cfg.DATALOADER.PREFETCHER else train_data_loader
        for iteration, (images, targets) in enumerate(data_loader):
            if not cfg.DATALOADER.PREFETCHER:
                images = images.to(device=device, non_blocking=True)
                targets = targets.to(device=device, non_blocking=True)

            if cfg.TRAIN.HYBRID_PRECISION:
                # Runs the forward pass with autocasting.
                with autocast():
                    output_dict = model(images)
                    loss_dict = criterion(output_dict, targets)
                    loss = loss_dict[KEY_LOSS] / gradient_accumulate_step

                current_iterations += 1
                if current_iterations % gradient_accumulate_step != 0:
                    if isinstance(model, DistributedDataParallel):
                        # multi-gpu distributed training
                        with model.no_sync():
                            scaler.scale(loss).backward()
                    else:
                        scaler.scale(loss).backward()
                else:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    current_iterations = 0
                    optimizer.zero_grad()
            else:
                output_dict = model(images)
                loss_dict = criterion(output_dict, targets)
                loss = loss_dict[KEY_LOSS] / gradient_accumulate_step

                current_iterations += 1
                if current_iterations % gradient_accumulate_step != 0:
                    if isinstance(model, DistributedDataParallel):
                        # multi-gpu distributed training
                        with model.no_sync():
                            loss.backward()
                    else:
                        loss.backward()
                else:
                    loss.backward()
                    optimizer.step()
                    current_iterations = 0
                    optimizer.zero_grad()

            acc_list = evaluator.evaluate_train(output_dict, targets)
            update_stats(cfg.NUM_GPUS, meters, loss_dict, acc_list)

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time)
            if (iteration + 1) % log_step == 0:
                logger.info(log_iter_stats(
                    iteration, epoch_iters, cur_epoch, max_epoch, optimizer.param_groups[0]['lr'], meters))
            if is_master_proc() and summary_writer:
                global_step = (cur_epoch - 1) * epoch_iters + (iteration + 1)
                for name, meter in meters.meters.items():
                    summary_writer.add_scalar('{}/avg'.format(name), float(meter.avg),
                                              global_step=global_step)
                    summary_writer.add_scalar('{}/global_avg'.format(name), meter.global_avg,
                                              global_step=global_step)
                summary_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], global_step=global_step)

        if not cfg.DATALOADER.PREFETCHER:
            data_loader.release()
        logger.info(log_epoch_stats(epoch_iters, cur_epoch, max_epoch, optimizer.param_groups[0]['lr'], meters))
        arguments["cur_epoch"] = cur_epoch
        lr_scheduler.step()
        if is_master_proc() and save_epoch > 0 and cur_epoch % save_epoch == 0 and cur_epoch != max_epoch:
            check_pointer.save("model_{:04d}".format(cur_epoch), **arguments)
        if eval_epoch > 0 and cur_epoch % eval_epoch == 0 and cur_epoch != max_epoch:
            if cfg.MODEL.NORM.PRECISE_BN:
                calculate_and_update_precise_bn(
                    train_data_loader,
                    model,
                    min(cfg.MODEL.NORM.NUM_BATCHES_PRECISE, len(train_data_loader)),
                    cfg.NUM_GPUS > 0,
                )

            eval_results = do_evaluation(cfg, model, test_data_loader, device, cur_epoch=cur_epoch)
            model.train()
            if is_master_proc() and summary_writer:
                for key, value in eval_results.items():
                    summary_writer.add_scalar(f'eval/{key}', value, global_step=cur_epoch + 1)

    if eval_epoch > 0:
        logger.info('Start final evaluating...')
        torch.cuda.empty_cache()  # speed up evaluating after training finished
        eval_results = do_evaluation(cfg, model, test_data_loader, device)

        if is_master_proc() and summary_writer:
            for key, value in eval_results.items():
                summary_writer.add_scalar(f'eval/{key}', value, global_step=arguments["cur_epoch"])
            summary_writer.close()
    if is_master_proc():
        check_pointer.save("model_final", **arguments)
    # compute training time
    total_training_time = int(time.time() - start_training_time)
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / max_iter))
    return model
예제 #12
0
def do_train(cfg, model, train_loader, optimizer, scheduler, loss_fn):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    device = "cuda"
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger("reid_baseline.train")
    logger.info('start training')

    if device:
        model.to(device)
        if torch.cuda.device_count() > 1:
            print('Using {} GPUs for training'.format(
                torch.cuda.device_count()))
            model = nn.DataParallel(model)

    loss_meter = AverageMeter()
    acc_meter = AverageMeter()

    # train
    scaler = GradScaler()
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        loss_meter.reset()
        acc_meter.reset()

        model.train()
        for n_iter, (img, vid) in enumerate(train_loader):

            optimizer.zero_grad()
            if cfg.INPUT.AUGMIX:
                bs = img[0].size(0)
                images_cat = torch.cat(img, dim=0).to(
                    device)  # [3 * batch, 3, 32, 32]
                target = vid.to(device)
                with autocast():
                    logits, feat = model(images_cat, target)
                    logits_orig, logits_augmix1, logits_augmix2 = logits[:bs], logits[
                        bs:2 * bs], logits[2 * bs:]
                    loss = loss_fn(logits_orig, feat, target)
                    p_orig, p_augmix1, p_augmix2 = F.softmax(
                        logits_orig,
                        dim=-1), F.softmax(logits_augmix1,
                                           dim=-1), F.softmax(logits_augmix2,
                                                              dim=-1)

                    # Clamp mixture distribution to avoid exploding KL divergence
                    p_mixture = torch.clamp(
                        (p_orig + p_augmix1 + p_augmix2) / 3., 1e-7, 1).log()
                    loss += 12 * (
                        F.kl_div(p_mixture, p_orig, reduction='batchmean') +
                        F.kl_div(p_mixture, p_augmix1, reduction='batchmean') +
                        F.kl_div(p_mixture, p_augmix2,
                                 reduction='batchmean')) / 3.
            else:
                img = img.to(device)
                target = vid.to(device)
                with autocast():
                    if cfg.MODEL.CHANNEL_HEAD:
                        score, feat, channel_head_feature = model(img, target)
                        #print(feat.shape, channel_head_feature.shape)
                        loss = loss_fn(score, feat, channel_head_feature,
                                       target)

                    else:
                        score, feat = model(img, target)
                        loss = loss_fn(score, feat, target)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            acc = (score.max(1)[1] == target).float().mean()
            loss_meter.update(loss.item(), img.shape[0])
            acc_meter.update(acc, 1)

            if (n_iter + 1) % log_period == 0:
                logger.info(
                    "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                    .format(epoch, (n_iter + 1), len(train_loader),
                            loss_meter.avg, acc_meter.avg,
                            scheduler.get_lr()[0]))
        scheduler.step()
        end_time = time.time()
        time_per_batch = (end_time - start_time) / (n_iter + 1)
        logger.info(
            "Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]"
            .format(epoch, time_per_batch,
                    train_loader.batch_size / time_per_batch))

        if epoch % checkpoint_period == 0:
            torch.save(
                model.state_dict(),
                os.path.join(cfg.OUTPUT_DIR,
                             cfg.MODEL.NAME + '_{}.pth'.format(epoch)))
예제 #13
0
class DeepvacTrain(Deepvac):
    def __init__(self, deepvac_config):
        super(DeepvacTrain, self).__init__(deepvac_config)
        self.initTrainParameters()
        self.initTrainContext()

    def setTrainContext(self):
        self.is_train = True
        self.is_val = False
        self.phase = 'TRAIN'
        self.dataset = self.train_dataset
        self.loader = self.train_loader
        self.batch_size = self.conf.train.batch_size
        self.net.train()
        if self.qat_net_prepared:
            self.qat_net_prepared.train()

    def setValContext(self):
        self.is_train = False
        self.is_val = True
        self.phase = 'VAL'
        self.dataset = self.val_dataset
        self.loader = self.val_loader
        self.batch_size = self.conf.val.batch_size
        self.net.eval()
        if self.qat_net_prepared:
            self.qat_net_prepared.eval()

    def initTrainContext(self):
        self.scheduler = None
        self.initOutputDir()
        self.initSummaryWriter()
        self.initCriterion()
        self.initOptimizer()
        self.initScheduler()
        self.initCheckpoint()
        self.initTrainLoader()
        self.initValLoader()

    def initTrainParameters(self):
        self.dataset = None
        self.loader = None
        self.target = None
        self.epoch = 0
        self.step = 0
        self.iter = 0
        # Creates a GradScaler once at the beginning of training.
        self.scaler = GradScaler()
        self.train_time = AverageMeter()
        self.load_data_time = AverageMeter()
        self.data_cpu2gpu_time = AverageMeter()
        self._mandatory_member_name = [
            'train_dataset', 'val_dataset', 'train_loader', 'val_loader',
            'net', 'criterion', 'optimizer'
        ]

    def initOutputDir(self):
        if self.conf.output_dir != 'output' or self.conf.output_dir != './output':
            LOG.logW(
                "According deepvac standard, you should save model files to [output] directory."
            )

        self.output_dir = '{}/{}'.format(self.conf.output_dir, self.branch)
        LOG.logI('model save dir: {}'.format(self.output_dir))
        #for DDP race condition
        os.makedirs(self.output_dir, exist_ok=True)

    def initSummaryWriter(self):
        event_dir = "{}/{}".format(self.conf.log_dir, self.branch)
        self.writer = SummaryWriter(event_dir)
        if not self.conf.tensorboard_port:
            return
        from tensorboard import program
        tensorboard = program.TensorBoard()
        self.conf.tensorboard_ip = '0.0.0.0' if self.conf.tensorboard_ip is None else self.conf.tensorboard_ip
        tensorboard.configure(argv=[
            None, '--host',
            str(self.conf.tensorboard_ip), '--logdir', event_dir, "--port",
            str(self.conf.tensorboard_port)
        ])
        try:
            url = tensorboard.launch()
            LOG.logI('Tensorboard at {} '.format(url))
        except Exception as e:
            LOG.logE(e.msg)

    def initCriterion(self):
        self.criterion = torch.nn.CrossEntropyLoss()
        LOG.logW(
            "You should reimplement initCriterion() to initialize self.criterion, unless CrossEntropyLoss() is exactly what you need"
        )

    def initCheckpoint(self):
        if not self.conf.checkpoint_suffix or self.conf.checkpoint_suffix == "":
            LOG.logI('Omit the checkpoint file since not specified...')
            return
        LOG.logI('Load checkpoint from {} folder'.format(self.output_dir))
        self.net.load_state_dict(
            torch.load(self.output_dir +
                       '/model__{}'.format(self.conf.checkpoint_suffix),
                       map_location=self.device))
        state_dict = torch.load(
            self.output_dir +
            '/checkpoint__{}'.format(self.conf.checkpoint_suffix),
            map_location=self.device)
        self.optimizer.load_state_dict(state_dict['optimizer'])
        if self.scheduler:
            self.scheduler.load_state_dict(state_dict['scheduler'])
        if self.conf.amp:
            LOG.logI(
                "Will load scaler from checkpoint since you enabled amp, make sure the checkpoint was saved with amp enabled."
            )
            try:
                self.scaler.load_state_dict(state_dict["scaler"])
            except:
                LOG.logI(
                    "checkpoint was saved without amp enabled, so use fresh GradScaler instead."
                )
                self.scaler = GradScaler()

        self.epoch = state_dict['epoch']

    def initScheduler(self):
        if isinstance(self.conf.lr_step, list):
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
                self.optimizer, self.conf.lr_step, self.conf.lr_factor)
        elif isinstance(self.conf.lr_step, FunctionType):
            self.scheduler = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer, lr_lambda=self.conf.lr_step)
        else:
            self.scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer, self.conf.lr_step, self.conf.lr_factor)
        LOG.logW(
            "You should reimplement initScheduler() to initialize self.scheduler, unless lr_scheduler.StepLR() or lr_scheduler.MultiStepLR() is exactly what you need"
        )

    def initTrainLoader(self):
        self.train_loader = None
        LOG.logE(
            "You must reimplement initTrainLoader() to initialize self.train_loader",
            exit=True)

    def initValLoader(self):
        self.val_loader = None
        LOG.logE(
            "You must reimplement initTrainLoader() to initialize self.val_loader",
            exit=True)

    def initOptimizer(self):
        self.initSgdOptimizer()
        LOG.logW(
            "You should reimplement initOptimizer() to initialize self.optimizer, unless SGD is exactly what you need"
        )

    def initSgdOptimizer(self):
        self.optimizer = optim.SGD(self.net.parameters(),
                                   lr=self.conf.lr,
                                   momentum=self.conf.momentum,
                                   weight_decay=self.conf.weight_decay,
                                   nesterov=self.conf.nesterov)

    def initAdamOptimizer(self):
        self.optimizer = optim.Adam(
            self.net.parameters(),
            lr=self.conf.lr,
        )
        for group in self.optimizer.param_groups:
            group.setdefault('initial_lr', group['lr'])

    def initRmspropOptimizer(self):
        self.optimizer = optim.RMSprop(
            self.net.parameters(),
            lr=self.conf.lr,
            momentum=self.conf.momentum,
            weight_decay=self.conf.weight_decay,
            # alpha=self.conf.rmsprop_alpha,
            # centered=self.conf.rmsprop_centered
        )

    def addScalar(self, tag, value, step):
        self.writer.add_scalar(tag, value, step)

    def addImage(self, tag, image, step):
        self.writer.add_image(tag, image, step)

    @syszux_once
    def addGraph(self, input):
        self.writer.add_graph(self.net, input)

    @syszux_once
    def smokeTestForExport3rd(self):
        #exportNCNN must before exportONNX
        self.exportONNX()
        self.exportNCNN()
        self.exportCoreML()
        #whether export TorchScript via trace, only here we can get self.sample
        self.exportTorchViaTrace()
        #compile pytorch state dict to TorchScript
        self.exportTorchViaScript()
        self.exportDynamicQuant()
        self.exportStaticQuant(prepare=True)

    def earlyIter(self):
        start = time.time()
        self.sample = self.sample.to(self.device)
        self.target = self.target.to(self.device)
        if not self.is_train:
            return
        self.data_cpu2gpu_time.update(time.time() - start)
        try:
            self.addGraph(self.sample)
        except:
            LOG.logW(
                "Tensorboard addGraph failed. You network foward may have more than one parameters?"
            )
            LOG.logW("Seems you need reimplement preIter function.")

    def preIter(self):
        pass

    def postIter(self):
        pass

    def preEpoch(self):
        pass

    def postEpoch(self):
        pass

    def doForward(self):
        self.output = self.net(self.sample)

    def doCalibrate(self):
        if self.static_quantized_net_prepared is None:
            return
        self.static_quantized_net_prepared(self.sample)

    def doLoss(self):
        self.loss = self.criterion(self.output, self.target)

    def doBackward(self):
        if self.conf.amp:
            self.scaler.scale(self.loss).backward()
        else:
            self.loss.backward()

    def doOptimize(self):
        if self.iter % self.conf.nominal_batch_factor != 0:
            return
        if self.conf.amp:
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            self.optimizer.step()
        self.optimizer.zero_grad()

    def doLog(self):
        if self.step % self.conf.log_every != 0:
            return
        self.addScalar('{}/Loss'.format(self.phase), self.loss.item(),
                       self.iter)
        self.addScalar('{}/LoadDataTime(secs/batch)'.format(self.phase),
                       self.load_data_time.val, self.iter)
        self.addScalar('{}/DataCpu2GpuTime(secs/batch)'.format(self.phase),
                       self.data_cpu2gpu_time.val, self.iter)
        self.addScalar('{}/TrainTime(secs/batch)'.format(self.phase),
                       self.train_time.val, self.iter)
        LOG.logI('{}: [{}][{}/{}] [Loss:{}  Lr:{}]'.format(
            self.phase, self.epoch, self.step, self.loader_len,
            self.loss.item(), self.optimizer.param_groups[0]['lr']))

    def saveState(self, current_time):
        file_partial_name = '{}__acc_{}__epoch_{}__step_{}__lr_{}'.format(
            current_time, self.accuracy, self.epoch, self.step,
            self.optimizer.param_groups[0]['lr'])
        state_file = '{}/model__{}.pth'.format(self.output_dir,
                                               file_partial_name)
        checkpoint_file = '{}/checkpoint__{}.pth'.format(
            self.output_dir, file_partial_name)
        output_trace_file = '{}/trace__{}.pt'.format(self.output_dir,
                                                     file_partial_name)
        output_script_file = '{}/script__{}.pt'.format(self.output_dir,
                                                       file_partial_name)
        output_onnx_file = '{}/onnx__{}.onnx'.format(self.output_dir,
                                                     file_partial_name)
        output_ncnn_file = '{}/ncnn__{}.bin'.format(self.output_dir,
                                                    file_partial_name)
        output_coreml_file = '{}/coreml__{}.mlmodel'.format(
            self.output_dir, file_partial_name)
        output_dynamic_quant_file = '{}/squant__{}.pt'.format(
            self.output_dir, file_partial_name)
        output_static_quant_file = '{}/dquant__{}.pt'.format(
            self.output_dir, file_partial_name)
        output_qat_file = '{}/qat__{}.pt'.format(self.output_dir,
                                                 file_partial_name)
        #save state_dict
        torch.save(self.net.state_dict(), state_file)
        #save checkpoint
        torch.save(
            {
                'optimizer': self.optimizer.state_dict(),
                'epoch': self.epoch,
                'scheduler':
                self.scheduler.state_dict() if self.scheduler else None,
                'scaler': self.scaler.state_dict() if self.conf.amp else None
            }, checkpoint_file)

        #convert for quantize, must before trace and script!!!
        self.exportDynamicQuant(output_dynamic_quant_file)
        self.exportStaticQuant(output_quant_file=output_static_quant_file)
        self.exportQAT(output_quant_file=output_qat_file)
        #save pt via trace
        self.exportTorchViaTrace(self.sample, output_trace_file)
        #save pt vida script
        self.exportTorchViaScript(output_script_file)
        #save onnx
        self.exportONNX(output_onnx_file)
        #save ncnn
        self.exportNCNN(output_ncnn_file)
        #save coreml
        self.exportCoreML(output_coreml_file)
        #tensorboard
        self.addScalar('{}/Accuracy'.format(self.phase), self.accuracy,
                       self.iter)

    def processTrain(self):
        self.setTrainContext()
        self.step = 0
        LOG.logI('Phase {} started...'.format(self.phase))
        self.loader_len = len(self.loader)
        save_every = self.loader_len // self.conf.save_num
        save_list = list(range(0, self.loader_len + 1, save_every))
        self.save_list = save_list[1:-1]
        LOG.logI('Model will be saved on step {} and the epoch end.'.format(
            self.save_list))
        self.addScalar('{}/LR'.format(self.phase),
                       self.optimizer.param_groups[0]['lr'], self.epoch)
        self.preEpoch()
        self.train_time.reset()
        self.load_data_time.reset()
        self.data_cpu2gpu_time.reset()

        start = time.time()
        for i, (sample, target) in enumerate(self.loader):
            self.load_data_time.update(time.time() - start)
            self.step = i
            self.target = target
            self.sample = sample
            self.preIter()
            self.earlyIter()
            with autocast(enabled=self.conf.amp if self.conf.amp else False):
                self.doForward()
                self.doLoss()
            self.doBackward()
            self.doOptimize()
            self.doLog()
            self.postIter()
            self.iter += 1
            self.train_time.update(time.time() - start)
            if self.step in self.save_list:
                self.processVal()
                self.setTrainContext()
            start = time.time()

        self.addScalar('{}/TrainTime(hours/epoch)'.format(self.phase),
                       round(self.train_time.sum / 3600, 2), self.epoch)
        self.addScalar(
            '{}/AverageBatchTrainTime(secs/epoch)'.format(self.phase),
            self.train_time.avg, self.epoch)
        self.addScalar(
            '{}/AverageBatchLoadDataTime(secs/epoch)'.format(self.phase),
            self.load_data_time.avg, self.epoch)
        self.addScalar(
            '{}/AverageBatchDataCpu2GpuTime(secs/epoch)'.format(self.phase),
            self.data_cpu2gpu_time.avg, self.epoch)

        self.postEpoch()
        if self.scheduler:
            self.scheduler.step()

    def processVal(self, smoke=False):
        self.setValContext()
        LOG.logI('Phase {} started...'.format(self.phase))
        with torch.no_grad():
            self.preEpoch()
            for i, (sample, target) in enumerate(self.loader):
                self.target = target
                self.sample = sample
                self.preIter()
                self.earlyIter()
                self.doForward()
                #calibrate only for quantization.
                self.doCalibrate()
                self.doLoss()
                self.smokeTestForExport3rd()
                LOG.logI('{}: [{}][{}/{}]'.format(self.phase, self.epoch, i,
                                                  len(self.loader)))
                self.postIter()
                if smoke:
                    break
            self.postEpoch()
        self.saveState(self.getTime())

    def processAccept(self):
        self.setValContext()

    def process(self):
        self.auditConfig()
        self.iter = 0
        epoch_start = self.epoch
        self.processVal(smoke=True)
        self.optimizer.zero_grad()
        for epoch in range(epoch_start, self.conf.epoch_num):
            self.epoch = epoch
            LOG.logI('Epoch {} started...'.format(self.epoch))
            self.processTrain()
            self.processVal()
            self.processAccept()

    def __call__(self):
        self.process()
예제 #14
0
    with autocast():
        Y_attr = G.get_attr(Y)
        L_attr = 0
        for i in range(len(Xt_attr)):
            #L_attr += torch.mean(torch.pow(Xt_attr[i] - Y_attr[i], 2).reshape(batch_size, -1), dim=1).mean()
            L_attr += torch.mean(torch.pow(Xt_attr[i] - Y_attr[i], 2))
        L_attr *= C_attr / 2.0

        #L_rec = torch.sum(0.5 * torch.mean(torch.pow(Y - Xt, 2).reshape(batch_size, -1), dim=1) * same_person) / (same_person.sum() + 1e-6)
        L_rec = MSE(
            Y[same_person],
            Xt[same_person]) * same_person.sum() * C_rec / (2.0 * batch_size)

        lossG = L_adv + L_attr + L_id + L_rec

    scaler.scale(lossG).backward()
    scaler.step(opt_G)

    # train D
    if niter % 2:  # Trying best of both world
        Xr = Xt  # Xt achieve better L_rec convergence
    else:
        Xr = Xs  # Xs achieve better L_id convergence
    Xr.requires_grad = True
    D.requires_grad_(True)
    opt_D.zero_grad()
    Xf = Y.detach()
    Xf.requires_grad = True
    with autocast():
        fake_D = D(Xf)
        loss_fake = 0
예제 #15
0
def train(model,
          state,
          path,
          annotations,
          val_path,
          val_annotations,
          resize,
          max_size,
          jitter,
          batch_size,
          iterations,
          val_iterations,
          mixed_precision,
          lr,
          warmup,
          milestones,
          gamma,
          rank=0,
          world=1,
          no_apex=False,
          use_dali=True,
          verbose=True,
          metrics_url=None,
          logdir=None,
          rotate_augment=False,
          augment_brightness=0.0,
          augment_contrast=0.0,
          augment_hue=0.0,
          augment_saturation=0.0,
          regularization_l2=0.0001,
          rotated_bbox=False,
          absolute_angle=False):
    'Train the model on the given dataset'

    # Prepare model
    nn_model = model
    stride = model.stride

    model = convert_fixedbn_model(model)
    if torch.cuda.is_available():
        model = model.to(memory_format=torch.channels_last).cuda()

    # Setup optimizer and schedule
    optimizer = SGD(model.parameters(),
                    lr=lr,
                    weight_decay=regularization_l2,
                    momentum=0.9)

    is_master = rank == 0
    if not no_apex:
        loss_scale = "dynamic" if use_dali else "128.0"
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level='O2' if mixed_precision else 'O0',
            keep_batchnorm_fp32=True,
            loss_scale=loss_scale,
            verbosity=is_master)

    if world > 1:
        model = DDP(model, device_ids=[rank]) if no_apex else ADDP(model)
    model.train()

    if 'optimizer' in state:
        optimizer.load_state_dict(state['optimizer'])

    def schedule(train_iter):
        if warmup and train_iter <= warmup:
            return 0.9 * train_iter / warmup + 0.1
        return gamma**len([m for m in milestones if m <= train_iter])

    scheduler = LambdaLR(optimizer, schedule)
    if 'scheduler' in state:
        scheduler.load_state_dict(state['scheduler'])

    # Prepare dataset
    if verbose: print('Preparing dataset...')
    if rotated_bbox:
        if use_dali:
            raise NotImplementedError(
                "This repo does not currently support DALI for rotated bbox detections."
            )
        data_iterator = RotatedDataIterator(
            path,
            jitter,
            max_size,
            batch_size,
            stride,
            world,
            annotations,
            training=True,
            rotate_augment=rotate_augment,
            augment_brightness=augment_brightness,
            augment_contrast=augment_contrast,
            augment_hue=augment_hue,
            augment_saturation=augment_saturation,
            absolute_angle=absolute_angle)
    else:
        data_iterator = (DaliDataIterator if use_dali else DataIterator)(
            path,
            jitter,
            max_size,
            batch_size,
            stride,
            world,
            annotations,
            training=True,
            rotate_augment=rotate_augment,
            augment_brightness=augment_brightness,
            augment_contrast=augment_contrast,
            augment_hue=augment_hue,
            augment_saturation=augment_saturation)
    if verbose: print(data_iterator)

    if verbose:
        print('    device: {} {}'.format(
            world, 'cpu' if not torch.cuda.is_available() else
            'GPU' if world == 1 else 'GPUs'))
        print('     batch: {}, precision: {}'.format(
            batch_size, 'mixed' if mixed_precision else 'full'))
        print(' BBOX type:', 'rotated' if rotated_bbox else 'axis aligned')
        print('Training model for {} iterations...'.format(iterations))

    # Create TensorBoard writer
    if is_master and logdir is not None:
        from torch.utils.tensorboard import SummaryWriter
        if verbose:
            print('Writing TensorBoard logs to: {}'.format(logdir))
        writer = SummaryWriter(log_dir=logdir)

    scaler = GradScaler()
    profiler = Profiler(['train', 'fw', 'bw'])
    iteration = state.get('iteration', 0)
    while iteration < iterations:
        cls_losses, box_losses = [], []
        for i, (data, target) in enumerate(data_iterator):
            if iteration >= iterations:
                break

            # Forward pass
            profiler.start('fw')

            optimizer.zero_grad()
            if not no_apex:
                cls_loss, box_loss = model([
                    data.contiguous(memory_format=torch.channels_last), target
                ])
            else:
                with autocast():
                    cls_loss, box_loss = model([
                        data.contiguous(memory_format=torch.channels_last),
                        target
                    ])
            del data
            profiler.stop('fw')

            # Backward pass
            profiler.start('bw')
            if not no_apex:
                with amp.scale_loss(cls_loss + box_loss,
                                    optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()
            else:
                scaler.scale(cls_loss + box_loss).backward()
                scaler.step(optimizer)
                scaler.update()

            scheduler.step()

            # Reduce all losses
            cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean(
            ).clone()
            if world > 1:
                torch.distributed.all_reduce(cls_loss)
                torch.distributed.all_reduce(box_loss)
                cls_loss /= world
                box_loss /= world
            if is_master:
                cls_losses.append(cls_loss)
                box_losses.append(box_loss)

            if is_master and not isfinite(cls_loss + box_loss):
                raise RuntimeError('Loss is diverging!\n{}'.format(
                    'Try lowering the learning rate.'))

            del cls_loss, box_loss
            profiler.stop('bw')

            iteration += 1
            profiler.bump('train')
            if is_master and (profiler.totals['train'] > 60
                              or iteration == iterations):
                focal_loss = torch.stack(list(cls_losses)).mean().item()
                box_loss = torch.stack(list(box_losses)).mean().item()
                learning_rate = optimizer.param_groups[0]['lr']
                if verbose:
                    msg = '[{:{len}}/{}]'.format(iteration,
                                                 iterations,
                                                 len=len(str(iterations)))
                    msg += ' focal loss: {:.3f}'.format(focal_loss)
                    msg += ', box loss: {:.3f}'.format(box_loss)
                    msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'],
                                                       batch_size)
                    msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(
                        profiler.means['fw'], profiler.means['bw'])
                    msg += ', {:.1f} im/s'.format(batch_size /
                                                  profiler.means['train'])
                    msg += ', lr: {:.2g}'.format(learning_rate)
                    print(msg, flush=True)

                if is_master and logdir is not None:
                    writer.add_scalar('focal_loss', focal_loss, iteration)
                    writer.add_scalar('box_loss', box_loss, iteration)
                    writer.add_scalar('learning_rate', learning_rate,
                                      iteration)
                    del box_loss, focal_loss

                if metrics_url:
                    post_metrics(
                        metrics_url, {
                            'focal loss': mean(cls_losses),
                            'box loss': mean(box_losses),
                            'im_s': batch_size / profiler.means['train'],
                            'lr': learning_rate
                        })

                # Save model weights
                state.update({
                    'iteration': iteration,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                })
                with ignore_sigint():
                    nn_model.save(state)

                profiler.reset()
                del cls_losses[:], box_losses[:]

            if val_annotations and (iteration == iterations
                                    or iteration % val_iterations == 0):
                stats = infer(model,
                              val_path,
                              None,
                              resize,
                              max_size,
                              batch_size,
                              annotations=val_annotations,
                              mixed_precision=mixed_precision,
                              is_master=is_master,
                              world=world,
                              use_dali=use_dali,
                              no_apex=no_apex,
                              is_validation=True,
                              verbose=False,
                              rotated_bbox=rotated_bbox)
                model.train()
                if is_master and logdir is not None and stats is not None:
                    writer.add_scalar('Validation_Precision/mAP', stats[0],
                                      iteration)
                    writer.add_scalar('Validation_Precision/[email protected]',
                                      stats[1], iteration)
                    writer.add_scalar('Validation_Precision/[email protected]',
                                      stats[2], iteration)
                    writer.add_scalar('Validation_Precision/mAP (small)',
                                      stats[3], iteration)
                    writer.add_scalar('Validation_Precision/mAP (medium)',
                                      stats[4], iteration)
                    writer.add_scalar('Validation_Precision/mAP (large)',
                                      stats[5], iteration)
                    writer.add_scalar('Validation_Recall/mAR (max 1 Dets)',
                                      stats[6], iteration)
                    writer.add_scalar('Validation_Recall/mAR (max 10 Dets)',
                                      stats[7], iteration)
                    writer.add_scalar('Validation_Recall/mAR (max 100 Dets)',
                                      stats[8], iteration)
                    writer.add_scalar('Validation_Recall/mAR (small)',
                                      stats[9], iteration)
                    writer.add_scalar('Validation_Recall/mAR (medium)',
                                      stats[10], iteration)
                    writer.add_scalar('Validation_Recall/mAR (large)',
                                      stats[11], iteration)

            if (iteration == iterations
                    and not rotated_bbox) or (iteration > iterations
                                              and rotated_bbox):
                break

    if is_master and logdir is not None:
        writer.close()
예제 #16
0
def train_schedule(writer,
                   loader,
                   validation_loader,
                   val_num_steps,
                   device,
                   criterion,
                   net,
                   optimizer,
                   lr_scheduler,
                   num_epochs,
                   is_mixed_precision,
                   input_sizes,
                   exp_name,
                   num_classes,
                   method='baseline'):
    # Should be the same as segmentation, given customized loss classes
    net.train()
    epoch = 0
    running_loss = 0.0
    loss_num_steps = int(len(loader) / 10) if len(loader) > 10 else 1
    if is_mixed_precision:
        scaler = GradScaler()

    # Training
    best_validation = 0
    while epoch < num_epochs:
        net.train()
        time_now = time.time()
        for i, data in enumerate(loader, 0):
            if method == 'lstr':
                inputs, labels = data
                inputs = inputs.to(device)
                labels = [{k: v.to(device)
                           for k, v in label.items()}
                          for label in labels]  # Seems slow
            else:
                inputs, labels, lane_existence = data
                inputs, labels, lane_existence = inputs.to(device), labels.to(
                    device), lane_existence.to(device)
            optimizer.zero_grad()

            with autocast(is_mixed_precision):
                # To support intermediate losses for SAD
                if method == 'lstr':
                    loss = criterion(inputs, labels, net)
                else:
                    loss = criterion(inputs, labels, lane_existence, net,
                                     input_sizes[0])

            if is_mixed_precision:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            lr_scheduler.step()
            running_loss += loss.item()
            current_step_num = int(epoch * len(loader) + i + 1)

            # Record losses
            if current_step_num % loss_num_steps == (loss_num_steps - 1):
                print('[%d, %d] loss: %.4f' %
                      (epoch + 1, i + 1, running_loss / loss_num_steps))
                writer.add_scalar('training loss',
                                  running_loss / loss_num_steps,
                                  current_step_num)
                running_loss = 0.0

            # Record checkpoints
            if validation_loader is not None:
                if current_step_num % val_num_steps == (val_num_steps - 1) or \
                        current_step_num == num_epochs * len(loader):
                    # save_checkpoint(net=net, optimizer=optimizer, lr_scheduler=lr_scheduler,
                    #                 filename=exp_name + '_' + str(current_step_num) + '.pt')

                    test_pixel_accuracy, test_mIoU = fast_evaluate(
                        loader=validation_loader,
                        device=device,
                        net=net,
                        num_classes=num_classes,
                        output_size=input_sizes[0],
                        is_mixed_precision=is_mixed_precision)
                    writer.add_scalar('test pixel accuracy',
                                      test_pixel_accuracy, current_step_num)
                    writer.add_scalar('test mIoU', test_mIoU, current_step_num)
                    net.train()

                    # Record best model (straight to disk)
                    if test_mIoU > best_validation:
                        best_validation = test_mIoU
                        save_checkpoint(net=net,
                                        optimizer=optimizer,
                                        lr_scheduler=lr_scheduler,
                                        filename=exp_name + '.pt')

        epoch += 1
        print('Epoch time: %.2fs' % (time.time() - time_now))

    # For no-evaluation mode
    if validation_loader is None:
        save_checkpoint(net=net,
                        optimizer=optimizer,
                        lr_scheduler=lr_scheduler,
                        filename=exp_name + '.pt')
예제 #17
0
def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler,
             device):
    if CFG.device == 'GPU':
        scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    global_step = 0
    for step, (images, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        if CFG.device == 'GPU':
            with autocast():
                _, _, y_preds = model(images)
                loss = criterion(y_preds, labels)
                # record loss
                losses.update(loss.item(), batch_size)
                if CFG.gradient_accumulation_steps > 1:
                    loss = loss / CFG.gradient_accumulation_steps
                scaler.scale(loss).backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), CFG.max_grad_norm)
                if (step + 1) % CFG.gradient_accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    global_step += 1

        elif CFG.device == 'TPU':
            _, _, y_preds = model(images)
            loss = criterion(y_preds, labels)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.gradient_accumulation_steps > 1:
                loss = loss / CFG.gradient_accumulation_steps
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       CFG.max_grad_norm)
            if (step + 1) % CFG.gradient_accumulation_steps == 0:
                xm.optimizer_step(optimizer, barrier=True)
                optimizer.zero_grad()
                global_step += 1
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if CFG.device == 'GPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
                print('Epoch: [{0}][{1}/{2}] '
                      'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                      'Elapsed {remain:s} '
                      'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                      'Grad: {grad_norm:.4f}  '
                      #'LR: {lr:.6f}  '
                      .format(
                       epoch+1, step, len(train_loader), batch_time=batch_time,
                       data_time=data_time, loss=losses,
                       remain=timeSince(start, float(step+1)/len(train_loader)),
                       grad_norm=grad_norm,
                       #lr=scheduler.get_lr()[0],
                       ))
        elif CFG.device == 'TPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader) - 1):
                xm.master_print('Epoch: [{0}][{1}/{2}] '
                                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                                'Elapsed {remain:s} '
                                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                                'Grad: {grad_norm:.4f}  '
                                #'LR: {lr:.6f}  '
                                .format(
                                epoch+1, step, len(train_loader), batch_time=batch_time,
                                data_time=data_time, loss=losses,
                                remain=timeSince(start, float(step+1)/len(train_loader)),
                                grad_norm=grad_norm,
                                #lr=scheduler.get_lr()[0],
                                ))
    return losses.avg
예제 #18
0
class CNN(object):
    def __init__(self):
        self.model = None
        self.lr = 0.001
        self.epochs = 20
        self.train_batch_size = 100
        self.test_batch_size = 100
        self.criterion = None
        self.optimizer = None
        self.scheduler = None
        self.device = None
        self.cuda = torch.cuda.is_available()
        self.train_loader = None
        self.test_loader = None
        self.scaler = None
    
    def load_data(self):
        train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
        test_transform = transforms.Compose([transforms.ToTensor()])
        train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
        self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size, shuffle=True)
        test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
        self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)

    def load_model(self):
        if self.cuda:
            self.device = torch.device('cuda')
            cudnn.benchmark = True
        else:
            self.device = torch.device('cpu')

        # self.model = LeNet().to(self.device)
        self.model = AlexNet().to(self.device)
        # self.model = VGG11().to(self.device)
        # self.model = VGG13().to(self.device)
        # self.model = VGG16().to(self.device)
        # self.model = VGG19().to(self.device)
        # self.model = GoogLeNet().to(self.device)
        # self.model = resnet18().to(self.device)
        # self.model = resnet34().to(self.device)
        # self.model = resnet50().to(self.device)
        # self.model = resnet101().to(self.device)
        # self.model = resnet152().to(self.device)
        # self.model = DenseNet121().to(self.device)
        # self.model = DenseNet161().to(self.device)
        # self.model = DenseNet169().to(self.device)
        # self.model = DenseNet201().to(self.device)
        # self.model = WideResNet(depth=28, num_classes=10).to(self.device)

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[75, 150], gamma=0.5)
        self.criterion = nn.CrossEntropyLoss().to(self.device)
        self.scaler = GradScaler()

    def train(self):
        print("train:")
        self.model.train()
        train_loss = 0
        train_correct = 0
        total = 0

        for batch_num, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()

            # output = self.model(data)
            # loss = self.criterion(output, target)
            # loss.backward()
            # self.optimizer.step()
            
            # Runs the forward pass with autocasting.
            with autocast():
              output = self.model(data)
              loss = self.criterion(output, target)
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            train_loss += loss.item()
            prediction = torch.max(output, 1)  # second param "1" represents the dimension to be reduced
            total += target.size(0)

            # train_correct incremented by one if predicted right
            train_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())

            # progress_bar(batch_num, len(self.train_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'
            #             % (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total))

            # print('Loss: %.4f | Acc: %.3f%% (%d/%d)'
            #         % (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total))

        return train_loss, train_correct / total

    def test(self):
        print("test:")
        self.model.eval()
        test_loss = 0
        test_correct = 0
        total = 0

        with torch.no_grad():
            for batch_num, (data, target) in enumerate(self.test_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                test_loss += loss.item()
                prediction = torch.max(output, 1)
                total += target.size(0)
                test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())

                # progress_bar(batch_num, len(self.test_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'
                #              % (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total))
                # print('Loss: %.4f | Acc: %.3f%% (%d/%d)'
                #     % (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total))

        return test_loss, test_correct / total

    # def save(self):
    #     model_out_path = "result.txt"
    #     torch.save(self.model, model_out_path)
    #     print("Checkpoint saved to {}".format(model_out_path))

    def run(self):
        self.load_data()
        self.load_model()
        accuracy = 0
        for epoch in range(1, self.epochs + 1):
            start = time.time()
            # self.scheduler.step()
            print("\n===> epoch: %d/%d" % (epoch, self.epochs))
            train_result = self.train()
            self.scheduler.step()
            end = time.time()
            print('Loss: %.4f | Acc: %.3f%%| time: %.3fs'
                % (train_result[0] / 501, 100. * train_result[1], end-start))
            if epoch == self.epochs:
              test_result = self.test()
              accuracy = max(accuracy, test_result[1])
              print("===> BEST ACC. PERFORMANCE: %.3f%%" % (accuracy * 100))
예제 #19
0
def train(model: Model,
          state: dict,
          train_data_path: str,
          train_rgb_json: str,
          val_data_path: str,
          val_rgb_json: str,
          transform_file: str,
          growing_parameters: dict,
          lr: float,
          iterations: int,
          val_iterations: int,
          verbose: bool,
          train_segment_masks_path: str = '',
          val_segment_masks_path: str = '',
          lambda_ccl=0.0,
          loss_type='L2',
          ccl_version='linear',
          alpha=5,
          gamma=.5,
          regularization_l2: float = 0.,
          warmup=5000,
          milestones=[],
          optimizer_name: str = 'sgd',
          print_every: int = 250,
          debug=False):
    model.train()
    torch.backends.cudnn.benchmark = True

    if debug:
        print_every = 10

    sparse_growing_parameters = load_growing_parameters(growing_parameters)
    filled_growing_parameters = fill_growing_parameters(
        sparse_growing_parameters, iterations)

    assert os.path.isfile(transform_file)
    sys.path.insert(0, os.path.dirname(transform_file))
    transforms = __import__(
        os.path.splitext(os.path.basename(transform_file))[0])

    model_dir = os.path.dirname(state['path'])

    writer = SummaryWriter(log_dir=os.path.join(model_dir, 'logs'))

    if loss_type == 'L2':
        criterion = L2Loss(weighted=False)
    elif loss_type == 'L2W':
        criterion = L2Loss(weighted=True, alpha=alpha, gamma=gamma)
    elif loss_type == 'L1':
        criterion = L1Loss(weighted=False)
    elif loss_type == 'L1W':
        criterion = L1Loss(weighted=True, alpha=alpha, gamma=gamma)
    elif loss_type == 'L2+CCL':
        criterion = L2CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version)
    elif loss_type == 'L2W+CCL':
        criterion = L2CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             weighted=True,
                             alpha=alpha,
                             gamma=gamma)
    elif loss_type == 'L2+CCL-gt':
        criterion = L2CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             ccl_target='gt',
                             weighted=False)
    elif loss_type == 'L2W+CCL-gt':
        criterion = L2CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             ccl_target='gt',
                             weighted=True,
                             alpha=alpha,
                             gamma=gamma)
    elif loss_type == 'L1+CCL':
        criterion = L1CCLoss(lambda_ccl=lambda_ccl, ccl_version=ccl_version)
    elif loss_type == 'L1W+CCL':
        criterion = L1CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             weighted=True,
                             alpha=alpha,
                             gamma=gamma)
    elif loss_type == 'L1+CCL-gt':
        criterion = L1CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             ccl_target='gt',
                             weighted=False)
    elif loss_type == 'L1W+CCL-gt':
        criterion = L1CCLoss(lambda_ccl=lambda_ccl,
                             ccl_version=ccl_version,
                             ccl_target='gt',
                             weighted=True,
                             alpha=alpha,
                             gamma=gamma)
    else:
        raise NotImplementedError()

    if torch.cuda.is_available():
        model = model.cuda()
        criterion = criterion.cuda()

    if optimizer_name == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=lr,
                               weight_decay=regularization_l2)
    elif optimizer_name == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=lr,
                              weight_decay=regularization_l2,
                              momentum=0.9)
    else:
        raise NotImplementedError(f'Optimizer {optimizer_name} not available')
    if 'optimizer' in state:
        print('loading optimizer...')
        optimizer.load_state_dict(state['optimizer'])

    scaler = GradScaler(enabled=True)
    if 'scaler' in state:
        print('loading scaler...')
        scaler.load_state_dict(state['scaler'])

    def schedule(train_iter):
        if warmup and train_iter <= warmup:
            return 0.9 * train_iter / warmup + 0.1
        return 0.1**len([m for m in milestones if m <= train_iter])

    scheduler = LambdaLR(optimizer, schedule)
    if 'scheduler' in state:
        print('loading scheduler...')
        scheduler.load_state_dict(state['scheduler'])
    iteration = state.get('iteration', 0)

    if iteration >= iterations:
        print('Training already done.')
        return

    if train_segment_masks_path or val_segment_masks_path:
        trainset = ImagenetColorSegmentData(train_data_path,
                                            train_segment_masks_path,
                                            rgb_json=train_rgb_json,
                                            transform=None,
                                            transform_l=to_tensor_l,
                                            transform_ab=to_tensor_ab)
        testset = ImagenetColorSegmentData(
            val_data_path,
            val_segment_masks_path,
            rgb_json=val_rgb_json,
            transform=transforms.get_val_transform(1024),
            transform_l=to_tensor_l,
            transform_ab=to_tensor_ab)
    else:
        trainset = ImagenetData(train_data_path,
                                rgb_json=train_rgb_json,
                                transform=None,
                                transform_l=to_tensor_l,
                                transform_ab=to_tensor_ab)
        testset = ImagenetData(val_data_path,
                               rgb_json=val_rgb_json,
                               transform=transforms.get_val_transform(1024),
                               transform_l=to_tensor_l,
                               transform_ab=to_tensor_ab)

    trainset_infer = ImagenetData(train_data_path,
                                  rgb_json=train_rgb_json,
                                  transform=transforms.get_val_transform(1024),
                                  transform_l=to_tensor_l,
                                  transform_ab=to_tensor_ab,
                                  training=False)
    testset_infer = ImagenetData(val_data_path,
                                 rgb_json=val_rgb_json,
                                 transform=transforms.get_val_transform(1024),
                                 transform_l=to_tensor_l,
                                 transform_ab=to_tensor_ab,
                                 training=False)

    sampler = SavableShuffleSampler(trainset, shuffle=not debug)
    if 'sampler' in state:
        print('loading sampler...')
        sampler.load_state_dict(state['sampler'])

    if len(sampler) > len(trainset):
        sampler = SavableShuffleSampler(trainset, shuffle=not debug)
        print('recreate the sampler, trainset changed...')

    print(f'        Loss: {loss_type}')
    print(criterion)
    print(
        f'   Optimizer: {optimizer.__class__.__name__} (LR:{optimizer.param_groups[0]["lr"]:.6f})'
    )
    print(f'   Iteration: {iteration}/{iterations}')
    print(f'      Warmup: {warmup}')
    print(f'  Milestones: {milestones}')
    print(f'     Growing: {sparse_growing_parameters}')
    print(f'   Traindata: {len(trainset)} images')
    print(f'    Testdata: {len(testset)} images')
    print(f' Sampler idx: {sampler.index}')
    print(f'Current step: {scheduler._step_count}')

    batch_size, input_size = filled_growing_parameters[iteration]
    trainset.transform = transforms.get_transform(input_size[0])
    trainloader = get_trainloader(trainset, batch_size, sampler)

    running_psnr, img_per_sec = 0.0, 0.0
    running_loss, avg_running_loss = defaultdict(float), defaultdict(float)
    tic = time.time()
    changed_batch_size = True
    psnr = PSNR()
    pbar = tqdm(total=iterations, initial=iteration)

    if iteration == 0:
        for name, param in model.named_parameters():
            writer.add_histogram(name, param, global_step=iteration)

    while iteration < iterations:
        loss_str = ' - '.join(
            [f'{key}: {val:.5f} ' for key, val in avg_running_loss.items()])
        pbar.set_description(
            f'[Ep: {sampler.epoch} | B: {batch_size} | Im: {input_size[0]}x{input_size[1]}]  loss: {loss_str} - {img_per_sec:.2f} img/s'
        )
        for data in trainloader:
            if iteration in sparse_growing_parameters and not changed_batch_size:
                # change batch size and input size
                batch_size, input_size = sparse_growing_parameters[iteration]
                trainset.transform = transforms.get_transform(input_size[0])
                # recreate the loader, otherwise the transform is not propagated in multiprocessing to the workers
                trainloader = get_trainloader(trainset, batch_size, sampler)
                changed_batch_size = True
                break
            else:
                changed_batch_size = False

            if torch.cuda.is_available():
                data = tuple([el.cuda(non_blocking=True) for el in data])

            # get data
            if len(data) == 4:
                inputs, labels, segment_masks, _ = data

            else:
                inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            with autocast():
                outputs = model(inputs)
                crit_labels = [labels, segment_masks
                               ] if train_segment_masks_path else [labels]
                loss, loss_dict = criterion(outputs, *crit_labels)
                _psnr = psnr(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            scheduler.step()

            del outputs
            del inputs
            del labels
            del data

            # print statistics
            for k, v, in loss_dict.items():
                running_loss[k] += v.item()
            running_psnr += _psnr.item()

            iteration += 1

            if iteration % print_every == 0 or iteration == iterations:
                img_per_sec = print_every * batch_size / (time.time() - tic)

                for k, v in running_loss.items():
                    avg_running_loss[k] = running_loss[k] / print_every
                    writer.add_scalar(f'train/{k}',
                                      avg_running_loss[k],
                                      global_step=iteration)
                avg_running_psnr = running_psnr / print_every

                writer.add_scalar('train/PSNR',
                                  avg_running_psnr,
                                  global_step=iteration)

                writer.add_scalar('Performance/Images per second',
                                  img_per_sec,
                                  global_step=iteration)
                writer.add_scalar('Learning rate',
                                  optimizer.param_groups[0]['lr'],
                                  global_step=iteration)
                if loss_type in ['L1+CCL', 'L2+CCL']:
                    writer.add_scalar('Parameters/lambda CCL',
                                      lambda_ccl,
                                      global_step=iteration)
                loss_str = ' - '.join([
                    f'{key}: {val:.5} '
                    for key, val in avg_running_loss.items()
                ])
                pbar.set_description(
                    f'[Ep: {sampler.epoch} | B: {batch_size} | Im: {input_size[0]}x{input_size[1]}] loss: {loss_str} - {img_per_sec:.2f} img/s'
                )

                running_loss = defaultdict(float)
                running_psnr = 0.0
                state.update({
                    'iteration': iteration,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'scaler': scaler.state_dict(),
                    'sampler': sampler.state_dict()
                })

                model.save(state, iteration)
                delete_older_then_n(state['path'], 10)

                tic = time.time()
            if iteration == iterations or iteration % val_iterations == 0:
                # run validation
                torch.backends.cudnn.benchmark = False
                model = model.eval()
                test_loader = DataLoader(testset,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=8,
                                         pin_memory=True,
                                         prefetch_factor=1)
                with torch.no_grad():
                    metric_results = get_validation_metrics(
                        test_loader, model, criterion, ccl_version=ccl_version)
                for k, v in metric_results.items():
                    writer.add_scalar(f'validation/{k}',
                                      v,
                                      global_step=iteration)

                # images from validation
                predicted_images = infer(
                    model=model,
                    dataset=testset_infer,
                    target_path=os.path.join(model_dir,
                                             f'predictions-{iteration}'),
                    batch_size=1,
                    img_limit=20,
                    transform=transforms.get_val_transform(1024),
                    debug=True,
                    tensorboard=True)
                for i, img in enumerate(predicted_images):
                    writer.add_image(f'example-{i}',
                                     img,
                                     global_step=iteration,
                                     dataformats='HWC')

                # images from training
                predicted_images = infer(
                    model=model,
                    dataset=trainset_infer,
                    target_path=os.path.join(
                        model_dir, f'predictions-training-{iteration}'),
                    batch_size=1,
                    img_limit=20,
                    transform=transforms.get_val_transform(1024),
                    debug=True,
                    tensorboard=True)
                for i, img in enumerate(predicted_images):
                    writer.add_image(f'example-train-{i}',
                                     img,
                                     global_step=iteration,
                                     dataformats='HWC')

                for name, param in model.named_parameters():
                    writer.add_histogram(name, param, global_step=iteration)
                model = model.train()
                torch.backends.cudnn.benchmark = True
                tic = time.time()
            pbar.update(1)
            if iteration == iterations:
                break

    pbar.close()
    writer.close()
    print('Finished Training')
    def run(self):
        # Should be the same as segmentation, given customized loss classes
        self.model.train()
        epoch = 0
        running_loss = None  # Dict logging for every loss (too many losses in this task)
        loss_num_steps = int(len(self.dataloader) / 10) if len(self.dataloader) > 10 else 1
        if self._cfg['mixed_precision']:
            scaler = GradScaler()

        # Training
        best_validation = 0
        while epoch < self._cfg['num_epochs']:
            self.model.train()
            if self._cfg['distributed']:
                self.train_sampler.set_epoch(epoch)
            time_now = time.time()
            for i, data in enumerate(self.dataloader, 0):
                if self._cfg['seg']:
                    inputs, labels, existence = data
                    inputs, labels, existence = inputs.to(self.device), labels.to(self.device), existence.to(self.device)
                else:
                    inputs, labels = data
                    inputs = inputs.to(self.device)
                    if self._cfg['collate_fn'] is None:
                        labels = labels.to(self.device)
                    else:
                        labels = [{k: v.to(self.device) for k, v in label.items()} for label in labels]  # Seems slow
                self.optimizer.zero_grad()

                with autocast(self._cfg['mixed_precision']):
                    # To support intermediate losses for SAD
                    if self._cfg['seg']:
                        loss, log_dict = self.criterion(inputs, labels, existence,
                                                        self.model, self._cfg['input_size'])
                    else:
                        loss, log_dict = self.criterion(inputs, labels,
                                                        self.model)

                if self._cfg['mixed_precision']:
                    scaler.scale(loss).backward()
                    scaler.step(self.optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    self.optimizer.step()

                self.lr_scheduler.step()

                log_dict = reduce_dict(log_dict)
                if running_loss is None:  # Because different methods may have different values to log
                    running_loss = {k: 0.0 for k in log_dict.keys()}
                for k in log_dict.keys():
                    running_loss[k] += log_dict[k]
                current_step_num = int(epoch * len(self.dataloader) + i + 1)

                # Record losses
                if current_step_num % loss_num_steps == (loss_num_steps - 1):
                    for k in running_loss.keys():
                        print('[%d, %d] %s: %.4f' % (epoch + 1, i + 1, k, running_loss[k] / loss_num_steps))
                        # Logging only once
                        if is_main_process():
                            self.writer.add_scalar(k, running_loss[k] / loss_num_steps, current_step_num)
                        running_loss[k] = 0.0

                # Record checkpoints
                if self._cfg['validation']:
                    assert self._cfg['seg'], 'Only segmentation based methods can be fast evaluated!'
                    if current_step_num % self._cfg['val_num_steps'] == (self._cfg['val_num_steps'] - 1) or \
                            current_step_num == self._cfg['num_epochs'] * len(self.dataloader):
                        test_pixel_accuracy, test_mIoU = LaneDetTester.fast_evaluate(
                            loader=self.validation_loader,
                            device=self.device,
                            net=self.model,
                            num_classes=self._cfg['num_classes'],
                            output_size=self._cfg['input_size'],
                            mixed_precision=self._cfg['mixed_precision'])
                        if is_main_process():
                            self.writer.add_scalar('test pixel accuracy',
                                                   test_pixel_accuracy,
                                                   current_step_num)
                            self.writer.add_scalar('test mIoU',
                                                   test_mIoU,
                                                   current_step_num)
                        self.model.train()

                        # Record best model (straight to disk)
                        if test_mIoU > best_validation:
                            best_validation = test_mIoU
                            save_checkpoint(net=self.model.module if self._cfg['distributed'] else self.model,
                                            optimizer=None,
                                            lr_scheduler=None,
                                            filename=os.path.join(self._cfg['exp_dir'], 'model.pt'))

            epoch += 1
            print('Epoch time: %.2fs' % (time.time() - time_now))

        # For no-evaluation mode
        if not self._cfg['validation']:
            save_checkpoint(net=self.model.module if self._cfg['distributed'] else self.model,
                            optimizer=None,
                            lr_scheduler=None,
                            filename=os.path.join(self._cfg['exp_dir'], 'model.pt'))
예제 #21
0
                avg_loss_batch /= config.num_sub_heads
                avg_loss_no_lamb_batch /= config.num_sub_heads
                two_head_loss_list.append(avg_loss_batch)

        status = {
            "epoch": e_i,
            "batch": b_i,
            "head_A_loss": two_head_loss_list[0].item(),
            "head_B_loss": two_head_loss_list[1].item(),
        }

        avg_loss += avg_loss_batch.item()
        avg_loss_no_lamb += avg_loss_no_lamb_batch.item()
        avg_loss_count += 1

        scaler.scale(sum(two_head_loss_list) / 2).backward()
        scaler.step(optimiser)
        scaler.update()
        indicator.set_postfix(status)
        b_i += 1

        avg_loss = float(avg_loss / avg_loss_count)
        avg_loss_no_lamb = float(avg_loss_no_lamb / avg_loss_count)

        epoch_loss.append(avg_loss)
        epoch_loss_no_lamb.append(avg_loss_no_lamb)
    indicator.close()

    # Eval -----------------------------------------------------------------------

    # Can also pick the subhead using the evaluation process (to do this,
예제 #22
0
def train_one_epoch(model: torch.nn.Module,
                    criterion: torch.nn.Module,
                    scaler: amp.GradScaler,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    max_norm: float = 0):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter(
        'lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter(
        'class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    for samples, targets in metric_logger.log_every(data_loader, print_freq,
                                                    header):
        # import ipdb; ipdb.set_trace()
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # outputs = model(samples)
        with amp.autocast(enabled=scaler.is_enabled()):
            outputs = model(samples)
        outputs = to_fp32(outputs) if scaler.is_enabled() else outputs
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys()
                     if k in weight_dict)

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {
            f'{k}_unscaled': v
            for k, v in loss_dict_reduced.items()
        }
        loss_dict_reduced_scaled = {
            k: v * weight_dict[k]
            for k, v in loss_dict_reduced.items() if k in weight_dict
        }
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        # losses.backward()
        scaler.scale(losses).backward()
        if max_norm > 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        # optimizer.step()
        scaler.step(optimizer)
        scaler.update()

        metric_logger.update(loss=loss_value,
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
예제 #23
0
파일: __init__.py 프로젝트: isi-nlp/rtg
class DistribTorch:

    host_name: str = socket.gethostname()
    pid: int = os.getpid()
    global_rank: int = int(get_env('RANK', '-1'))
    local_rank: int = int(get_env('LOCAL_RANK', '-1'))
    world_size: int = int(get_env('WORLD_SIZE', '-1'))
    master_addr: str = get_env('MASTER_ADDR', '')
    master_port: int = int(get_env('MASTER_PORT', '-1'))

    gpu_count: int = torch.cuda.device_count()
    visible_devices: str = get_env('CUDA_VISIBLE_DEVICES', '')
    max_norm = 10
    fp16 = False  # Manually enable by calling enable_fp16()

    _scaler = None
    _is_backend_ready = False
    # singleton instance; lazy initialization
    _instance: ClassVar['DistribTorch'] = None
    _model: nn.Module = None

    def setup(self):
        log.info("DistribTorch setup()")
        if self.world_size > 1:
            assert self.global_rank >= 0
            assert self.local_rank >= 0
            assert self.master_addr
            assert self.master_port > 1024
            backend = 'nccl' if self.gpu_count > 0 else 'gloo'
            log.info(
                f"Initializing PyTorch distributed with '{backend}' backend:\n {self}"
            )
            torch.distributed.init_process_group(init_method='env://',
                                                 backend=backend)
            self._is_backend_ready = True
        return self

    def enable_fp16(self):
        if not self.fp16:  # conditional import
            self.fp16 = True
            self._scaler = GradScaler(enabled=self.fp16)
            log.info("Enabling FP16  /Automatic Mixed Precision training")
        else:
            log.warning(" fp16 is already enabled")

    def close(self):
        if self._is_backend_ready:
            log.warning("destroying distributed backend")
            torch.distributed.destroy_process_group()
            self._is_backend_ready = False

    @classmethod
    def instance(cls) -> 'DistribTorch':
        """
        :return: gets singleton instance of class, lazily initialized
        """
        if not cls._instance:
            cls._instance = cls()
        return cls._instance

    def maybe_distributed(self, module: nn.Module):
        if self.world_size > 1:
            if not self._is_backend_ready:
                self.setup()
            self._model = module
            #return torch.nn.parallel.DistributedDataParallel(module)
        return module  # don't wrap

    @property
    def is_distributed(self):
        return self.world_size > 1

    @property
    def is_global_main(self) -> bool:
        return self.global_rank <= 0

    @property
    def is_local_main(self) -> bool:
        return self.local_rank <= 0

    def barrier(self):
        if self.is_distributed:
            torch.distributed.barrier()
        # else we dont need it

    def backward(self, loss):
        if torch.isnan(loss):
            log.warning('loss is nan; backward() skipped')
            return
        if self.fp16:
            loss = self._scaler.scale(loss)
            # to apply norm: TODO: unscale gradients ; refer to docs
            # torch.nn.utils.clip_grad_norm_(self._amp.master_params(opt.optimizer), self.max_norm)
        loss.backward()

    def average_gradients(self, model):
        size = float(self.world_size)
        for param in model.parameters():
            dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
            # TODO: ring reduce https://pytorch.org/tutorials/intermediate/dist_tuto.html#our-own-ring-allreduce
            param.grad.data /= size

    def step(self, optimizer: Optimizer):
        if self.is_distributed:
            self.average_gradients(self._model)
            #TODO: Maybe we dont need to average every step ?
        if self.fp16:
            self._scaler.step(optimizer)
            self._scaler.update()
        else:
            optimizer.step()
        optimizer.zero_grad()
예제 #24
0
class BaseRunner(object):
    """
    Re-ID Base Runner
    """
    def __init__(
        self,
        cfg,
        model,
        optimizer,
        criterions,
        train_loader,
        train_sets=None,
        lr_scheduler=None,
        meter_formats=None,
        print_freq=10,
        reset_optim=True,
        label_generator=None,
    ):
        super(BaseRunner, self).__init__()
        # set_random_seed(cfg.TRAIN.seed, cfg.TRAIN.deterministic)

        if meter_formats is None:
            meter_formats = {"Time": ":.3f", "Acc@1": ":.2%"}

        self.cfg = cfg
        self.model = model
        self.optimizer = optimizer
        self.criterions = criterions
        self.lr_scheduler = lr_scheduler
        self.print_freq = print_freq
        self.reset_optim = reset_optim
        self.label_generator = label_generator

        self.is_pseudo = ("PSEUDO_LABELS" in self.cfg.TRAIN
                          and self.cfg.TRAIN.unsup_dataset_indexes is not None)
        if self.is_pseudo:
            if self.label_generator is None:
                self.label_generator = LabelGenerator(self.cfg, self.model)

        self._rank, self._world_size, self._is_dist = get_dist_info()
        self._epoch, self._start_epoch = 0, 0
        self._best_mAP = 0

        # build data loaders
        self.train_loader, self.train_sets = train_loader, train_sets
        if "val_dataset" in self.cfg.TRAIN:
            self.val_loader, self.val_set = build_val_dataloader(cfg)

        # save training variables
        for key in criterions.keys():
            meter_formats[key] = ":.3f"
        self.train_progress = Meters(meter_formats,
                                     self.cfg.TRAIN.iters,
                                     prefix="Train: ")

        # build mixed precision scaler
        if "amp" in cfg.TRAIN:
            global amp_support
            if cfg.TRAIN.amp and amp_support:
                assert not isinstance(model, DataParallel), \
                    "We do not support mixed precision training with DataParallel currently"
                self.scaler = GradScaler()
            else:
                if cfg.TRAIN.amp:
                    warnings.warn(
                        "Please update the PyTorch version (>=1.6) to support mixed precision training"
                    )
                self.scaler = None
        else:
            self.scaler = None

    def run(self):
        # the whole process for training
        for ep in range(self._start_epoch, self.cfg.TRAIN.epochs):
            self._epoch = ep

            # generate pseudo labels
            if self.is_pseudo:
                if (ep % self.cfg.TRAIN.PSEUDO_LABELS.freq == 0
                        or ep == self._start_epoch):
                    self.update_labels()
                    synchronize()

            # train
            self.train()
            synchronize()

            # validate
            if (ep + 1) % self.cfg.TRAIN.val_freq == 0 or (
                    ep + 1) == self.cfg.TRAIN.epochs:
                if "val_dataset" in self.cfg.TRAIN:
                    mAP = self.val()
                    self.save(mAP)
                else:
                    self.save()

            # update learning rate
            if self.lr_scheduler is not None:
                if isinstance(self.lr_scheduler, list):
                    for scheduler in self.lr_scheduler:
                        scheduler.step()
                elif isinstance(self.lr_scheduler, dict):
                    for key in self.lr_scheduler.keys():
                        self.lr_scheduler[key].step()
                else:
                    self.lr_scheduler.step()

            # synchronize distributed processes
            synchronize()

    def update_labels(self):
        sep = "*************************"
        print(
            f"\n{sep} Start updating pseudo labels on epoch {self._epoch} {sep}\n"
        )

        # generate pseudo labels
        pseudo_labels, label_centers = self.label_generator(
            self._epoch, print_freq=self.print_freq)

        # update train loader
        self.train_loader, self.train_sets = build_train_dataloader(
            self.cfg,
            pseudo_labels,
            self.train_sets,
            self._epoch,
        )

        # update criterions
        if "cross_entropy" in self.criterions.keys():
            self.criterions[
                "cross_entropy"].num_classes = self.train_loader.loader.dataset.num_pids

        # reset optim (optional)
        if self.reset_optim:
            self.optimizer.state = collections.defaultdict(dict)

        # update classifier centers
        start_cls_id = 0
        for idx in range(len(self.cfg.TRAIN.datasets)):
            if idx in self.cfg.TRAIN.unsup_dataset_indexes:
                labels = torch.arange(
                    start_cls_id, start_cls_id + self.train_sets[idx].num_pids)
                centers = label_centers[
                    self.cfg.TRAIN.unsup_dataset_indexes.index(idx)]
                if isinstance(self.model, list):
                    for model in self.model:
                        if isinstance(model,
                                      (DataParallel, DistributedDataParallel)):
                            model = model.module
                        model.initialize_centers(centers, labels)
                else:
                    model = self.model
                    if isinstance(model,
                                  (DataParallel, DistributedDataParallel)):
                        model = model.module
                    model.initialize_centers(centers, labels)
            start_cls_id += self.train_sets[idx].num_pids

        print(f"\n{sep} Finished updating pseudo label {sep}n")

    def train(self):
        # one loop for training
        if isinstance(self.model, list):
            for model in self.model:
                model.train()
        elif isinstance(self.model, dict):
            for key in self.model.keys():
                self.model[key].train()
        else:
            self.model.train()

        self.train_progress.reset(prefix="Epoch: [{}]".format(self._epoch))

        if isinstance(self.train_loader, list):
            for loader in self.train_loader:
                loader.new_epoch(self._epoch)
        else:
            self.train_loader.new_epoch(self._epoch)

        end = time.time()
        for iter in range(self.cfg.TRAIN.iters):

            if isinstance(self.train_loader, list):
                batch = [loader.next() for loader in self.train_loader]
            else:
                batch = self.train_loader.next()
            # self.train_progress.update({'Data': time.time()-end})

            if self.scaler is None:
                loss = self.train_step(iter, batch)
                if (loss > 0):
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

            else:
                with autocast():
                    loss = self.train_step(iter, batch)
                if (loss > 0):
                    self.optimizer.zero_grad()
                    self.scaler.scale(loss).backward()
                    self.scaler.step(self.optimizer)

            if self.scaler is not None:
                self.scaler.update()

            self.train_progress.update({"Time": time.time() - end})
            end = time.time()

            if iter % self.print_freq == 0:
                self.train_progress.display(iter)

    def train_step(self, iter, batch):
        # need to be re-written case by case
        assert not isinstance(
            self.model,
            list), "please re-write 'train_step()' to support list of models"

        data = batch_processor(batch, self.cfg.MODEL.dsbn)
        if len(data["img"]) > 1:
            warnings.warn(
                "please re-write the 'runner.train_step()' function to make use of "
                "mutual transformer.")

        inputs = data["img"][0].cuda()
        targets = data["id"].cuda()

        results = self.model(inputs)
        if "prob" in results.keys():
            results["prob"] = results["prob"][:, :self.train_loader.loader.
                                              dataset.num_pids]

        total_loss = 0
        meters = {}
        for key in self.criterions.keys():
            loss = self.criterions[key](results, targets)
            total_loss += loss * float(self.cfg.TRAIN.LOSS.losses[key])
            meters[key] = loss.item()

        if "prob" in results.keys():
            acc = accuracy(results["prob"].data, targets.data)
            meters["Acc@1"] = acc[0]

        self.train_progress.update(meters)

        return total_loss

    def val(self):
        if not isinstance(self.model, list):
            model_list = [self.model]
        else:
            model_list = self.model

        better_mAP = 0
        for idx in range(len(model_list)):
            if len(model_list) > 1:
                print("==> Val on the no.{} model".format(idx))
            cmc, mAP = val_reid(
                self.cfg,
                model_list[idx],
                self.val_loader[0],
                self.val_set[0],
                self._epoch,
                self.cfg.TRAIN.val_dataset,
                self._rank,
                print_freq=self.print_freq,
            )
            better_mAP = max(better_mAP, mAP)

        return better_mAP

    def save(self, mAP=None):
        if mAP is not None:
            is_best = mAP > self._best_mAP
            self._best_mAP = max(self._best_mAP, mAP)
            print(bcolors.OKGREEN +
                  "\n * Finished epoch {:3d}  mAP: {:5.1%}  best: {:5.1%}{}\n".
                  format(self._epoch, mAP, self._best_mAP,
                         " *" if is_best else "") + bcolors.ENDC)
        else:
            is_best = True
            print(bcolors.OKGREEN +
                  "\n * Finished epoch {:3d} \n".format(self._epoch) +
                  bcolors.ENDC)

        if self._rank == 0:
            # only on cuda:0
            self.save_model(is_best, self.cfg.work_dir)

    def save_model(self, is_best, fpath):
        if isinstance(self.model, list):
            state_dict = {}
            state_dict["epoch"] = self._epoch + 1
            state_dict["best_mAP"] = self._best_mAP
            for idx, model in enumerate(self.model):
                state_dict["state_dict_" + str(idx + 1)] = model.state_dict()
            save_checkpoint(state_dict,
                            is_best,
                            fpath=osp.join(fpath, "checkpoint.pth"))

        elif isinstance(self.model, dict):
            state_dict = {}
            state_dict["epoch"] = self._epoch + 1
            state_dict["best_mAP"] = self._best_mAP
            for key in self.model.keys():
                state_dict["state_dict"] = self.model[key].state_dict()
                save_checkpoint(state_dict,
                                False,
                                fpath=osp.join(fpath,
                                               "checkpoint_" + key + ".pth"))

        elif isinstance(self.model, nn.Module):
            state_dict = {}
            state_dict["epoch"] = self._epoch + 1
            state_dict["best_mAP"] = self._best_mAP
            state_dict["state_dict"] = self.model.state_dict()
            save_checkpoint(state_dict,
                            is_best,
                            fpath=osp.join(fpath, "checkpoint.pth"))

        else:
            assert "Unknown type of model for save_model()"

    def resume(self, path):
        # resume from a training checkpoint (not source pretrain)
        self.load_model(path)
        synchronize()

    def load_model(self, path):
        if isinstance(self.model, list):
            assert osp.isfile(path)
            state_dict = load_checkpoint(path)
            for idx, model in enumerate(self.model):
                copy_state_dict(state_dict["state_dict_" + str(idx + 1)],
                                model)

        elif isinstance(self.model, dict):
            assert osp.isdir(path)
            for key in self.model.keys():
                state_dict = load_checkpoint(
                    osp.join(path, "checkpoint_" + key + ".pth"))
                copy_state_dict(state_dict["state_dict"], self.model[key])

        elif isinstance(self.model, nn.Module):
            assert osp.isfile(path)
            state_dict = load_checkpoint(path)
            copy_state_dict(state_dict["state_dict"], self.model)

        self._start_epoch = state_dict["epoch"]
        self._best_mAP = state_dict["best_mAP"]

    @property
    def epoch(self):
        """int: Current epoch."""
        return self._epoch

    @property
    def rank(self):
        """int: Rank of current process. (distributed training)"""
        return self._rank

    @property
    def world_size(self):
        """int: Number of processes participating in the job.
        (distributed training)"""
        return self._world_size
예제 #25
0
파일: utils.py 프로젝트: STomoya/animeface
def train(
    max_iters, dataset, sampler, latent_dim,
    G, G_ema, D, optimizer_G, optimizer_D,
    num_tags, ema_decay,
    recons_lambda, style_lambda, feat_lambda,
    amp, device, save
):
    
    status = Status(max_iters)
    scaler = GradScaler() if amp else None
    loss = LSGANLoss()
    _refs = [None]*len(num_tags)

    def check_d_output(output):
        if isinstance(output, tuple):
            return output[0], output[1]
        else:
            return output, None
    
    while status.batches_done < max_iters:
        i, j = random_ij(num_tags)
        _, j_ = random_ij(num_tags, (i, j))
        real = dataset.sample(i, j)
        real = real.to(device)
        z    = sampler((real.size(0), latent_dim))
        # print(real.size(0), z.size(0))

        optimizer_G.zero_grad()
        optimizer_D.zero_grad()

        '''generate images'''
        with autocast(amp):
            # reconstruct
            recons = G(real)
            # reconstruct with style code of itself
            refs = copy.copy(_refs)
            refs[i] = (real, j)
            recons_self_trans = G(real, refs)
            # fake
            refs[i] = (z, j_)
            fake = G(real, refs)
            # reconstruct from fake
            refs[i] = (real, j)
            recons_fake_trans = G(fake, refs)

            '''Discriminator'''
            # D(real)
            output = D(real, i, j)
            real_prob, real_feat = check_d_output(output)
            # D(fake)
            output = D(fake.detach(), i, j_)
            fake_prob, fake_feat = check_d_output(output)
            # D(G(fake))
            output = D(recons_fake_trans.detach(), i, j)
            recons_prob, recons_feat = check_d_output(output)

            # loss
            D_loss = loss.d_loss(real_prob[:, 0], fake_prob[:, 0]) \
                + loss.d_loss(real_prob[:, 1], recons_prob[:, 1])
            if feat_lambda > 0 and real_feat is not None:
                D_loss = D_loss \
                    + feature_matching(real_feat, fake_feat) * feat_lambda

        if scaler is not None:
            scaler.scale(D_loss).backward()
            scaler.step(optimizer_D)
        else:
            D_loss.backward()
            optimizer_D.step()

        '''Generator'''
        with autocast(amp):
            # D(fake)
            output = D(fake, i, j_)
            fake_prob, fake_feat = check_d_output(output)
            # D(G(fake))
            output = D(recons_fake_trans, i, j)
            recons_prob, recons_feat = check_d_output(output)
            # style codes
            style_j_ = G.category_modules[i].map(z, j_)
            style_fake = G.category_modules[i].extract(fake, j_)

            # loss
            G_loss = loss.g_loss(fake_prob[:, 0]) + loss.g_loss(recons_prob[:, 1])
            G_loss = G_loss \
                + l1_loss(style_j_, style_fake) * style_lambda \
                + (l1_loss(recons, real) \
                    + l1_loss(recons_self_trans, real) \
                    + l1_loss(recons_fake_trans, real)) * recons_lambda

        if scaler is not None:
            scaler.scale(G_loss).backward()
            scaler.step(optimizer_G)
        else:
            G_loss.backward()
            optimizer_G.step()

        update_ema(G, G_ema, ema_decay)

        # save
        if status.batches_done % save == 0:
            with torch.no_grad():
                refs[i] = (z, j_)
                fake = G_ema(real, refs)
            images = _image_grid(real, fake)
            save_image(images, f'implementations/HiSD/result/{status.batches_done}_tag{i}_{j}to{j_}.jpg',
                nrow=4, normalize=True, value_range=(-1, 1))
            torch.save(G.state_dict(), f'implementations/HiSD/result/G_{status.batches_done}.pt',)
        save_image(fake, f'running.jpg', nrow=4, normalize=True, value_range=(-1, 1))

        # updates
        loss_dict = dict(
            G=G_loss.item() if not torch.isnan(G_loss).any() else 0,
            D=D_loss.item() if not torch.isnan(D_loss).any() else 0
        )
        status.update(loss_dict)
        if scaler is not None:
            scaler.update()
        
    status.plot()
예제 #26
0
def pruning():
    # Training DataLoader
    dataset_train = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'],
                          mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'],
                          mode='RGB'),
        ],
                   transforms=A.PairCompose([
                       A.PairRandomAffineAndResize((512, 512),
                                                   degrees=(-5, 5),
                                                   translate=(0.1, 0.1),
                                                   scale=(0.4, 1),
                                                   shear=(-5, 5)),
                       A.PairRandomHorizontalFlip(),
                       A.PairRandomBoxBlur(0.1, 5),
                       A.PairRandomSharpen(0.1),
                       A.PairApplyOnlyAtIndices([1],
                                                T.ColorJitter(
                                                    0.15, 0.15, 0.15, 0.05)),
                       A.PairApply(T.ToTensor())
                   ]),
                   assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['train'],
                      mode='RGB',
                      transforms=T.Compose([
                          A.RandomAffineAndResize((512, 512),
                                                  degrees=(-5, 5),
                                                  translate=(0.1, 0.1),
                                                  scale=(1, 2),
                                                  shear=(-5, 5)),
                          T.RandomHorizontalFlip(),
                          A.RandomBoxBlur(0.1, 5),
                          A.RandomSharpen(0.1),
                          T.ColorJitter(0.15, 0.15, 0.15, 0.05),
                          T.ToTensor()
                      ])),
    ])
    dataloader_train = DataLoader(dataset_train,
                                  shuffle=True,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    # Validation DataLoader
    dataset_valid = ZipDataset([
        ZipDataset([
            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'],
                          mode='L'),
            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'],
                          mode='RGB')
        ],
                   transforms=A.PairCompose([
                       A.PairRandomAffineAndResize((512, 512),
                                                   degrees=(-5, 5),
                                                   translate=(0.1, 0.1),
                                                   scale=(0.3, 1),
                                                   shear=(-5, 5)),
                       A.PairApply(T.ToTensor())
                   ]),
                   assert_equal_length=True),
        ImagesDataset(DATA_PATH['backgrounds']['valid'],
                      mode='RGB',
                      transforms=T.Compose([
                          A.RandomAffineAndResize((512, 512),
                                                  degrees=(-5, 5),
                                                  translate=(0.1, 0.1),
                                                  scale=(1, 1.2),
                                                  shear=(-5, 5)),
                          T.ToTensor()
                      ])),
    ])
    dataset_valid = SampleDataset(dataset_valid, 50)
    dataloader_valid = DataLoader(dataset_valid,
                                  pin_memory=True,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers)

    # Model
    model = MattingBase(args.model_backbone).cuda()

    if args.model_last_checkpoint is not None:
        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
    elif args.model_pretrain_initialization is not None:
        model.load_pretrained_deeplabv3_state_dict(
            torch.load(args.model_pretrain_initialization)['model_state'])

    # 打印初试稀疏率
    # for name, module in model.named_modules():
    #     # prune 10% of connections in all 2D-conv layers
    #     if isinstance(module, torch.nn.Conv2d):
    #         # DNSUnst(module, name='weight')
    #         prune.l1_unstructured(module, name='weight', amount=0.4)
    #         prune.remove(module, 'weight')
    print("the original sparsity: ", get_sparsity(model))

    optimizer = Adam([{
        'params': model.backbone.parameters(),
        'lr': 1e-4
    }, {
        'params': model.aspp.parameters(),
        'lr': 5e-4
    }, {
        'params': model.decoder.parameters(),
        'lr': 5e-4
    }])
    scaler = GradScaler()

    # Logging and checkpoints
    if not os.path.exists(f'checkpoint/{args.model_name}'):
        os.makedirs(f'checkpoint/{args.model_name}')
    writer = SummaryWriter(f'log/{args.model_name}')

    # Run loop
    for epoch in range(args.epoch_start, args.epoch_end):
        for i, ((true_pha, true_fgr),
                true_bgr) in enumerate(tqdm(dataloader_train)):

            step = epoch * len(dataloader_train) + i

            true_pha = true_pha.cuda(non_blocking=True)
            true_fgr = true_fgr.cuda(non_blocking=True)
            true_bgr = true_bgr.cuda(non_blocking=True)
            true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr,
                                                       true_bgr)

            true_src = true_bgr.clone()

            # Augment with shadow
            aug_shadow_idx = torch.rand(len(true_src)) < 0.3
            if aug_shadow_idx.any():
                aug_shadow = true_pha[aug_shadow_idx].mul(0.3 *
                                                          random.random())
                aug_shadow = T.RandomAffine(degrees=(-5, 5),
                                            translate=(0.2, 0.2),
                                            scale=(0.5, 1.5),
                                            shear=(-5, 5))(aug_shadow)
                aug_shadow = kornia.filters.box_blur(
                    aug_shadow, (random.choice(range(20, 40)), ) * 2)
                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(
                    aug_shadow).clamp_(0, 1)
                del aug_shadow
            del aug_shadow_idx

            # Composite foreground onto source
            true_src = true_fgr * true_pha + true_src * (1 - true_pha)

            # Augment with noise
            aug_noise_idx = torch.rand(len(true_src)) < 0.4
            if aug_noise_idx.any():
                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(
                    torch.randn_like(true_src[aug_noise_idx]).mul_(
                        0.03 * random.random())).clamp_(0, 1)
                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(
                    torch.randn_like(true_bgr[aug_noise_idx]).mul_(
                        0.03 * random.random())).clamp_(0, 1)
            del aug_noise_idx

            # Augment background with jitter
            aug_jitter_idx = torch.rand(len(true_src)) < 0.8
            if aug_jitter_idx.any():
                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(
                    0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
            del aug_jitter_idx

            # Augment background with affine
            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
            if aug_affine_idx.any():
                true_bgr[aug_affine_idx] = T.RandomAffine(
                    degrees=(-1, 1),
                    translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
            del aug_affine_idx

            with autocast():
                pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
                loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha,
                                    true_fgr)

            scaler.scale(loss).backward()

            # 剪枝
            best_c = np.zeros(187)
            if i == 0:
                ncs = NCS_C(model, true_src, true_bgr, true_pha, true_fgr)
                best_c = ncs.run(model, true_src, true_bgr, true_pha, true_fgr)
                PreTPUnst(model, best_c)
            else:
                # 调整
                PreDNSUnst(model, best_c)

            scaler.step(optimizer)
            Pruned(model)

            scaler.update()
            optimizer.zero_grad()

            if (i + 1) % args.log_train_loss_interval == 0:
                writer.add_scalar('loss', loss, step)

            if (i + 1) % args.log_train_images_interval == 0:
                writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5),
                                 step)
                writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5),
                                 step)
                writer.add_image('train_pred_com',
                                 make_grid(pred_fgr * pred_pha, nrow=5), step)
                writer.add_image('train_pred_err', make_grid(pred_err, nrow=5),
                                 step)
                writer.add_image('train_true_src', make_grid(true_src, nrow=5),
                                 step)
                writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5),
                                 step)

            del true_pha, true_fgr, true_bgr, true_src
            del pred_pha, pred_fgr, pred_err
            del loss
            del best_c

            if (i + 1) % args.log_valid_interval == 0:
                valid(model, dataloader_valid, writer, step)

            if (step + 1) % args.checkpoint_interval == 0:
                torch.save(
                    model.state_dict(),
                    f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth'
                )

        print("the sparsity of epoch {} : {}".format(epoch,
                                                     get_sparsity(model)))
        torch.save(model.state_dict(),
                   f'checkpoint/{args.model_name}/epoch-{epoch}.pth')

    # 打印最终的稀疏度
    print("the final sparsity: ", get_sparsity(model))
예제 #27
0
    def train(self, train_loader):

        scaler = GradScaler(enabled=self.args.fp16_precision)

        # save config file
        save_config_file(self.writer.log_dir, self.args)

        n_iter = 0
        logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
        logging.info(f"Training with gpu: {self.args.disable_cuda}.")

        for epoch_counter in range(self.args.epochs):
            for images, _ in tqdm(train_loader):
                images = torch.cat(images, dim=0)

                images = images.to(self.args.device)

                with autocast(enabled=self.args.fp16_precision):
                    features = self.model(images)
                    logits, labels = self.info_nce_loss(features)
                    loss = self.criterion(logits, labels)

                self.optimizer.zero_grad()

                scaler.scale(loss).backward()

                scaler.step(self.optimizer)
                scaler.update()

                if n_iter % self.args.log_every_n_steps == 0:
                    top1, top5 = accuracy(logits, labels, topk=(1, 5))
                    self.writer.add_scalar('loss', loss, global_step=n_iter)
                    self.writer.add_scalar('acc/top1',
                                           top1[0],
                                           global_step=n_iter)
                    self.writer.add_scalar('acc/top5',
                                           top5[0],
                                           global_step=n_iter)
                    self.writer.add_scalar('learning_rate',
                                           self.scheduler.get_lr()[0],
                                           global_step=n_iter)

                n_iter += 1

            # warmup for the first 10 epochs
            if epoch_counter >= 10:
                self.scheduler.step()
            logging.debug(
                f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}"
            )

        logging.info("Training has finished.")
        # save model checkpoints
        checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
        save_checkpoint(
            {
                'epoch': self.args.epochs,
                'arch': self.args.arch,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
            },
            is_best=False,
            filename=os.path.join(self.writer.log_dir, checkpoint_name))
        logging.info(
            f"Model checkpoint and metadata has been saved at {self.writer.log_dir}."
        )
예제 #28
0
def train(args):
    # torch.multiprocessing.set_sharing_strategy('file_system')
    # too many barriers / one node data parallel and multiple node DDP
    os.environ['MASTER_ADDR'] = args["master_addr"]
    os.environ['MASTER_PORT'] = args["master_port"]
    os.environ['TOKENIZERS_PARALLELISM'] = "true"
    torch.backends.cudnn.benchmark = True
    rank = args["nr"]
    gpus = args["gpus_per_node"]
    if args["cpu"]:
        assert args["world_size"] == 1
        device = torch.device("cpu")
        barrier = get_barrier(False)
    else:
        dist.init_process_group(args["dist_backend"], rank=rank, world_size=args["world_size"])
        device = torch.device('cuda:0')  # Unique only on individual node.
        torch.cuda.set_device(device)
        barrier = get_barrier(True)

    set_seeds(args["seed"])
    mconf = model_config.to_dict()
    config = dict(md_config=md_config, sm_config=sm_config)[mconf.pop("model_size")]
    tokenizer = get_tokenizer(mconf.pop("tokenizer_name"))
    config.vocab_size = len(tokenizer) + 22
    config.tokenizer_length = 1024
    config.tokenizer_length = config.tokenizer_length - config.num_highway_cls_tokens
    config.max_position_embeddings = config.max_position_embeddings + config.num_highway_cls_tokens

    collate_fn = get_collate_fn(config.num_highway_cls_tokens, tokenizer.pad_token_id)

    model = FastFormerForFusedELECTRAPretraining(config, tokenizer=tokenizer, **mconf).to(device)
    print("Trainable Params = %s" % (numel(model) / 1_000_000))
    if args["pretrained_model"] is not None:
        model.load_state_dict(torch.load(args["pretrained_model"], map_location={'cuda:%d' % 0: 'cuda:%d' % 0}))
    model.data_parallel = True
    # Take model to local rank
    if args["cpu"]:
        ddp_model = model
    else:
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        ddp_model = DDP(model, device_ids=[0], find_unused_parameters=True)
    all_params = list(filter(lambda p: p.requires_grad, ddp_model.parameters()))
    optc = optimizer_config.to_dict()
    optimizer = AdamW(all_params, lr=optc["lr"], eps=optc["eps"], weight_decay=optc["weight_decay"], betas=(optc["beta_1"], optc["beta_2"]))
    optimizer.zero_grad()
    scaler = GradScaler()

    model_save_dir = args["model_save_dir"]
    model_save_name = args["model_save_name"]
    if rank == 0:
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
    assert os.path.exists(model_save_dir)
    barrier()
    print("Optimizer Created for Rank = %s" % rank)
    shuffle_dataset = args["shuffle_dataset"]
    sampling_fraction = optc["sampling_fraction"]
    if not args["validate_only"] and not args["test_only"]:
        train_loader = build_dataloader(args["train_dataset"], shuffle_dataset, sampling_fraction, config, collate_fn, tokenizer, world_size=args["world_size"], num_workers=args["num_workers"])

    print("Data Loaded for Rank = %s" % rank)
    validate_every_steps = args["validate_every_steps"]
    log_every_steps = args["log_every_steps"]
    save_every_steps = args["save_every_steps"]
    scheduler = optimization.get_constant_schedule_with_warmup(optimizer, optc["warmup_steps"])
    gradient_clipping = optc["gradient_clipping"]
    _ = model.train()
    barrier()

    start_time = time.time()
    batch_times = []
    model_times = []
    full_times = []
    print("Start Training for Rank = %s" % rank)
    for step, batch in enumerate(train_loader):
        model.zero_grad()
        optimizer.zero_grad()
        if step == 0:
            print("First Batch Training for Rank = %s" % rank)
        # if step <= 39:
        #     continue
        gen_batch_time = time.time() - start_time
        batch_times.append(gen_batch_time)
        if (step + 1) % save_every_steps == 0:
            if rank == 0:
                torch.save(ddp_model.state_dict(), os.path.join(model_save_dir, model_save_name))
            barrier()
        if (step + 1) % validate_every_steps == 0:
            if rank == 0:
                val_results = LargeValidator(args["validation_dataset"], ddp_model, config, device, tokenizer)()
                print("Rank = %s, steps = %s, Val = %s" % (rank, step, val_results))
            barrier()
        record_accuracy = False
        if (step + 1) % log_every_steps == 0:
            record_accuracy = True

        batch["record_accuracy"] = record_accuracy
        labels = batch["label_mlm_input_ids"] if "label_mlm_input_ids" in batch else batch["input_ids"]
        labels = labels.to(device)
        model_start_time = time.time()
        if args["cpu"]:
            output = ddp_model(**batch, labels=labels)
            output = {key: [item[key] for item in output]
                      for key in list(functools.reduce(
                    lambda x, y: x.union(y),
                    (set(dicts.keys()) for dicts in output)
                ))
                      }
            output = {k: torch.mean(v) for k, v in output.items()}
            loss = output["loss"]
            loss_dict = output["loss_dict"]
            loss.backward()
            torch.nn.utils.clip_grad_norm_(all_params, gradient_clipping)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        else:
            with autocast():

                output = ddp_model(**batch, labels=labels)
                output = {key: [item[key] for item in output]
                          for key in list(functools.reduce(
                        lambda x, y: x.union(y),
                        (set(dicts.keys()) for dicts in output)
                    ))
                          }
                output = {k: torch.mean(v) for k, v in output.items()}
                loss = output["loss"]
                loss_dict = output["loss_dict"]
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(all_params, gradient_clipping)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
        model_end_time = time.time() - model_start_time
        model_times.append(model_end_time)
        full_time = time.time() - start_time
        full_times.append(full_time)
        start_time = time.time()
        if (step + 1) % log_every_steps == 0:
            print("Rank = %s, steps = %s, batch_size = %s, Loss = %s, Accuracy = %s" % (rank, step, batch["input_ids"].size(), loss_dict, output["accuracy_hist"]))
            print("Batch time = %s, Model Time = %s, Full time = %s" % (np.mean(batch_times), np.mean(model_times), np.mean(full_times)))
            batch_times = []
            model_times = []
            full_times = []
            clean_memory()
            barrier()



    # Take inputs to local_rank

    # TODO: validate on multigpu, sort the val datasets alphabetically and let the gpu with rank == dataset rank in sort pick up the dataset. GPUs with rank > len(datasetDict) stay idle.
    # TODO: select one dataset and make full batch from it, this way rebalancing can be easy.
    # TODO: dataset rebalancing.
    # TODO: save model only in local_rank == 0 process
    # TODO: Check if all initialised model weights are same??
    # I've been tracking an ema of sample training loss during training and using that to guide weighted data sampling (rather than the typical uniform sampling). Seems to help with a variety of real world datasets where the bulk of the data is often very similar and easy to learn but certain subpopulations are much more challenging.

    pass
예제 #29
0
class Agent(object):
    def __init__(self,
                 input_dims,
                 n_actions,
                 env,
                 epsilon=1.0,
                 batch_size=32,
                 eps_dec=4.5e-7,
                 replace=1000,
                 nheads=4,
                 gamma=0.99,
                 capacity=100000,
                 transformer_layers=1,
                 lr=0.0003,
                 time_model=True):
        self.input_dims = input_dims
        self.gamma = gamma
        self.embed_len = env.observation_space.shape[1]
        self.n_actions = n_actions
        self.batch_size = batch_size
        self.epsilon = epsilon
        self.eps_dec = eps_dec
        self.replace = replace
        self.eps_min = 0.01
        self.update_cntr = 0
        self.env = env
        self.scaler = GradScaler()
        if time_model:
            self.memory = TimeBuffer(capacity, input_dims, n_actions)
            self.q_eval = TimeModel(n_actions,
                                    nheads,
                                    transformer_layers,
                                    lr=lr)

            self.q_train = TimeModel(n_actions,
                                     nheads,
                                     transformer_layers,
                                     lr=lr)
        else:
            self.memory = ReplayBuffer(capacity=capacity,
                                       input_dims=self.input_dims,
                                       n_actions=self.n_actions,
                                       embed_len=self.embed_len)

            self.q_eval = Transformer(self.embed_len,
                                      nheads,
                                      n_actions,
                                      transformer_layers,
                                      network_name="q_eval",
                                      lr=lr)
            self.q_train = Transformer(self.embed_len,
                                       nheads,
                                       n_actions,
                                       transformer_layers,
                                       network_name="q_train",
                                       lr=lr)

    def pick_action(self, obs):
        if np.random.random() > self.epsilon:
            state = T.tensor([obs], dtype=T.float).to(self.q_eval.device)
            with autocast():
                output = self.q_eval.forward(state).sum(dim=0).mean(
                    dim=0).argmax(dim=0)
            action = output.item()
        else:
            action = self.env.action_space.sample()

        return action

    def save_script(self):
        print("Saving torch script...")
        q_eval_script = T.jit.script(self.q_eval)
        q_eval_script.save("q_eval_script.pt")
        print(q_eval_script)

    def update_target_network(self):
        if self.update_cntr % self.replace == 0:
            self.q_eval.load_state_dict(self.q_train.state_dict())

    # Store Experience
    def store_transition(self, state, action, reward, state_, done):
        self.memory.store_transition(state, action, reward, state_, done)

    def decrement_epsilon(self):
        self.epsilon = self.epsilon - self.eps_dec if self.epsilon > self.eps_min else self.eps_min

    # Agent's Learn Function
    def learn(self):
        if self.memory.mem_cntr < self.batch_size:
            return

        # Sample from memory
        states, actions, rewards, states_, dones = self.memory.sample_buffer(
            self.batch_size)
        # Numpy to Tensor
        states = T.tensor(states, dtype=T.float).to(self.q_eval.device)
        actions = T.tensor(actions, dtype=T.int64).to(self.q_eval.device)
        rewards = T.tensor(rewards, dtype=T.float).to(self.q_eval.device)
        states_ = T.tensor(states_, dtype=T.float).to(self.q_eval.device)
        done = T.tensor(dones, dtype=T.bool).to(self.q_eval.device)

        # self.q_train.optimizer.zero_grad()
        for param in self.q_train.parameters():
            param.grad = None

        self.update_target_network()

        indices = np.arange(self.batch_size)

        # Estimate Q
        with autocast():
            q_pred = self.q_train.forward(states).mean(dim=1)
            q_pred *= actions
            q_pred = q_pred.mean(dim=1)
            q_next = self.q_eval.forward(states_).mean(dim=1)
            q_train = self.q_train.forward(states_).mean(dim=1)

            q_next[done] = 0.0
            max_action = T.argmax(q_train, dim=1)

            y = rewards + self.gamma * q_next[indices, max_action]

            loss = self.q_train.loss(y, q_pred).to(self.q_eval.device)
        # loss.backward()
        # self.q_train.optimizer.step()
        self.scaler.scale(loss).backward()
        self.scaler.step(self.q_train.optimizer)
        self.scaler.update()

        self.update_cntr += 1
        self.decrement_epsilon()

    # Save weights
    def save(self):
        print("Saving...")
        self.q_eval.save()
        self.q_train.save()

    # Load Weights
    def load(self):
        print("loading...")
        self.q_eval.load()
        self.q_train.load()

    def count_params(self):
        model = self.q_train
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
예제 #30
0
def train(args):

    model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
    print("Parameter Count: %d" % count_parameters(model))

    if args.restore_ckpt is not None:
        model.load_state_dict(torch.load(args.restore_ckpt), strict=False)

    model.cuda()
    model.train()

    if args.stage != 'chairs':
        model.module.freeze_bn()

    train_loader = datasets.fetch_dataloader(args)
    optimizer, scheduler = fetch_optimizer(args, model)

    total_steps = 0
    scaler = GradScaler(enabled=args.mixed_precision)
    logger = Logger(model, scheduler)

    VAL_FREQ = 5000
    add_noise = True

    should_keep_training = True
    while should_keep_training:

        for i_batch, data_blob in enumerate(train_loader):
            optimizer.zero_grad()
            image1, image2, flow, valid = [x.cuda() for x in data_blob]

            # show_image(image1[0])
            # show_image(image2[0])

            if args.add_noise:
                stdv = np.random.uniform(0.0, 5.0)
                image1 = (image1 +
                          stdv * torch.randn(*image1.shape).cuda()).clamp(
                              0.0, 255.0)
                image2 = (image2 +
                          stdv * torch.randn(*image2.shape).cuda()).clamp(
                              0.0, 255.0)

            flow_predictions = model(image1, image2, iters=args.iters)

            loss, metrics = sequence_loss(flow_predictions, flow, valid)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            scaler.step(optimizer)
            scheduler.step()
            scaler.update()

            logger.push(metrics)

            if total_steps % VAL_FREQ == VAL_FREQ - 1:
                PATH = 'checkpoints/%d_%s.pth' % (total_steps + 1, args.name)
                torch.save(model.state_dict(), PATH)

                results = {}
                for val_dataset in args.validation:
                    if val_dataset == 'chairs':
                        results.update(evaluate.validate_chairs(model.module))
                    elif val_dataset == 'sintel':
                        results.update(evaluate.validate_sintel(model.module))
                    elif val_dataset == 'kitti':
                        results.update(evaluate.validate_kitti(model.module))

                logger.write_dict(results)

                model.train()
                if args.stage != 'chairs':
                    model.module.freeze_bn()

            total_steps += 1

            if total_steps > args.num_steps:
                should_keep_training = False
                break

    logger.close()
    PATH = 'checkpoints/%s.pth' % args.name
    torch.save(model.state_dict(), PATH)

    return PATH