예제 #1
0
def train(args, checkpoint, mid_checkpoint_location, final_checkpoint_location, best_checkpoint_location,
          actfun, curr_seed, outfile_path, filename, fieldnames, curr_sample_size, device, num_params,
          curr_k=2, curr_p=1, curr_g=1, perm_method='shuffle'):
    """
    Runs training session for a given randomized model
    :param args: arguments for this job
    :param checkpoint: current checkpoint
    :param checkpoint_location: output directory for checkpoints
    :param actfun: activation function currently being used
    :param curr_seed: seed being used by current job
    :param outfile_path: path to save outputs from training session
    :param fieldnames: column names for output file
    :param device: reference to CUDA device for GPU support
    :param num_params: number of parameters in the network
    :param curr_k: k value for this iteration
    :param curr_p: p value for this iteration
    :param curr_g: g value for this iteration
    :param perm_method: permutation strategy for our network
    :return:
    """

    resnet_ver = args.resnet_ver
    resnet_width = args.resnet_width
    num_epochs = args.num_epochs

    actfuns_1d = ['relu', 'abs', 'swish', 'leaky_relu', 'tanh']
    if actfun in actfuns_1d:
        curr_k = 1
    kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}

    if args.one_shot:
        util.seed_all(curr_seed)
        model_temp, _ = load_model(args.model, args.dataset, actfun, curr_k, curr_p, curr_g, num_params=num_params,
                                   perm_method=perm_method, device=device, resnet_ver=resnet_ver,
                                   resnet_width=resnet_width, verbose=args.verbose)

        util.seed_all(curr_seed)
        dataset_temp = util.load_dataset(
            args,
            args.model,
            args.dataset,
            seed=curr_seed,
            validation=True,
            batch_size=args.batch_size,
            train_sample_size=curr_sample_size,
            kwargs=kwargs)

        curr_hparams = hparams.get_hparams(args.model, args.dataset, actfun, curr_seed,
                                           num_epochs, args.search, args.hp_idx, args.one_shot)
        optimizer = optim.Adam(model_temp.parameters(),
                               betas=(curr_hparams['beta1'], curr_hparams['beta2']),
                               eps=curr_hparams['eps'],
                               weight_decay=curr_hparams['wd']
                               )

        start_time = time.time()
        oneshot_fieldnames = fieldnames if args.search else None
        oneshot_outfile_path = outfile_path if args.search else None
        lr = util.run_lr_finder(
            args,
            model_temp,
            dataset_temp[0],
            optimizer,
            nn.CrossEntropyLoss(),
            val_loader=dataset_temp[3],
            show=False,
            device=device,
            fieldnames=oneshot_fieldnames,
            outfile_path=oneshot_outfile_path,
            hparams=curr_hparams
        )
        curr_hparams = {}
        print("Time to find LR: {}\n LR found: {:3e}".format(time.time() - start_time, lr))

    else:
        curr_hparams = hparams.get_hparams(args.model, args.dataset, actfun, curr_seed,
                                           num_epochs, args.search, args.hp_idx)
        lr = curr_hparams['max_lr']

        criterion = nn.CrossEntropyLoss()
        model, model_params = load_model(args.model, args.dataset, actfun, curr_k, curr_p, curr_g, num_params=num_params,
                                   perm_method=perm_method, device=device, resnet_ver=resnet_ver,
                                   resnet_width=resnet_width, verbose=args.verbose)

        util.seed_all(curr_seed)
        model.apply(util.weights_init)

        util.seed_all(curr_seed)
        dataset = util.load_dataset(
            args,
            args.model,
            args.dataset,
            seed=curr_seed,
            validation=args.validation,
            batch_size=args.batch_size,
            train_sample_size=curr_sample_size,
            kwargs=kwargs)
        loaders = {
            'aug_train': dataset[0],
            'train': dataset[1],
            'aug_eval': dataset[2],
            'eval': dataset[3],
        }
        sample_size = dataset[4]
        batch_size = dataset[5]

        if args.one_shot:
            optimizer = optim.Adam(model_params)
            scheduler = OneCycleLR(optimizer,
                                   max_lr=lr,
                                   epochs=num_epochs,
                                   steps_per_epoch=int(math.floor(sample_size / batch_size)),
                                   cycle_momentum=False
                                   )
        else:
            optimizer = optim.Adam(model_params,
                                   betas=(curr_hparams['beta1'], curr_hparams['beta2']),
                                   eps=curr_hparams['eps'],
                                   weight_decay=curr_hparams['wd']
                                   )
            scheduler = OneCycleLR(optimizer,
                                   max_lr=curr_hparams['max_lr'],
                                   epochs=num_epochs,
                                   steps_per_epoch=int(math.floor(sample_size / batch_size)),
                                   pct_start=curr_hparams['cycle_peak'],
                                   cycle_momentum=False
                                   )

        epoch = 1
        if checkpoint is not None:
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            epoch = checkpoint['epoch']
            model.to(device)
            print("*** LOADED CHECKPOINT ***"
                  "\n{}"
                  "\nSeed: {}"
                  "\nEpoch: {}"
                  "\nActfun: {}"
                  "\nNum Params: {}"
                  "\nSample Size: {}"
                  "\np: {}"
                  "\nk: {}"
                  "\ng: {}"
                  "\nperm_method: {}".format(mid_checkpoint_location, checkpoint['curr_seed'],
                                             checkpoint['epoch'], checkpoint['actfun'],
                                             checkpoint['num_params'], checkpoint['sample_size'],
                                             checkpoint['p'], checkpoint['k'], checkpoint['g'],
                                             checkpoint['perm_method']))

        util.print_exp_settings(curr_seed, args.dataset, outfile_path, args.model, actfun,
                                util.get_model_params(model), sample_size, batch_size, model.k, model.p, model.g,
                                perm_method, resnet_ver, resnet_width, args.optim, args.validation, curr_hparams)

        best_val_acc = 0

        if args.mix_pre_apex:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

        # ---- Start Training
        while epoch <= num_epochs:

            if args.check_path != '':
                torch.save({'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'curr_seed': curr_seed,
                            'epoch': epoch,
                            'actfun': actfun,
                            'num_params': num_params,
                            'sample_size': sample_size,
                            'p': curr_p, 'k': curr_k, 'g': curr_g,
                            'perm_method': perm_method
                            }, mid_checkpoint_location)

            util.seed_all((curr_seed * args.num_epochs) + epoch)
            start_time = time.time()
            if args.mix_pre:
                scaler = torch.cuda.amp.GradScaler()

            # ---- Training
            model.train()
            total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
            for batch_idx, (x, targetx) in enumerate(loaders['aug_train']):
                # print(batch_idx)
                x, targetx = x.to(device), targetx.to(device)
                optimizer.zero_grad()
                if args.mix_pre:
                    with torch.cuda.amp.autocast():
                        output = model(x)
                        train_loss = criterion(output, targetx)
                    total_train_loss += train_loss
                    n += 1
                    scaler.scale(train_loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                elif args.mix_pre_apex:
                    output = model(x)
                    train_loss = criterion(output, targetx)
                    total_train_loss += train_loss
                    n += 1
                    with amp.scale_loss(train_loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    optimizer.step()
                else:
                    output = model(x)
                    train_loss = criterion(output, targetx)
                    total_train_loss += train_loss
                    n += 1
                    train_loss.backward()
                    optimizer.step()
                if args.optim == 'onecycle' or args.optim == 'onecycle_sgd':
                    scheduler.step()
                _, prediction = torch.max(output.data, 1)
                num_correct += torch.sum(prediction == targetx.data)
                num_total += len(prediction)
            epoch_aug_train_loss = total_train_loss / n
            epoch_aug_train_acc = num_correct * 1.0 / num_total

            alpha_primes = []
            alphas = []
            if model.actfun == 'combinact':
                for i, layer_alpha_primes in enumerate(model.all_alpha_primes):
                    curr_alpha_primes = torch.mean(layer_alpha_primes, dim=0)
                    curr_alphas = F.softmax(curr_alpha_primes, dim=0).data.tolist()
                    curr_alpha_primes = curr_alpha_primes.tolist()
                    alpha_primes.append(curr_alpha_primes)
                    alphas.append(curr_alphas)

            model.eval()
            with torch.no_grad():
                total_val_loss, n, num_correct, num_total = 0, 0, 0, 0
                for batch_idx, (y, targety) in enumerate(loaders['aug_eval']):
                    y, targety = y.to(device), targety.to(device)
                    output = model(y)
                    val_loss = criterion(output, targety)
                    total_val_loss += val_loss
                    n += 1
                    _, prediction = torch.max(output.data, 1)
                    num_correct += torch.sum(prediction == targety.data)
                    num_total += len(prediction)
                epoch_aug_val_loss = total_val_loss / n
                epoch_aug_val_acc = num_correct * 1.0 / num_total

                total_val_loss, n, num_correct, num_total = 0, 0, 0, 0
                for batch_idx, (y, targety) in enumerate(loaders['eval']):
                    y, targety = y.to(device), targety.to(device)
                    output = model(y)
                    val_loss = criterion(output, targety)
                    total_val_loss += val_loss
                    n += 1
                    _, prediction = torch.max(output.data, 1)
                    num_correct += torch.sum(prediction == targety.data)
                    num_total += len(prediction)
                epoch_val_loss = total_val_loss / n
                epoch_val_acc = num_correct * 1.0 / num_total
            lr_curr = 0
            for param_group in optimizer.param_groups:
                lr_curr = param_group['lr']
            print(
                "    Epoch {}: LR {:1.5f} ||| aug_train_acc {:1.4f} | val_acc {:1.4f}, aug {:1.4f} ||| "
                "aug_train_loss {:1.4f} | val_loss {:1.4f}, aug {:1.4f} ||| time = {:1.4f}"
                    .format(epoch, lr_curr, epoch_aug_train_acc, epoch_val_acc, epoch_aug_val_acc,
                            epoch_aug_train_loss, epoch_val_loss, epoch_aug_val_loss, (time.time() - start_time)), flush=True
            )

            if args.hp_idx is None:
                hp_idx = -1
            else:
                hp_idx = args.hp_idx

            epoch_train_loss = 0
            epoch_train_acc = 0
            if epoch == num_epochs:
                with torch.no_grad():
                    total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
                    for batch_idx, (x, targetx) in enumerate(loaders['aug_train']):
                        x, targetx = x.to(device), targetx.to(device)
                        output = model(x)
                        train_loss = criterion(output, targetx)
                        total_train_loss += train_loss
                        n += 1
                        _, prediction = torch.max(output.data, 1)
                        num_correct += torch.sum(prediction == targetx.data)
                        num_total += len(prediction)
                    epoch_aug_train_loss = total_train_loss / n
                    epoch_aug_train_acc = num_correct * 1.0 / num_total

                    total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
                    for batch_idx, (x, targetx) in enumerate(loaders['train']):
                        x, targetx = x.to(device), targetx.to(device)
                        output = model(x)
                        train_loss = criterion(output, targetx)
                        total_train_loss += train_loss
                        n += 1
                        _, prediction = torch.max(output.data, 1)
                        num_correct += torch.sum(prediction == targetx.data)
                        num_total += len(prediction)
                    epoch_train_loss = total_val_loss / n
                    epoch_train_acc = num_correct * 1.0 / num_total

            # Outputting data to CSV at end of epoch
            with open(outfile_path, mode='a') as out_file:
                writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n')
                writer.writerow({'dataset': args.dataset,
                                 'seed': curr_seed,
                                 'epoch': epoch,
                                 'time': (time.time() - start_time),
                                 'actfun': model.actfun,
                                 'sample_size': sample_size,
                                 'model': args.model,
                                 'batch_size': batch_size,
                                 'alpha_primes': alpha_primes,
                                 'alphas': alphas,
                                 'num_params': util.get_model_params(model),
                                 'var_nparams': args.var_n_params,
                                 'var_nsamples': args.var_n_samples,
                                 'k': curr_k,
                                 'p': curr_p,
                                 'g': curr_g,
                                 'perm_method': perm_method,
                                 'gen_gap': float(epoch_val_loss - epoch_train_loss),
                                 'aug_gen_gap': float(epoch_aug_val_loss - epoch_aug_train_loss),
                                 'resnet_ver': resnet_ver,
                                 'resnet_width': resnet_width,
                                 'epoch_train_loss': float(epoch_train_loss),
                                 'epoch_train_acc': float(epoch_train_acc),
                                 'epoch_aug_train_loss': float(epoch_aug_train_loss),
                                 'epoch_aug_train_acc': float(epoch_aug_train_acc),
                                 'epoch_val_loss': float(epoch_val_loss),
                                 'epoch_val_acc': float(epoch_val_acc),
                                 'epoch_aug_val_loss': float(epoch_aug_val_loss),
                                 'epoch_aug_val_acc': float(epoch_aug_val_acc),
                                 'hp_idx': hp_idx,
                                 'curr_lr': lr_curr,
                                 'found_lr': lr,
                                 'hparams': curr_hparams,
                                 'epochs': num_epochs
                                 })

            epoch += 1

            if args.optim == 'rmsprop':
                scheduler.step()

            if args.checkpoints:
                if epoch_val_acc > best_val_acc:
                    best_val_acc = epoch_val_acc
                    torch.save({'state_dict': model.state_dict(),
                                'optimizer': optimizer.state_dict(),
                                'scheduler': scheduler.state_dict(),
                                'curr_seed': curr_seed,
                                'epoch': epoch,
                                'actfun': actfun,
                                'num_params': num_params,
                                'sample_size': sample_size,
                                'p': curr_p, 'k': curr_k, 'g': curr_g,
                                'perm_method': perm_method
                                }, best_checkpoint_location)

                torch.save({'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'curr_seed': curr_seed,
                            'epoch': epoch,
                            'actfun': actfun,
                            'num_params': num_params,
                            'sample_size': sample_size,
                            'p': curr_p, 'k': curr_k, 'g': curr_g,
                            'perm_method': perm_method
                            }, final_checkpoint_location)
예제 #2
0
class Maskv3Agent:
    def __init__(self, config):
        self.config = config

        # Train on device
        target_device = config['train']['device']
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.device = target_device
        else:
            self.device = "cpu"

        # Load dataset
        train_transform = get_yolo_transform(config['dataset']['size'],
                                             mode='train')
        valid_transform = get_yolo_transform(config['dataset']['size'],
                                             mode='test')
        train_dataset = YOLOMaskDataset(
            csv_file=config['dataset']['train']['csv'],
            img_dir=config['dataset']['train']['img_root'],
            mask_dir=config['dataset']['train']['mask_root'],
            anchors=config['dataset']['anchors'],
            scales=config['dataset']['scales'],
            n_classes=config['dataset']['n_classes'],
            transform=train_transform)
        valid_dataset = YOLOMaskDataset(
            csv_file=config['dataset']['valid']['csv'],
            img_dir=config['dataset']['valid']['img_root'],
            mask_dir=config['dataset']['valid']['mask_root'],
            anchors=config['dataset']['anchors'],
            scales=config['dataset']['scales'],
            n_classes=config['dataset']['n_classes'],
            transform=valid_transform)
        # DataLoader
        self.train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=config['dataloader']['batch_size'],
            num_workers=config['dataloader']['num_workers'],
            collate_fn=maskv3_collate_fn,
            pin_memory=True,
            shuffle=True,
            drop_last=False)
        self.valid_loader = DataLoader(
            dataset=valid_dataset,
            batch_size=config['dataloader']['batch_size'],
            num_workers=config['dataloader']['num_workers'],
            collate_fn=maskv3_collate_fn,
            pin_memory=True,
            shuffle=False,
            drop_last=False)
        # Model
        model = Maskv3(
            # Detection Branch
            in_channels=config['model']['in_channels'],
            num_classes=config['model']['num_classes'],
            # Prototype Branch
            num_masks=config['model']['num_masks'],
            num_features=config['model']['num_features'],
        )
        self.model = model.to(self.device)
        # Faciliated Anchor boxes with model
        torch_anchors = torch.tensor(config['dataset']['anchors'])  # (3, 3, 2)
        torch_scales = torch.tensor(config['dataset']['scales'])  # (3,)
        scaled_anchors = (  # (3, 3, 2)
            torch_anchors *
            (torch_scales.unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)))
        self.scaled_anchors = scaled_anchors.to(self.device)

        # Optimizer
        self.scaler = torch.cuda.amp.GradScaler()
        self.optimizer = optim.Adam(
            params=self.model.parameters(),
            lr=config['optimizer']['lr'],
            weight_decay=config['optimizer']['weight_decay'],
        )
        # Scheduler
        self.scheduler = OneCycleLR(
            self.optimizer,
            max_lr=config['optimizer']['lr'],
            epochs=config['train']['n_epochs'],
            steps_per_epoch=len(self.train_loader),
        )
        # Loss function
        self.loss_fn = YOLOMaskLoss(num_classes=config['model']['num_classes'],
                                    num_masks=config['model']['num_masks'])

        # Tensorboard
        self.logdir = config['train']['logdir']
        self.board = SummaryWriter(logdir=config['train']['logdir'])

        # Training State
        self.current_epoch = 0
        self.current_map = 0

    def resume(self):
        checkpoint_path = osp.join(self.logdir, 'best.pth')
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        self.current_map = checkpoint['current_map']
        self.current_epoch = checkpoint['current_epoch']
        print("Restore checkpoint at '{}'".format(self.current_epoch))

    def train(self):
        for epoch in range(self.current_epoch + 1,
                           self.config['train']['n_epochs'] + 1):
            self.current_epoch = epoch
            self._train_one_epoch()
            self._validate()
            accs = self._check_accuracy()

            if self.current_epoch < self.config['valid']['when']:
                self._save_checkpoint()

            if (self.current_epoch >= self.config['valid']['when']
                    and self.current_epoch % 5 == 0):
                mAP50 = self._check_map()
                if mAP50 > self.current_map:
                    self.current_map = mAP50
                    self._save_checkpoint()

    def finalize(self):
        self._check_map()

    def _train_one_epoch(self):
        n_epochs = self.config['train']['n_epochs']
        current_epoch = self.current_epoch
        current_lr = self.optimizer.param_groups[0]['lr']
        loop = tqdm(self.train_loader,
                    leave=True,
                    desc=(f"Train Epoch:{current_epoch}/{n_epochs}"
                          f", LR: {current_lr:.5f}"))
        obj_losses = []
        box_losses = []
        noobj_losses = []
        class_losses = []
        total_losses = []
        segment_losses = []
        self.model.train()
        for batch_idx, (imgs, masks, targets) in enumerate(loop):
            # Move device
            imgs = imgs.to(self.device)  # (N, 3, 416, 416)
            masks = [m.to(self.device) for m in masks]  # (nM_g, H, W)
            target_s1 = targets[0].to(self.device)  # (N, 3, 13, 13, 6)
            target_s2 = targets[1].to(self.device)  # (N, 3, 26, 26, 6)
            target_s3 = targets[2].to(self.device)  # (N, 3, 52, 52, 6)
            # Model prediction
            with torch.cuda.amp.autocast():
                outs, prototypes = self.model(imgs)
                s1_loss = self.loss_fn(
                    outs[0],
                    target_s1,
                    self.scaled_anchors[0],  # Detection Branch
                    prototypes,
                    masks,  # Prototype Branch
                )
                s2_loss = self.loss_fn(
                    outs[1],
                    target_s2,
                    self.scaled_anchors[1],  # Detection Branch
                    prototypes,
                    masks,  # Prototype Branch
                )
                s3_loss = self.loss_fn(
                    outs[2],
                    target_s3,
                    self.scaled_anchors[2],  # Detection Branch
                    prototypes,
                    masks,  # Prototype Branch
                )
            # Aggregate loss
            obj_loss = s1_loss['obj_loss'] + s2_loss['obj_loss'] + s3_loss[
                'obj_loss']
            box_loss = s1_loss['box_loss'] + s2_loss['box_loss'] + s3_loss[
                'box_loss']
            noobj_loss = s1_loss['noobj_loss'] + s2_loss[
                'noobj_loss'] + s3_loss['noobj_loss']
            class_loss = s1_loss['class_loss'] + s2_loss[
                'class_loss'] + s3_loss['class_loss']
            segment_loss = s1_loss['segment_loss'] + s2_loss[
                'segment_loss'] + s3_loss['segment_loss']
            total_loss = s1_loss['total_loss'] + s2_loss[
                'total_loss'] + s3_loss['total_loss']
            # Moving average loss
            total_losses.append(total_loss.item())
            obj_losses.append(obj_loss.item())
            noobj_losses.append(noobj_loss.item())
            box_losses.append(box_loss.item())
            class_losses.append(class_loss.item())
            segment_losses.append(segment_loss.item())
            # Update Parameters
            self.optimizer.zero_grad()
            self.scaler.scale(total_loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.scheduler.step()
            # Upadte progress bar
            mean_total_loss = sum(total_losses) / len(total_losses)
            mean_obj_loss = sum(obj_losses) / len(obj_losses)
            mean_noobj_loss = sum(noobj_losses) / len(noobj_losses)
            mean_box_loss = sum(box_losses) / len(box_losses)
            mean_class_loss = sum(class_losses) / len(class_losses)
            mean_segment_loss = sum(segment_losses) / len(segment_losses)
            loop.set_postfix(
                loss=mean_total_loss,
                cls=mean_class_loss,
                box=mean_box_loss,
                obj=mean_obj_loss,
                noobj=mean_noobj_loss,
                segment=mean_segment_loss,
            )
        # Logging (epoch)
        epoch_total_loss = sum(total_losses) / len(total_losses)
        epoch_obj_loss = sum(obj_losses) / len(obj_losses)
        epoch_noobj_loss = sum(noobj_losses) / len(noobj_losses)
        epoch_box_loss = sum(box_losses) / len(box_losses)
        epoch_class_loss = sum(class_losses) / len(class_losses)
        epoch_segment_loss = sum(segment_losses) / len(segment_losses)
        self.board.add_scalar('Epoch Train Loss',
                              epoch_total_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Train BOX Loss',
                              epoch_box_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Train OBJ Loss',
                              epoch_obj_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Train NOOBJ Loss',
                              epoch_noobj_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Train CLASS Loss',
                              epoch_class_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Train SEGMENT Loss',
                              epoch_segment_loss,
                              global_step=self.current_epoch)

    def _validate(self):
        n_epochs = self.config['train']['n_epochs']
        current_epoch = self.current_epoch
        current_lr = self.optimizer.param_groups[0]['lr']
        loop = tqdm(self.valid_loader,
                    leave=True,
                    desc=(f"Valid Epoch:{current_epoch}/{n_epochs}"
                          f", LR: {current_lr:.5f}"))
        obj_losses = []
        box_losses = []
        noobj_losses = []
        class_losses = []
        total_losses = []
        segment_losses = []
        self.model.eval()
        for batch_idx, (imgs, masks, targets) in enumerate(loop):
            # Move device
            imgs = imgs.to(self.device)  # (N, 3, 416, 416)
            masks = [m.to(self.device) for m in masks]  # (nM_g, H, W)
            target_s1 = targets[0].to(self.device)  # (N, 3, 13, 13, 6)
            target_s2 = targets[1].to(self.device)  # (N, 3, 26, 26, 6)
            target_s3 = targets[2].to(self.device)  # (N, 3, 52, 52, 6)
            # Model Prediction
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    outs, prototypes = self.model(imgs)
                    s1_loss = self.loss_fn(
                        outs[0],
                        target_s1,
                        self.scaled_anchors[0],  # Detection Branch
                        prototypes,
                        masks,  # Prototype Branch
                    )
                    s2_loss = self.loss_fn(
                        outs[1],
                        target_s2,
                        self.scaled_anchors[1],  # Detection Branch
                        prototypes,
                        masks,  # Prototype Branch
                    )
                    s3_loss = self.loss_fn(
                        outs[2],
                        target_s3,
                        self.scaled_anchors[2],  # Detection Branch
                        prototypes,
                        masks,  # Prototype Branch
                    )
            # Aggregate loss
            obj_loss = s1_loss['obj_loss'] + s2_loss['obj_loss'] + s3_loss[
                'obj_loss']
            box_loss = s1_loss['box_loss'] + s2_loss['box_loss'] + s3_loss[
                'box_loss']
            noobj_loss = s1_loss['noobj_loss'] + s2_loss[
                'noobj_loss'] + s3_loss['noobj_loss']
            class_loss = s1_loss['class_loss'] + s2_loss[
                'class_loss'] + s3_loss['class_loss']
            segment_loss = s1_loss['segment_loss'] + s2_loss[
                'segment_loss'] + s3_loss['segment_loss']
            total_loss = s1_loss['total_loss'] + s2_loss[
                'total_loss'] + s3_loss['total_loss']
            # Moving average loss
            obj_losses.append(obj_loss.item())
            box_losses.append(box_loss.item())
            noobj_losses.append(noobj_loss.item())
            class_losses.append(class_loss.item())
            total_losses.append(total_loss.item())
            segment_losses.append(segment_loss.item())
            # Upadte progress bar
            mean_total_loss = sum(total_losses) / len(total_losses)
            mean_obj_loss = sum(obj_losses) / len(obj_losses)
            mean_noobj_loss = sum(noobj_losses) / len(noobj_losses)
            mean_box_loss = sum(box_losses) / len(box_losses)
            mean_class_loss = sum(class_losses) / len(class_losses)
            mean_segment_loss = sum(segment_losses) / len(segment_losses)
            loop.set_postfix(
                loss=mean_total_loss,
                cls=mean_class_loss,
                box=mean_box_loss,
                obj=mean_obj_loss,
                noobj=mean_noobj_loss,
                segment=mean_segment_loss,
            )
        # Logging (epoch)
        epoch_total_loss = sum(total_losses) / len(total_losses)
        epoch_obj_loss = sum(obj_losses) / len(obj_losses)
        epoch_noobj_loss = sum(noobj_losses) / len(noobj_losses)
        epoch_box_loss = sum(box_losses) / len(box_losses)
        epoch_class_loss = sum(class_losses) / len(class_losses)
        epoch_segment_loss = sum(segment_losses) / len(segment_losses)
        self.board.add_scalar('Epoch Valid Loss',
                              epoch_total_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Valid BOX Loss',
                              epoch_box_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Valid OBJ Loss',
                              epoch_obj_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Valid NOOBJ Loss',
                              epoch_noobj_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Valid CLASS Loss',
                              epoch_class_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Valid SEGMENT Loss',
                              epoch_segment_loss,
                              global_step=self.current_epoch)

    def _check_accuracy(self):
        tot_obj = 0
        tot_noobj = 0
        correct_obj = 0
        correct_noobj = 0
        correct_class = 0
        self.model.eval()
        loop = tqdm(self.valid_loader, leave=True, desc=f"Check ACC")
        for batch_idx, (imgs, masks, targets) in enumerate(loop):
            batch_size = imgs.size(0)
            # Move device
            imgs = imgs.to(self.device)  # (N, 3, 416, 416)
            target_s1 = targets[0].to(self.device)  # (N, 3, 13, 13, 6)
            target_s2 = targets[1].to(self.device)  # (N, 3, 26, 26, 6)
            target_s3 = targets[2].to(self.device)  # (N, 3, 52, 52, 6)
            targets = [target_s1, target_s2, target_s3]
            # Model Prediction
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    outs, prototypes = self.model(imgs)
            for scale_idx in range(len(outs)):
                # Get output
                pred = outs[scale_idx]
                target = targets[scale_idx]
                # Get mask
                obj_mask = target[..., 4] == 1
                noobj_mask = target[..., 4] == 0
                # Count objects
                tot_obj += torch.sum(obj_mask)
                tot_noobj += torch.sum(noobj_mask)
                # Exception Handling
                if torch.sum(obj_mask) == 0:
                    obj_pred = torch.sigmoid(
                        pred[..., 4]) > self.config['valid']['conf_threshold']
                    correct_noobj += torch.sum(
                        obj_pred[noobj_mask] == target[..., 4][noobj_mask])
                    continue
                # Count number of correct classified object
                correct_class += torch.sum((torch.argmax(
                    pred[...,
                         5:5 + self.config['model']['num_classes']][obj_mask],
                    dim=-1) == target[..., 5][obj_mask]))
                # Count number of correct objectness & non-objectness
                obj_pred = torch.sigmoid(
                    pred[..., 4]) > self.config['valid']['conf_threshold']
                correct_obj += torch.sum(
                    obj_pred[obj_mask] == target[..., 4][obj_mask])
                correct_noobj += torch.sum(
                    obj_pred[noobj_mask] == target[..., 4][noobj_mask])
        # Aggregation Result
        acc_obj = (correct_obj / (tot_obj + 1e-6)) * 100
        acc_cls = (correct_class / (tot_obj + 1e-6)) * 100
        acc_noobj = (correct_noobj / (tot_noobj + 1e-6)) * 100
        accs = {
            'cls': acc_cls.item(),
            'obj': acc_obj.item(),
            'noobj': acc_noobj.item()
        }
        print(f"Epoch {self.current_epoch} [Accs]: {accs}")
        return accs

    def _check_map(self):
        sample_idx = 0
        all_pred_bboxes = []
        all_true_bboxes = []
        self.model.eval()
        loop = tqdm(self.valid_loader, leave=True, desc="Check mAP")
        for batch_idx, (imgs, masks, targets) in enumerate(loop):
            batch_size = imgs.size(0)
            # Move device
            imgs = imgs.to(self.device)  # (N, 3, 416, 416)
            target_s1 = targets[0].to(self.device)  # (N, 3, 13, 13, 6)
            target_s2 = targets[1].to(self.device)  # (N, 3, 26, 26, 6)
            target_s3 = targets[2].to(self.device)  # (N, 3, 52, 52, 6)
            targets = [target_s1, target_s2, target_s3]
            # Model Forward
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    preds, prototypes = self.model(imgs)
            # Convert cells to bboxes
            # =================================================================
            true_bboxes = [[] for _ in range(batch_size)]
            pred_bboxes = [[] for _ in range(batch_size)]
            for scale_idx, (pred, target) in enumerate(zip(preds, targets)):
                scale = pred.size(2)
                anchors = self.scaled_anchors[scale_idx]  # (3, 2)
                anchors = anchors.reshape(1, 3, 1, 1, 2)  # (1, 3, 1, 1, 2)
                # Convert prediction to correct format
                pred[..., 0:2] = torch.sigmoid(pred[...,
                                                    0:2])  # (N, 3, S, S, 2)
                pred[..., 2:4] = torch.exp(
                    pred[..., 2:4]) * anchors  # (N, 3, S, S, 2)
                pred[..., 4:5] = torch.sigmoid(pred[...,
                                                    4:5])  # (N, 3, S, S, 1)
                pred_cls_probs = F.softmax(
                    pred[..., 5:5 + self.config['model']['num_classes']],
                    dim=-1)  # (N, 3, S, S, C)
                _, indices = torch.max(pred_cls_probs, dim=-1)  # (N, 3, S, S)
                indices = indices.unsqueeze(-1)  # (N, 3, S, S, 1)
                pred = torch.cat([pred[..., :5], indices],
                                 dim=-1)  # (N, 3, S, S, 6)
                # Convert coordinate system to normalized format (xywh)
                pboxes = cells_to_boxes(cells=pred,
                                        scale=scale)  # (N, 3, S, S, 6)
                tboxes = cells_to_boxes(cells=target,
                                        scale=scale)  # (N, 3, S, S, 6)
                # Filter out bounding boxes from all cells
                for idx, cell_boxes in enumerate(pboxes):
                    obj_mask = cell_boxes[
                        ..., 4] > self.config['valid']['conf_threshold']
                    boxes = cell_boxes[obj_mask]
                    pred_bboxes[idx] += boxes.tolist()
                # Filter out bounding boxes from all cells
                for idx, cell_boxes in enumerate(tboxes):
                    obj_mask = cell_boxes[..., 4] > 0.99
                    boxes = cell_boxes[obj_mask]
                    true_bboxes[idx] += boxes.tolist()
            # Perform NMS batch-by-batch
            # =================================================================
            for batch_idx in range(batch_size):
                pbboxes = torch.tensor(pred_bboxes[batch_idx])
                tbboxes = torch.tensor(true_bboxes[batch_idx])
                # Perform NMS class-by-class
                for c in range(self.config['model']['num_classes']):
                    # Filter pred boxes of specific class
                    nms_pred_boxes = nms_by_class(
                        target=c,
                        bboxes=pbboxes,
                        iou_threshold=self.config['valid']
                        ['nms_iou_threshold'])
                    nms_true_boxes = nms_by_class(
                        target=c,
                        bboxes=tbboxes,
                        iou_threshold=self.config['valid']
                        ['nms_iou_threshold'])
                    all_pred_bboxes.extend([[sample_idx] + box
                                            for box in nms_pred_boxes])
                    all_true_bboxes.extend([[sample_idx] + box
                                            for box in nms_true_boxes])
                sample_idx += 1
        # Compute [email protected] & [email protected]
        # =================================================================
        # The format of the bboxes is (idx, x1, y1, x2, y2, conf, class)
        all_pred_bboxes = torch.tensor(all_pred_bboxes)  # (J, 7)
        all_true_bboxes = torch.tensor(all_true_bboxes)  # (K, 7)
        eval50 = mean_average_precision(
            all_pred_bboxes,
            all_true_bboxes,
            iou_threshold=0.5,
            n_classes=self.config['dataset']['n_classes'])
        eval75 = mean_average_precision(
            all_pred_bboxes,
            all_true_bboxes,
            iou_threshold=0.75,
            n_classes=self.config['dataset']['n_classes'])
        print((
            f"Epoch {self.current_epoch}:\n"
            f"\t-[[email protected]]={eval50['mAP']:.3f}, [Recall]={eval50['recall']:.3f}, [Precision]={eval50['precision']:.3f}\n"
            f"\t-[[email protected]]={eval75['mAP']:.3f}, [Recall]={eval75['recall']:.3f}, [Precision]={eval75['precision']:.3f}\n"
        ))
        return eval50['mAP']

    def _save_checkpoint(self):
        checkpoint = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'current_map': self.current_map,
            'current_epoch': self.current_epoch
        }
        checkpoint_path = osp.join(self.logdir, 'best.pth')
        torch.save(checkpoint, checkpoint_path)
        print("Save checkpoint at '{}'".format(checkpoint_path))
예제 #3
0
class Trainer():
    def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.train_lmdb = config['dataset']['train_lmdb']
        self.valid_lmdb = config['dataset']['valid_lmdb']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.image_aug = config['aug']['image_aug']
        self.masked_language_model = config['aug']['masked_language_model']
        self.metrics = config['trainer']['metrics']
        self.is_padding = config['dataset']['is_padding']

        self.tensorboard_dir = config['monitor']['log_dir']
        if not os.path.exists(self.tensorboard_dir):
            os.makedirs(self.tensorboard_dir, exist_ok=True)
        self.writer = SummaryWriter(self.tensorboard_dir)

        # LOGGER
        self.logger = Logger(config['monitor']['log_dir'])
        self.logger.info(config)

        self.iter = 0
        self.best_acc = 0
        self.scheduler = None
        self.is_finetuning = config['trainer']['is_finetuning']

        if self.is_finetuning:
            self.logger.info("Finetuning model ---->")
            if self.model.seq_modeling == 'crnn':
                self.optimizer = Adam(lr=0.0001,
                                      params=self.model.parameters(),
                                      betas=(0.5, 0.999))
            else:
                self.optimizer = AdamW(lr=0.0001,
                                       params=self.model.parameters(),
                                       betas=(0.9, 0.98),
                                       eps=1e-09)

        else:

            self.optimizer = AdamW(self.model.parameters(),
                                   betas=(0.9, 0.98),
                                   eps=1e-09)
            self.scheduler = OneCycleLR(self.optimizer,
                                        total_steps=self.num_iters,
                                        **config['optimizer'])

        if self.model.seq_modeling == 'crnn':
            self.criterion = torch.nn.CTCLoss(self.vocab.pad,
                                              zero_infinity=True)
        else:
            self.criterion = LabelSmoothingLoss(len(self.vocab),
                                                padding_idx=self.vocab.pad,
                                                smoothing=0.1)

        # Pretrained model
        if config['trainer']['pretrained']:
            self.load_weights(config['trainer']['pretrained'])
            self.logger.info("Loaded trained model from: {}".format(
                config['trainer']['pretrained']))

        # Resume
        elif config['trainer']['resume_from']:
            self.load_checkpoint(config['trainer']['resume_from'])
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(torch.device(self.device))

            self.logger.info("Resume training from {}".format(
                config['trainer']['resume_from']))

        # DATASET
        transforms = None
        if self.image_aug:
            transforms = augmentor

        train_lmdb_paths = [
            os.path.join(self.data_root, lmdb_path)
            for lmdb_path in self.train_lmdb
        ]

        self.train_gen = self.data_gen(
            lmdb_paths=train_lmdb_paths,
            data_root=self.data_root,
            annotation=self.train_annotation,
            masked_language_model=self.masked_language_model,
            transform=transforms,
            is_train=True)

        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                lmdb_paths=[os.path.join(self.data_root, self.valid_lmdb)],
                data_root=self.data_root,
                annotation=self.valid_annotation,
                masked_language_model=False)

        self.train_losses = []
        self.logger.info("Number batch samples of training: %d" %
                         len(self.train_gen))
        self.logger.info("Number batch samples of valid: %d" %
                         len(self.valid_gen))

        config_savepath = os.path.join(self.tensorboard_dir, "config.yml")
        if not os.path.exists(config_savepath):
            self.logger.info("Saving config file at: %s" % config_savepath)
            Cfg(config).save(config_savepath)

    def train(self):
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1
            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start
            start = time.time()

            # LOSS
            loss = self.step(batch)
            total_loss += loss
            self.train_losses.append((self.iter, loss))

            total_gpu_time += time.time() - start

            if self.iter % self.print_every == 0:

                info = 'Iter: {:06d} - Train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)
                lastest_loss = total_loss / self.print_every
                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                self.logger.info(info)

            if self.valid_annotation and self.iter % self.valid_every == 0:
                val_time = time.time()
                val_loss = self.validate()
                acc_full_seq, acc_per_char, wer = self.precision(self.metrics)

                self.logger.info("Iter: {:06d}, start validating".format(
                    self.iter))
                info = 'Iter: {:06d} - Valid loss: {:.3f} - Acc full seq: {:.4f} - Acc per char: {:.4f} - WER: {:.4f} - Time: {:.4f}'.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char, wer,
                    time.time() - val_time)
                self.logger.info(info)

                if acc_full_seq > self.best_acc:
                    self.save_weights(self.tensorboard_dir + "/best.pt")
                    self.best_acc = acc_full_seq

                self.logger.info("Iter: {:06d} - Best acc: {:.4f}".format(
                    self.iter, self.best_acc))

                filename = 'last.pt'
                filepath = os.path.join(self.tensorboard_dir, filename)
                self.logger.info("Save checkpoint %s" % filename)
                self.save_checkpoint(filepath)

                log_loss = {'train loss': lastest_loss, 'val loss': val_loss}
                self.writer.add_scalars('Loss', log_loss, self.iter)
                self.writer.add_scalar('WER', wer, self.iter)

    def validate(self):
        self.model.eval()

        total_loss = []

        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                img, tgt_input, tgt_output, tgt_padding_mask = batch[
                    'img'], batch['tgt_input'], batch['tgt_output'], batch[
                        'tgt_padding_mask']

                outputs = self.model(img, tgt_input, tgt_padding_mask)
                #                loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))

                if self.model.seq_modeling == 'crnn':
                    length = batch['labels_len']
                    preds_size = torch.autograd.Variable(
                        torch.IntTensor([outputs.size(0)] * self.batch_size))
                    loss = self.criterion(outputs, tgt_output, preds_size,
                                          length)
                else:
                    outputs = outputs.flatten(0, 1)
                    tgt_output = tgt_output.flatten()
                    loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                del outputs
                del loss

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []
        probs_sents = []
        imgs_sents = []

        for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)):
            batch = self.batch_to_device(batch)

            if self.model.seq_modeling != 'crnn':
                if self.beamsearch:
                    translated_sentence = batch_translate_beam_search(
                        batch['img'], self.model)
                    prob = None
                else:
                    translated_sentence, prob = translate(
                        batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist())
            else:
                translated_sentence, prob = translate_crnn(
                    batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist(), crnn=True)

            actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())
            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)

            imgs_sents.extend(batch['img'])
            img_files.extend(batch['filenames'])
            probs_sents.extend(prob)

            # Visualize in tensorboard
            if idx == 0:
                try:
                    num_samples = self.config['monitor']['num_samples']
                    fig = plt.figure(figsize=(12, 15))
                    imgs_samples = imgs_sents[:num_samples]
                    preds_samples = pred_sents[:num_samples]
                    actuals_samples = actual_sents[:num_samples]
                    probs_samples = probs_sents[:num_samples]
                    for id_img in range(len(imgs_samples)):
                        img = imgs_samples[id_img]
                        img = img.permute(1, 2, 0)
                        img = img.cpu().detach().numpy()
                        ax = fig.add_subplot(num_samples,
                                             1,
                                             id_img + 1,
                                             xticks=[],
                                             yticks=[])
                        plt.imshow(img)
                        ax.set_title(
                            "LB: {} \n Pred: {:.4f}-{}".format(
                                actuals_samples[id_img], probs_samples[id_img],
                                preds_samples[id_img]),
                            color=('green' if actuals_samples[id_img]
                                   == preds_samples[id_img] else 'red'),
                            fontdict={
                                'fontsize': 18,
                                'fontweight': 'medium'
                            })

                    self.writer.add_figure('predictions vs. actuals',
                                           fig,
                                           global_step=self.iter)
                except Exception as error:
                    print(error)
                    continue

            if sample != None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, img_files, probs_sents, imgs_sents

    def precision(self, sample=None, measure_time=True):
        t1 = time.time()
        pred_sents, actual_sents, _, _, _ = self.predict(sample=sample)
        time_predict = time.time() - t1

        sensitive_case = self.config['predictor']['sensitive_case']
        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        sensitive_case,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        sensitive_case,
                                        mode='per_char')
        wer = compute_accuracy(actual_sents,
                               pred_sents,
                               sensitive_case,
                               mode='wer')

        if measure_time:
            print("Time: {:.4f}".format(time_predict / len(actual_sents)))
        return acc_full_seq, acc_per_char, wer

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16,
                             save_fig=False):

        pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]
            probs = [probs[i] for i in wrongs]
            imgs = [imgs[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}
        ncols = 5
        nrows = int(math.ceil(len(img_files) / ncols))
        fig, ax = plt.subplots(nrows, ncols, figsize=(12, 15))

        for vis_idx in range(0, len(img_files)):
            row = vis_idx // ncols
            col = vis_idx % ncols

            pred_sent = pred_sents[vis_idx]
            actual_sent = actual_sents[vis_idx]
            prob = probs[vis_idx]
            img = imgs[vis_idx].permute(1, 2, 0).cpu().detach().numpy()

            ax[row, col].imshow(img)
            ax[row, col].set_title(
                "Pred: {: <2} \n Actual: {} \n prob: {:.2f}".format(
                    pred_sent, actual_sent, prob),
                fontname=fontname,
                color='r' if pred_sent != actual_sent else 'g')
            ax[row, col].get_xaxis().set_ticks([])
            ax[row, col].get_yaxis().set_ticks([])

        plt.subplots_adjust()
        if save_fig:
            fig.savefig('vis_prediction.png')
        plt.show()

    def log_prediction(self, sample=16, csv_file='model.csv'):
        pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample)
        save_predictions(csv_file, pred_sents, actual_sents, img_files)

    def vis_data(self, sample=20):

        ncols = 5
        nrows = int(math.ceil(sample / ncols))
        fig, ax = plt.subplots(nrows, ncols, figsize=(12, 12))

        num_plots = 0
        for idx, batch in enumerate(self.train_gen):
            for vis_idx in range(self.batch_size):
                row = num_plots // ncols
                col = num_plots % ncols

                img = batch['img'][vis_idx].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(
                    batch['tgt_input'].T[vis_idx].tolist())

                ax[row, col].imshow(img)
                ax[row, col].set_title("Label: {: <2}".format(sent),
                                       fontsize=16,
                                       color='g')

                ax[row, col].get_xaxis().set_ticks([])
                ax[row, col].get_yaxis().set_ticks([])

                num_plots += 1
                if num_plots >= sample:
                    plt.subplots_adjust()
                    fig.savefig('vis_dataset.png')
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']
        self.train_losses = checkpoint['train_losses']
        if self.scheduler is not None:
            self.scheduler.load_state_dict(checkpoint['scheduler'])

        self.best_acc = checkpoint['best_acc']

    def save_checkpoint(self, filename):
        state = {
            'iter':
            self.iter,
            'state_dict':
            self.model.state_dict(),
            'optimizer':
            self.optimizer.state_dict(),
            'train_losses':
            self.train_losses,
            'scheduler':
            None if self.scheduler is None else self.scheduler.state_dict(),
            'best_acc':
            self.best_acc
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))
        if self.is_checkpoint(state_dict):
            self.model.load_state_dict(state_dict['state_dict'])
        else:

            for name, param in self.model.named_parameters():
                if name not in state_dict:
                    print('{} not found'.format(name))
                elif state_dict[name].shape != param.shape:
                    print('{} missmatching shape, required {} but found {}'.
                          format(name, param.shape, state_dict[name].shape))
                    del state_dict[name]
            self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def is_checkpoint(self, checkpoint):
        try:
            checkpoint['state_dict']
        except:
            return False
        else:
            return True

    def batch_to_device(self, batch):
        img = batch['img'].to(self.device, non_blocking=True)
        tgt_input = batch['tgt_input'].to(self.device, non_blocking=True)
        tgt_output = batch['tgt_output'].to(self.device, non_blocking=True)
        tgt_padding_mask = batch['tgt_padding_mask'].to(self.device,
                                                        non_blocking=True)

        batch = {
            'img': img,
            'tgt_input': tgt_input,
            'tgt_output': tgt_output,
            'tgt_padding_mask': tgt_padding_mask,
            'filenames': batch['filenames'],
            'labels_len': batch['labels_len']
        }

        return batch

    def data_gen(self,
                 lmdb_paths,
                 data_root,
                 annotation,
                 masked_language_model=True,
                 transform=None,
                 is_train=False):
        datasets = []
        for lmdb_path in lmdb_paths:
            dataset = OCRDataset(
                lmdb_path=lmdb_path,
                root_dir=data_root,
                annotation_path=annotation,
                vocab=self.vocab,
                transform=transform,
                image_height=self.config['dataset']['image_height'],
                image_min_width=self.config['dataset']['image_min_width'],
                image_max_width=self.config['dataset']['image_max_width'],
                separate=self.config['dataset']['separate'],
                batch_size=self.batch_size,
                is_padding=self.is_padding)
            datasets.append(dataset)
        if len(self.train_lmdb) > 1:
            dataset = torch.utils.data.ConcatDataset(datasets)

        if self.is_padding:
            sampler = None
        else:
            sampler = ClusterRandomSampler(dataset, self.batch_size, True)

        collate_fn = Collator(masked_language_model)

        gen = DataLoader(dataset,
                         batch_size=self.batch_size,
                         sampler=sampler,
                         collate_fn=collate_fn,
                         shuffle=is_train,
                         drop_last=self.model.seq_modeling == 'crnn',
                         **self.config['dataloader'])

        return gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[
            'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']

        outputs = self.model(img,
                             tgt_input,
                             tgt_key_padding_mask=tgt_padding_mask)
        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))

        if self.model.seq_modeling == 'crnn':
            length = batch['labels_len']
            preds_size = torch.autograd.Variable(
                torch.IntTensor([outputs.size(0)] * self.batch_size))
            loss = self.criterion(outputs, tgt_output, preds_size, length)
        else:
            outputs = outputs.view(
                -1, outputs.size(2))  # flatten(0, 1)    # B*S x N_class
            tgt_output = tgt_output.view(-1)  # flatten()    # B*S
            loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()

        if not self.is_finetuning:
            self.scheduler.step()

        loss_item = loss.item()

        return loss_item

    def count_parameters(self, model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def gen_pseudo_labels(self, outfile=None):
        pred_sents = []
        img_files = []
        probs_sents = []

        for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)):
            batch = self.batch_to_device(batch)

            if self.model.seq_modeling != 'crnn':
                if self.beamsearch:
                    translated_sentence = batch_translate_beam_search(
                        batch['img'], self.model)
                    prob = None
                else:
                    translated_sentence, prob = translate(
                        batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist())
            else:
                translated_sentence, prob = translate_crnn(
                    batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist(), crnn=True)

            pred_sents.extend(pred_sent)
            img_files.extend(batch['filenames'])
            probs_sents.extend(prob)
        assert len(pred_sents) == len(img_files) and len(img_files) == len(
            probs_sents)
        with open(outfile, 'w', encoding='utf-8') as f:
            for anno in zip(img_files, pred_sents, probs_sents):
                f.write('||||'.join([anno[0], anno[1],
                                     str(float(anno[2]))]) + '\n')
예제 #4
0
class Learner:
    def __init__(self, model, train_loader, valid_loader, fold, config, seed):
        self.config = config
        self.seed = seed
        self.device = self.config.device
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.model = model.to(self.device)

        self.fold = fold
        self.logger = init_logger(
            config.log_dir, f'train_seed{self.seed}_fold{self.fold}.log')
        self.tb_logger = init_tb_logger(
            config.log_dir, f'train_seed{self.seed}_fold{self.fold}')
        if self.fold == 0:
            self.log('\n'.join(
                [f"{k} = {v}" for k, v in self.config.__dict__.items()]))

        self.criterion = SmoothBCEwLogits(smoothing=self.config.smoothing)
        self.evaluator = nn.BCEWithLogitsLoss()
        self.summary_loss = AverageMeter()
        self.history = {'train': [], 'valid': []}

        self.optimizer = Adam(self.model.parameters(),
                              lr=config.lr,
                              weight_decay=self.config.weight_decay)
        self.scheduler = OneCycleLR(optimizer=self.optimizer,
                                    pct_start=0.1,
                                    div_factor=1e3,
                                    max_lr=1e-2,
                                    epochs=config.n_epochs,
                                    steps_per_epoch=len(train_loader))
        self.scaler = GradScaler() if config.fp16 else None

        self.epoch = 0
        self.best_epoch = 0
        self.best_loss = np.inf

    def train_one_epoch(self):
        self.model.train()
        self.summary_loss.reset()
        iters = len(self.train_loader)
        for step, (g_x, c_x, cate_x, labels,
                   non_labels) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            # self.tb_logger.add_scalar('Train/lr', self.optimizer.param_groups[0]['lr'],
            #                           iters * self.epoch + step)
            labels = labels.to(self.device)
            non_labels = non_labels.to(self.device)
            g_x = g_x.to(self.device)
            c_x = c_x.to(self.device)
            cate_x = cate_x.to(self.device)
            batch_size = labels.shape[0]

            with ExitStack() as stack:
                if self.config.fp16:
                    auto = stack.enter_context(autocast())
                outputs = self.model(g_x, c_x, cate_x)
                loss = self.criterion(outputs, labels)

            if self.config.fp16:
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss.backward()
                self.optimizer.step()

            self.summary_loss.update(loss.item(), batch_size)
            if self.scheduler.__class__.__name__ != 'ReduceLROnPlateau':
                self.scheduler.step()

        self.history['train'].append(self.summary_loss.avg)
        return self.summary_loss.avg

    def validation(self):
        self.model.eval()
        self.summary_loss.reset()
        iters = len(self.valid_loader)
        for step, (g_x, c_x, cate_x, labels,
                   non_labels) in enumerate(self.valid_loader):
            with torch.no_grad():
                labels = labels.to(self.device)
                g_x = g_x.to(self.device)
                c_x = c_x.to(self.device)
                cate_x = cate_x.to(self.device)
                batch_size = labels.shape[0]
                outputs = self.model(g_x, c_x, cate_x)
                loss = self.evaluator(outputs, labels)

                self.summary_loss.update(loss.detach().item(), batch_size)

        self.history['valid'].append(self.summary_loss.avg)
        return self.summary_loss.avg

    def fit(self, epochs):
        self.log(f'Start training....')
        for e in range(epochs):
            t = time.time()
            loss = self.train_one_epoch()

            # self.log(f'[Train] \t Epoch: {self.epoch}, loss: {loss:.6f}, time: {(time.time() - t):.2f}')
            self.tb_logger.add_scalar('Train/Loss', loss, self.epoch)

            t = time.time()
            loss = self.validation()

            # self.log(f'[Valid] \t Epoch: {self.epoch}, loss: {loss:.6f}, time: {(time.time() - t):.2f}')
            self.tb_logger.add_scalar('Valid/Loss', loss, self.epoch)
            self.post_processing(loss)

            self.epoch += 1
        self.log(f'best epoch: {self.best_epoch}, best loss: {self.best_loss}')
        return self.history

    def post_processing(self, loss):
        if loss < self.best_loss:
            self.best_loss = loss
            self.best_epoch = self.epoch

            self.model.eval()
            torch.save(
                {
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'best_loss': self.best_loss,
                    'epoch': self.epoch,
                },
                f'{os.path.join(self.config.log_dir, f"{self.config.name}_seed{self.seed}_fold{self.fold}.pth")}'
            )
            self.log(f'best model: {self.epoch} epoch - loss: {loss:.6f}')

    def load(self, path):
        checkpoint = torch.load(path,
                                map_location=lambda storage, loc: storage)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_loss = checkpoint['best_loss']
        self.epoch = checkpoint['epoch'] + 1

    def log(self, text):
        self.logger.info(text)
예제 #5
0
class Trainer():
    def __init__(self, alphabets_, list_ngram):

        self.vocab = Vocab(alphabets_)
        self.synthesizer = SynthesizeData(vocab_path="")
        self.list_ngrams_train, self.list_ngrams_valid = self.train_test_split(
            list_ngram, test_size=0.1)
        print("Loaded data!!!")
        print("Total training samples: ", len(self.list_ngrams_train))
        print("Total valid samples: ", len(self.list_ngrams_valid))

        INPUT_DIM = self.vocab.__len__()
        OUTPUT_DIM = self.vocab.__len__()

        self.device = DEVICE
        self.num_iters = NUM_ITERS
        self.beamsearch = BEAM_SEARCH

        self.batch_size = BATCH_SIZE
        self.print_every = PRINT_PER_ITER
        self.valid_every = VALID_PER_ITER

        self.checkpoint = CHECKPOINT
        self.export_weights = EXPORT
        self.metrics = MAX_SAMPLE_VALID
        logger = LOG

        if logger:
            self.logger = Logger(logger)

        self.iter = 0

        self.model = Seq2Seq(input_dim=INPUT_DIM,
                             output_dim=OUTPUT_DIM,
                             encoder_embbeded=ENC_EMB_DIM,
                             decoder_embedded=DEC_EMB_DIM,
                             encoder_hidden=ENC_HID_DIM,
                             decoder_hidden=DEC_HID_DIM,
                             encoder_dropout=ENC_DROPOUT,
                             decoder_dropout=DEC_DROPOUT)

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer,
                                    total_steps=self.num_iters,
                                    pct_start=PCT_START,
                                    max_lr=MAX_LR)

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        self.train_gen = self.data_gen(self.list_ngrams_train,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=True)
        self.valid_gen = self.data_gen(self.list_ngrams_valid,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=False)

        self.train_losses = []

        # to device
        self.model.to(self.device)
        self.criterion.to(self.device)

    def train_test_split(self, list_phrases, test_size=0.1):
        list_phrases = list_phrases
        train_idx = int(len(list_phrases) * (1 - test_size))
        list_phrases_train = list_phrases[:train_idx]
        list_phrases_valid = list_phrases[train_idx:]
        return list_phrases_train, list_phrases_valid

    def data_gen(self, list_ngrams_np, synthesizer, vocab, is_train=True):
        dataset = AutoCorrectDataset(list_ngrams_np,
                                     transform_noise=synthesizer,
                                     vocab=vocab,
                                     maxlen=MAXLEN)

        shuffle = True if is_train else False
        gen = DataLoader(dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=shuffle,
                         drop_last=False)

        return gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        src, tgt = batch['src'], batch['tgt']
        src, tgt = src.transpose(1, 0), tgt.transpose(
            1, 0)  # batch x src_len -> src_len x batch

        outputs = self.model(
            src, tgt)  # src : src_len x B, outpus : B x tgt_len x vocab

        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
        outputs = outputs.view(-1, outputs.size(2))  # flatten(0, 1)

        tgt_output = tgt.transpose(0, 1).reshape(
            -1)  # flatten()   # tgt: tgt_len xB , need convert to B x tgt_len

        loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()
        self.scheduler.step()

        loss_item = loss.item()

        return loss_item

    def train(self):
        print("Begin training from iter: ", self.iter)
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        best_acc = -1

        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1

            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start

            start = time.time()
            loss = self.step(batch)
            total_gpu_time += time.time() - start

            total_loss += loss
            self.train_losses.append((self.iter, loss))

            if self.iter % self.print_every == 0:
                info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)

                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                print(info)
                self.logger.log(info)

            if self.iter % self.valid_every == 0:
                val_loss, preds, actuals, inp_sents = self.validate()
                acc_full_seq, acc_per_char, cer = self.precision(self.metrics)

                info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f} - CER: {:.4f} '.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char, cer)
                print(info)
                print("--- Sentence predict ---")
                for pred, inp, label in zip(preds, inp_sents, actuals):
                    infor_predict = 'Pred: {} - Inp: {} - Label: {}'.format(
                        pred, inp, label)
                    print(infor_predict)
                    self.logger.log(infor_predict)
                self.logger.log(info)

                if acc_full_seq > best_acc:
                    self.save_weights(self.export_weights)
                    best_acc = acc_full_seq
                self.save_checkpoint(self.checkpoint)

    def validate(self):
        self.model.eval()

        total_loss = []
        max_step = self.metrics / self.batch_size
        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                src, tgt = batch['src'], batch['tgt']
                src, tgt = src.transpose(1, 0), tgt.transpose(1, 0)

                outputs = self.model(src, tgt, 0)  # turn off teaching force

                outputs = outputs.flatten(0, 1)
                tgt_output = tgt.flatten()
                loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                preds, actuals, inp_sents, probs = self.predict(5)

                del outputs
                del loss
                if step > max_step:
                    break

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss, preds[:3], actuals[:3], inp_sents[:3]

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        inp_sents = []

        for batch in self.valid_gen:
            batch = self.batch_to_device(batch)

            if self.beamsearch:
                translated_sentence = batch_translate_beam_search(
                    batch['src'], self.model)
                prob = None
            else:
                translated_sentence, prob = translate(batch['src'], self.model)

            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt'].tolist())
            inp_sent = self.vocab.batch_decode(batch['src'].tolist())

            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)
            inp_sents.extend(inp_sent)

            if sample is not None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, inp_sents, prob

    def precision(self, sample=None):

        pred_sents, actual_sents, _, _ = self.predict(sample=sample)

        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='per_char')
        cer = compute_accuracy(actual_sents, pred_sents, mode='CER')

        return acc_full_seq, acc_per_char, cer

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16):

        pred_sents, actual_sents, img_files, probs = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]
            probs = [probs[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}

    def visualize_dataset(self, sample=16, fontname='serif'):
        n = 0
        for batch in self.train_gen:
            for i in range(self.batch_size):
                img = batch['img'][i].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())

                n += 1
                if n >= sample:
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']

        self.train_losses = checkpoint['train_losses']

    def save_checkpoint(self, filename):
        state = {
            'iter': self.iter,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'scheduler': self.scheduler.state_dict()
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))

        for name, param in self.model.named_parameters():
            if name not in state_dict:
                print('{} not found'.format(name))
            elif state_dict[name].shape != param.shape:
                print('{} missmatching shape, required {} but found {}'.format(
                    name, param.shape, state_dict[name].shape))
                del state_dict[name]

        self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def batch_to_device(self, batch):

        src = batch['src'].to(self.device, non_blocking=True)
        tgt = batch['tgt'].to(self.device, non_blocking=True)

        batch = {'src': src, 'tgt': tgt}

        return batch
예제 #6
0
def main(cfg):
    workdir = Path(cfg.workdir)
    workdir.mkdir(parents=True, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    set_logger(workdir / 'log.txt')
    cfg.dump_to_file(workdir / 'config.yml')
    saver = Saver(workdir, keep_num=10)
    logging.info(f'config: \n{cfg}')
    logging.info(f'use device: {device}')

    model = iqa.__dict__[cfg.model.name](**cfg.model.kwargs)
    model = model.to(device)

    if torch.cuda.device_count() > 1:
        model_dp = nn.DataParallel(model)
    else:
        model_dp = model

    train_transform = Transform(
        transforms.Compose([
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ]))

    val_transform = Transform(
        transforms.Compose([transforms.RandomCrop(224),
                            transforms.ToTensor()]))

    if not Path(cfg.ava.train_cache).exists():
        create_memmap(cfg.ava.train_labels, cfg.ava.images,
                      cfg.ava.train_cache, cfg.num_workers)
    if not Path(cfg.ava.val_cache).exists():
        create_memmap(cfg.ava.train_labels, cfg.ava.images, cfg.ava.val_cache,
                      cfg.num_workers)

    trainset = MemMap(cfg.ava.train_cache, train_transform)
    valset = MemMap(cfg.ava.val_cache, val_transform)

    total_steps = len(trainset) // cfg.batch_size * cfg.num_epochs
    eval_interval = len(trainset) // cfg.batch_size
    logging.info(f'total steps: {total_steps}, eval interval: {eval_interval}')
    model_dp.train()
    parameters = group_parameters(model)
    optimizer = SGD(parameters,
                    cfg.lr,
                    cfg.momentum,
                    weight_decay=cfg.weight_decay)

    lr_scheduler = OneCycleLR(optimizer,
                              max_lr=cfg.lr,
                              div_factor=cfg.lr / cfg.warmup_lr,
                              total_steps=total_steps,
                              pct_start=0.01,
                              final_div_factor=cfg.warmup_lr / cfg.final_lr)

    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=cfg.batch_size,
                                               shuffle=True,
                                               num_workers=cfg.num_workers,
                                               drop_last=True,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=cfg.batch_size,
                                             shuffle=False,
                                             num_workers=cfg.num_workers,
                                             pin_memory=True)

    curr_loss = 1e9
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'step': 0,  # init step,
        'cfg': cfg,
        'loss': curr_loss
    }

    saver.save(0, state)

    trainloader = repeat_loader(train_loader)
    batch_processor = BatchProcessor(device)
    start = time.time()
    for step in range(0, total_steps, eval_interval):
        num_steps = min(step + eval_interval, total_steps) - step
        step += num_steps
        trainmeter = train_steps(model_dp, trainloader, optimizer,
                                 lr_scheduler, emd_loss, batch_processor,
                                 num_steps)
        valmeter = evaluate(model_dp, val_loader, emd_loss, batch_processor)
        finish = time.time()
        img_s = cfg.batch_size * eval_interval / (finish - start)
        loss = valmeter.meters['loss'].global_avg

        state = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'step': step,  # init step,
            'cfg': cfg,
            'loss': loss
        }
        saver.save(step, state)

        if loss < curr_loss:
            curr_loss = loss
            saver.save_best(state)

        logging.info(
            f'step: [{step}/{total_steps}] img_s: {img_s:.2f} train: [{trainmeter}] eval:[{valmeter}]'
        )
        start = time.time()
예제 #7
0
파일: train.py 프로젝트: jimmysue/xvision
def main(args):
    # prepare workspace

    workdir = Path(args.workdir)
    workdir.mkdir(parents=True, exist_ok=True)
    logger = get_logger(workdir / 'log.txt')
    logger.info(f'config: \n{args}')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.device:
        logger.info(f'user specify device: {args.device}')
        device = torch.device(args.device)
    logger.info(f'use device: {device}')

    # dump all configues to later use, such as for testing
    with open(workdir / 'config.yml', 'wt') as f:
        args.dump(stream=f)

    saver = Saver(workdir, keep_num=10)

    # prepare dataset
    valtransform = ValTransform(dsize=args.dsize)
    traintransform = TrainTransform(dsize=args.dsize, **args.augments)
    trainset = WiderFace(args.train_label,
                         args.train_image,
                         min_face=1,
                         with_shapes=True,
                         transform=traintransform)
    valset = WiderFace(args.val_label,
                       args.val_image,
                       transform=valtransform,
                       min_face=1)

    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True,
                             collate_fn=wider_collate,
                             drop_last=True)
    valloader = DataLoader(valset,
                           batch_size=args.batch_size,
                           shuffle=False,
                           num_workers=args.num_workers,
                           pin_memory=True,
                           collate_fn=wider_collate)

    # model
    model = models.__dict__[args.model.name](phase='train').to(device)
    prior = BBoxShapePrior(args.num_classes, 5, args.anchors,
                           args.iou_threshold, args.encode_mean,
                           args.encode_std)

    model = Detector(prior, model)

    # optimizer and lr scheduler
    parameters = group_parameters(model, bias_decay=0)
    optimizer = SGD(parameters,
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    lr_scheduler = OneCycleLR(optimizer,
                              max_lr=args.lr,
                              div_factor=20,
                              total_steps=args.total_steps,
                              pct_start=0.1,
                              final_div_factor=100)
    trainloader = repeat_loader(trainloader)

    model.to(device)
    model.train()

    best_loss = 1e9
    state = {
        'model': model.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'optimizer': optimizer.state_dict(),
        'step': 0,
        'loss': best_loss
    }
    saver.save(0, state)

    def reset_meter():
        meter = MetricLogger()
        meter.add_meter('lr', SmoothedValue(1, fmt='{value:.5f}'))
        return meter

    train_meter = reset_meter()
    start = time.time()
    for step in range(args.start_step, args.total_steps):
        batch = next(trainloader)
        batch = batch_to(batch, device)
        image = batch['image']
        box = batch['bbox']
        point = batch['shape']
        mask = batch['mask']
        label = batch['label']

        score_loss, box_loss, point_loss = model(image,
                                                 targets=(label, box, point,
                                                          mask))
        loss = score_loss + 2.0 * box_loss + point_loss

        train_meter.meters['score'].update(score_loss.item())
        train_meter.meters['box'].update(box_loss.item())
        train_meter.meters['shape'].update(point_loss.item())
        train_meter.meters['total'].update(loss.item())
        train_meter.meters['lr'].update(optimizer.param_groups[0]['lr'])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        if (step + 1) % args.eval_interval == 0:
            duration = time.time() - start
            img_s = args.eval_interval * args.batch_size / duration
            eval_meter = evaluate(model, valloader, prior, device)

            logger.info(
                f'Step [{step + 1}/{args.total_steps}] img/s: {img_s:.2f} train: [{train_meter}] eval: [{eval_meter}]'
            )
            train_meter = reset_meter()
            start = time.time()
            curr_loss = eval_meter.meters['total'].global_avg
            state = {
                'model': model.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
                'step': curr_loss,
            }
            saver.save(step + 1, state)

            if (curr_loss < best_loss):
                best_loss = curr_loss
                saver.save_best(state)
예제 #8
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    _logger.info('====================\n\n'
                 'Actfun: {}\n'
                 'LR: {}\n'
                 'Epochs: {}\n'
                 'p: {}\n'
                 'k: {}\n'
                 'g: {}\n'
                 'Extra channel multiplier: {}\n'
                 'Weight Init: {}\n'
                 '\n===================='.format(args.actfun, args.lr,
                                                 args.epochs, args.p, args.k,
                                                 args.g,
                                                 args.extra_channel_mult,
                                                 args.weight_init))

    # ================================================================================= Loading models
    pre_model = create_model(
        args.model,
        pretrained=True,
        actfun='swish',
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint,
        p=args.p,
        k=args.k,
        g=args.g,
        extra_channel_mult=args.extra_channel_mult,
        weight_init_name=args.weight_init,
        partial_ho_actfun=args.partial_ho_actfun)
    pre_model_layers = list(pre_model.children())
    pre_model = torch.nn.Sequential(*pre_model_layers[:-1])
    pre_model.to(device)

    model = MLP.MLP(actfun=args.actfun,
                    input_dim=1280,
                    output_dim=args.num_classes,
                    k=args.k,
                    p=args.p,
                    g=args.g,
                    num_params=1_000_000,
                    permute_type='shuffle')
    model.to(device)

    # ================================================================================= Loading dataset
    util.seed_all(args.seed)
    if args.data == 'caltech101' and not os.path.exists('caltech101'):
        dir_root = r'101_ObjectCategories'
        dir_new = r'caltech101'
        dir_new_train = os.path.join(dir_new, 'train')
        dir_new_val = os.path.join(dir_new, 'val')
        dir_new_test = os.path.join(dir_new, 'test')
        if not os.path.exists(dir_new):
            os.mkdir(dir_new)
            os.mkdir(dir_new_train)
            os.mkdir(dir_new_val)
            os.mkdir(dir_new_test)

        for dir2 in os.listdir(dir_root):
            if dir2 != 'BACKGROUND_Google':
                curr_path = os.path.join(dir_root, dir2)
                new_path_train = os.path.join(dir_new_train, dir2)
                new_path_val = os.path.join(dir_new_val, dir2)
                new_path_test = os.path.join(dir_new_test, dir2)
                if not os.path.exists(new_path_train):
                    os.mkdir(new_path_train)
                if not os.path.exists(new_path_val):
                    os.mkdir(new_path_val)
                if not os.path.exists(new_path_test):
                    os.mkdir(new_path_test)

                train_upper = int(0.8 * len(os.listdir(curr_path)))
                val_upper = int(0.9 * len(os.listdir(curr_path)))
                curr_files_all = os.listdir(curr_path)
                curr_files_train = curr_files_all[:train_upper]
                curr_files_val = curr_files_all[train_upper:val_upper]
                curr_files_test = curr_files_all[val_upper:]

                for file in curr_files_train:
                    copyfile(os.path.join(curr_path, file),
                             os.path.join(new_path_train, file))
                for file in curr_files_val:
                    copyfile(os.path.join(curr_path, file),
                             os.path.join(new_path_val, file))
                for file in curr_files_test:
                    copyfile(os.path.join(curr_path, file),
                             os.path.join(new_path_test, file))
    time.sleep(5)

    # create the train and eval datasets
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        _logger.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            _logger.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha=args.mixup,
                          cutmix_alpha=args.cutmix,
                          cutmix_minmax=args.cutmix_minmax,
                          prob=args.mixup_prob,
                          switch_prob=args.mixup_switch_prob,
                          mode=args.mixup_mode,
                          label_smoothing=args.smoothing,
                          num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # create data loaders w/ augmentation pipeline
    train_interpolation = args.train_interpolation
    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # ================================================================================= Optimizer / scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=args.lr,
        epochs=args.epochs,
        steps_per_epoch=int(math.floor(len(dataset_train) / args.batch_size)),
        cycle_momentum=False)

    # ================================================================================= Save file / checkpoints
    fieldnames = [
        'dataset', 'seed', 'epoch', 'time', 'actfun', 'model', 'batch_size',
        'alpha_primes', 'alphas', 'num_params', 'k', 'p', 'g', 'perm_method',
        'gen_gap', 'epoch_train_loss', 'epoch_train_acc',
        'epoch_aug_train_loss', 'epoch_aug_train_acc', 'epoch_val_loss',
        'epoch_val_acc', 'curr_lr', 'found_lr', 'epochs'
    ]
    filename = 'out_{}_{}_{}_{}'.format(datetime.date.today(), args.actfun,
                                        args.data, args.seed)
    outfile_path = os.path.join(args.output, filename) + '.csv'
    checkpoint_path = os.path.join(args.check_path, filename) + '.pth'
    if not os.path.exists(outfile_path):
        with open(outfile_path, mode='w') as out_file:
            writer = csv.DictWriter(out_file,
                                    fieldnames=fieldnames,
                                    lineterminator='\n')
            writer.writeheader()

    epoch = 1
    checkpoint = torch.load(checkpoint_path) if os.path.exists(
        checkpoint_path) else None
    if checkpoint is not None:
        pre_model.load_state_dict(checkpoint['pre_model_state_dict'])
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        epoch = checkpoint['epoch']
        pre_model.to(device)
        model.to(device)
        print("*** LOADED CHECKPOINT ***"
              "\n{}"
              "\nSeed: {}"
              "\nEpoch: {}"
              "\nActfun: {}"
              "\np: {}"
              "\nk: {}"
              "\ng: {}"
              "\nperm_method: {}".format(checkpoint_path,
                                         checkpoint['curr_seed'],
                                         checkpoint['epoch'],
                                         checkpoint['actfun'], checkpoint['p'],
                                         checkpoint['k'], checkpoint['g'],
                                         checkpoint['perm_method']))

    args.mix_pre_apex = False
    if args.control_amp == 'apex':
        args.mix_pre_apex = True
        model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

    # ================================================================================= Training
    while epoch <= args.epochs:

        if args.check_path != '':
            torch.save(
                {
                    'pre_model_state_dict': pre_model.state_dict(),
                    'model_state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'curr_seed': args.seed,
                    'epoch': epoch,
                    'actfun': args.actfun,
                    'p': args.p,
                    'k': args.k,
                    'g': args.g,
                    'perm_method': 'shuffle'
                }, checkpoint_path)

        util.seed_all((args.seed * args.epochs) + epoch)
        start_time = time.time()
        args.mix_pre = False
        if args.control_amp == 'native':
            args.mix_pre = True
            scaler = torch.cuda.amp.GradScaler()

        # ---- Training
        model.train()
        total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
        for batch_idx, (x, targetx) in enumerate(loader_train):
            x, targetx = x.to(device), targetx.to(device)
            optimizer.zero_grad()
            if args.mix_pre:
                with torch.cuda.amp.autocast():
                    with torch.no_grad():
                        x = pre_model(x)
                    output = model(x)
                    train_loss = criterion(output, targetx)
                total_train_loss += train_loss
                n += 1
                scaler.scale(train_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            elif args.mix_pre_apex:
                with torch.no_grad():
                    x = pre_model(x)
                output = model(x)
                train_loss = criterion(output, targetx)
                total_train_loss += train_loss
                n += 1
                with amp.scale_loss(train_loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()
            else:
                with torch.no_grad():
                    x = pre_model(x)
                output = model(x)
                train_loss = criterion(output, targetx)
                total_train_loss += train_loss
                n += 1
                train_loss.backward()
                optimizer.step()
            scheduler.step()
            _, prediction = torch.max(output.data, 1)
            num_correct += torch.sum(prediction == targetx.data)
            num_total += len(prediction)
        epoch_aug_train_loss = total_train_loss / n
        epoch_aug_train_acc = num_correct * 1.0 / num_total

        alpha_primes = []
        alphas = []
        if model.actfun == 'combinact':
            for i, layer_alpha_primes in enumerate(model.all_alpha_primes):
                curr_alpha_primes = torch.mean(layer_alpha_primes, dim=0)
                curr_alphas = F.softmax(curr_alpha_primes, dim=0).data.tolist()
                curr_alpha_primes = curr_alpha_primes.tolist()
                alpha_primes.append(curr_alpha_primes)
                alphas.append(curr_alphas)

        model.eval()
        with torch.no_grad():
            total_val_loss, n, num_correct, num_total = 0, 0, 0, 0
            for batch_idx, (y, targety) in enumerate(loader_eval):
                y, targety = y.to(device), targety.to(device)
                with torch.no_grad():
                    y = pre_model(y)
                output = model(y)
                val_loss = criterion(output, targety)
                total_val_loss += val_loss
                n += 1
                _, prediction = torch.max(output.data, 1)
                num_correct += torch.sum(prediction == targety.data)
                num_total += len(prediction)
            epoch_val_loss = total_val_loss / n
            epoch_val_acc = num_correct * 1.0 / num_total
        lr_curr = 0
        for param_group in optimizer.param_groups:
            lr_curr = param_group['lr']
        print(
            "    Epoch {}: LR {:1.5f} ||| aug_train_acc {:1.4f} | val_acc {:1.4f} ||| "
            "aug_train_loss {:1.4f} | val_loss {:1.4f} ||| time = {:1.4f}".
            format(epoch, lr_curr, epoch_aug_train_acc, epoch_val_acc,
                   epoch_aug_train_loss, epoch_val_loss,
                   (time.time() - start_time)),
            flush=True)

        epoch_train_loss = 0
        epoch_train_acc = 0
        if epoch == args.epochs:
            with torch.no_grad():
                total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
                for batch_idx, (x, targetx) in enumerate(loader_train):
                    x, targetx = x.to(device), targetx.to(device)
                    with torch.no_grad():
                        x = pre_model(x)
                    output = model(x)
                    train_loss = criterion(output, targetx)
                    total_train_loss += train_loss
                    n += 1
                    _, prediction = torch.max(output.data, 1)
                    num_correct += torch.sum(prediction == targetx.data)
                    num_total += len(prediction)
                epoch_aug_train_loss = total_train_loss / n
                epoch_aug_train_acc = num_correct * 1.0 / num_total

                total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
                for batch_idx, (x, targetx) in enumerate(loader_eval):
                    x, targetx = x.to(device), targetx.to(device)
                    with torch.no_grad():
                        x = pre_model(x)
                    output = model(x)
                    train_loss = criterion(output, targetx)
                    total_train_loss += train_loss
                    n += 1
                    _, prediction = torch.max(output.data, 1)
                    num_correct += torch.sum(prediction == targetx.data)
                    num_total += len(prediction)
                epoch_train_loss = total_val_loss / n
                epoch_train_acc = num_correct * 1.0 / num_total

        # Outputting data to CSV at end of epoch
        with open(outfile_path, mode='a') as out_file:
            writer = csv.DictWriter(out_file,
                                    fieldnames=fieldnames,
                                    lineterminator='\n')
            writer.writerow({
                'dataset': args.data,
                'seed': args.seed,
                'epoch': epoch,
                'time': (time.time() - start_time),
                'actfun': model.actfun,
                'model': args.model,
                'batch_size': args.batch_size,
                'alpha_primes': alpha_primes,
                'alphas': alphas,
                'num_params': util.get_model_params(model),
                'k': args.k,
                'p': args.p,
                'g': args.g,
                'perm_method': 'shuffle',
                'gen_gap': float(epoch_val_loss - epoch_train_loss),
                'epoch_train_loss': float(epoch_train_loss),
                'epoch_train_acc': float(epoch_train_acc),
                'epoch_aug_train_loss': float(epoch_aug_train_loss),
                'epoch_aug_train_acc': float(epoch_aug_train_acc),
                'epoch_val_loss': float(epoch_val_loss),
                'epoch_val_acc': float(epoch_val_acc),
                'curr_lr': lr_curr,
                'found_lr': args.lr,
                'epochs': args.epochs
            })

        epoch += 1
def main():

    # Training settings and hyperparameters
    parser = argparse.ArgumentParser(description='SpenceNet Pytorch Training')
    parser.add_argument('--encoder', default='XResNet34', type=str,
                        choices=['XResNet18', 'XResNet34', 'XResNet50'],
                        help='encoder architecture (default: XResNet34)')
    parser.add_argument('--num_workers', default=2, type=int,
                        help='number of data loading workers (default: 2)')
    parser.add_argument('--epochs', default=30, type=int,
                        help='number of total training epochs')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--use_grayscale', default=True,
                        help='turn input images to grayscale (default: True)')
    parser.add_argument('--img_size', type=int, default=300,
                        help='target image size for training (default: 300)')
    parser.add_argument('--max_lr', type=float, default=0.001,
                        help='maximum learning rate (default: 0.001)')
    parser.add_argument('--encoder_lr_mult', type=float, default=0.25,
                        help='encoder_lr = max_lr * this value (0.25 default)')
    parser.add_argument('--weight_decay', type=float, default=0.001,
                        help='weight decay (default: 0.001)')
    parser.add_argument('--sched_pct_start', type=float, default=0.3,
                        help='OneCycleLR pct_start parameter (default: 0.3)')
    parser.add_argument('--sched_div_factor', type=float, default=10.0,
                        help='OneCycleLR div factor (default: 10.0)')
    parser.add_argument('--wing_loss_e', type=float, default=2.0,
                        help='Wing Loss e parameter (default: 2.0)')
    parser.add_argument('--wing_loss_w', type=float, default=10.0,
                        help='Wing Loss w parameter (default: 10.0)')
    parser.add_argument('--use_cuda', default=True,
                        help='Enables CUDA training (default: True)')
    parser.add_argument('--seed', type=int, default=None,
                        help='fix random seed for training (default: None)')
    parser.add_argument('--wandb_project', default='multi-head-spencenet',
                        type=str, help='WandB project name')
    parser.add_argument('--save_dir', default='saved/', type=str,
                        help='directory to save outputs in (default: saved/)')
    parser.add_argument('--resume', default='', type=str,
                        help='path to checkpoint to optionally resume from')

    config = parser.parse_args()
    wandb_config = vars(config)  # WandB expects dictionary

    # Get timestamp
    today = datetime.now(tz=utc)
    today = today.astimezone(timezone('US/Pacific'))
    timestamp = today.strftime("%b_%d_%Y_%H_%M")

    wandb.init(config=wandb_config,
               project=config.wandb_project,
               dir=config.save_dir,
               name=timestamp,
               id=timestamp)

    use_cuda = config.use_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': config.num_workers,
              'pin_memory': True} if use_cuda else {}

    # Fix random seeds and deterministic pytorch for reproducibility
    if config.seed:
        torch.manual_seed(config['seed'])  # pytorch random seed
        np.random.seed(config['seed'])  # numpy random seed
        torch.backends.cudnn.deterministic = True

    # DATASET LOADING
    # Letters Dictionary >>> Class ID: [Letter Name, # of Coordinate Values]
    letter_dict = {0: ['alpha', 20], 1: ['beta', 28], 2: ['gamma', 16]}
    letter_ordered_dict = OrderedDict(sorted(letter_dict.items()))

    # Define the tranformations
    train_transforms = transforms.Compose([
                                lt.RandomCrop(10),
                                lt.RandomRotate(10),
                                lt.RandomLightJitter(0.2),
                                lt.RandomPerspective(0.5),
                                lt.Resize(config.img_size),
                                lt.ToNormalizedTensor()
                                ])
    test_transforms = transforms.Compose([
                                    lt.Resize(config.img_size),
                                    lt.ToNormalizedTensor()
                                    ])

    # Add grayscale transform
    if config.use_grayscale:
        train_transforms.transforms.insert(0, lt.ToGrayscale())
        test_transforms.transforms.insert(0, lt.ToGrayscale())

    # Define separate datasets for each annotated class
    letters = [key for key, val in letter_dict.items() if val[1] != 0]
    train_ds_list = []
    test_ds_list = []

    for letter in letters:
        train_ds_list.append(LetterDataset(f'./data/{letter_dict[letter][0]}_small_data.csv',
                                           num_coordinates=letter_dict[letter][1],
                                           transform=train_transforms))
        test_ds_list.append(LetterDataset(f'./data/{letter_dict[letter][0]}_small_data.csv',
                                          is_validation=True,
                                          num_coordinates=letter_dict[letter][1],
                                          transform=test_transforms))

    # Concatenated Datasets
    train_datasets = ConcatDataset(train_ds_list)
    test_datasets = ConcatDataset(test_ds_list)

    # Define Dataloaders with custom LetterBatchSampler
    train_loader = DataLoader(dataset=train_datasets,
                              sampler=LetterBatchSampler(
                                          dataset=train_datasets,
                                          batch_size=config.batch_size,
                                          drop_last=True),
                              batch_size=config.batch_size,
                              **kwargs)

    test_loader = DataLoader(dataset=test_datasets,
                             sampler=LetterBatchSampler(
                                        dataset=test_datasets,
                                        batch_size=config.batch_size,
                                        drop_last=True),
                             batch_size=config.batch_size,
                             **kwargs)

    # INITIALIZE MODEL
    model = SpenceNet(letter_ordered_dict,
                      backbone=config.encoder,
                      c_in=1 if config.use_grayscale else 3,
                      img_size=config.img_size).to(device)

    optimizer = optim.AdamW([
                            {'params': model.encoder.parameters(),
                             'lr': config.max_lr*config.encoder_lr_mult},
                            {'params': model.classification_head.parameters()},
                            {'params': model.keypoint_heads.parameters()}
                            ],
                            lr=config.max_lr, betas=(0.9, 0.99),
                            weight_decay=config.weight_decay)

    # Initialize Loss Function
    criterion = MultiLoss(e=config.wing_loss_e, w=config.wing_loss_w)

    # LR Scheduler
    scheduler = OneCycleLR(optimizer,
                           max_lr=config.max_lr,
                           pct_start=config.sched_pct_start,
                           div_factor=config.sched_div_factor,
                           steps_per_epoch=len(train_loader),
                           epochs=config.epochs)

    # Optionally resume from saved checkpoint
    if config.resume:
        model, optimizer, scheduler, curr_epoch, ckp_loss = load_checkpoint(config.resume, model, optimizer, scheduler)
        start_epoch = curr_epoch
        best_loss = ckp_loss
        print(f'Resuming from checkpoint... Epoch: {start_epoch} Loss: {best_loss:.4f}')
    else:
        start_epoch = 0
        best_loss = math.inf

    # Track all gradients/parameters with WandB
    wandb.watch(model, log='all')

    # Training start time
    training_start = time.time()

    for epoch in range(start_epoch, config.epochs):
        train_metrics = train(config, model, device, train_loader, optimizer, scheduler, criterion)
        test_metrics = test(config, model, device, test_loader, criterion, len(letters))

        # Log training data and metrics
        # TODO: in test, randomly return 4 img per class based on len(letters)
        log_metrics(timestamp,
                    training_start,
                    epoch,
                    config.epochs,
                    train_metrics,
                    test_metrics)

        # Checkpoint saving
        is_best = test_metrics['test_multi_loss'] < best_loss
        best_loss = min(test_metrics['test_multi_loss'], best_loss)
        save_checkpoint({'epoch': epoch,
                         'loss': test_metrics['test_multi_loss'],
                         'model_state': model.state_dict(),
                         'opt_state': optimizer.state_dict(),
                         'sched_state': scheduler.state_dict()},
                        is_best,
                        checkpoint_dir=f'saved/{timestamp}/')
예제 #10
0
파일: train.py 프로젝트: jimmysue/xvision
def main(args):
    workdir = Path(args.workdir)
    workdir.mkdir(parents=True, exist_ok=True)
    logger = get_logger(workdir / 'log.txt')
    logger.info(f'config:\n{args}')
    saver = Saver(workdir, keep_num=10)
    # dump all configues
    with open(workdir / 'config.yml', 'wt') as f:
        args.dump(stream=f)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f'use device: {device}')

    num_points = len(args.data.symmetry)
    model = models.__dict__[args.model.name](num_points)

    model.to(device)
    parameters = group_parameters(model, bias_decay=0)
    optimizer = SGD(parameters,
                    args.lr,
                    args.momentum,
                    weight_decay=args.weight_decay)
    lr_scheduler = OneCycleLR(optimizer,
                              max_lr=args.lr,
                              div_factor=20,
                              total_steps=args.total_steps,
                              pct_start=0.1,
                              final_div_factor=100)

    # datasets
    valtransform = Transform(args.dsize, args.padding, args.data.meanshape,
                             args.data.meanbbox)
    traintransform = Transform(args.dsize, args.padding, args.data.meanshape,
                               args.data.meanbbox, args.data.symmetry,
                               args.augments)

    traindata = datasets.__dict__[args.data.name](**args.data.train)
    valdata = datasets.__dict__[args.data.name](**args.data.val)
    traindata.transform = traintransform
    valdata.transform = valtransform

    trainloader = DataLoader(traindata,
                             args.batch_size,
                             shuffle=True,
                             drop_last=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    valloader = DataLoader(valdata,
                           args.batch_size,
                           False,
                           num_workers=args.num_workers,
                           pin_memory=False)

    def repeat(loader):
        while True:
            for batch in loader:
                yield batch

    best_loss = 1e9
    state = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'step': 0,
        'loss': best_loss,
        'cfg': args
    }
    score_fn = IbugScore(args.left_eye, args.right_eye)
    saver.save(0, state)
    repeatloader = repeat(trainloader)
    start = time.time()
    for step in range(0, args.total_steps, args.eval_interval):
        num_steps = min(args.eval_interval, args.total_steps - step)
        step += num_steps
        trainmeter = train_steps(model, repeatloader, optimizer, lr_scheduler,
                                 score_fn, device, num_steps)
        evalmeter = evaluate(model, valloader, score_fn, device)
        curr_loss = evalmeter.meters['loss'].global_avg
        finish = time.time()
        img_s = num_steps * args.batch_size / (finish - start)
        logger.info(
            f'step: [{step}/{args.total_steps}] img/s: {img_s:.2f} train: [{trainmeter}] eval: [{evalmeter}]'
        )

        state = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'step': step,
            'loss': curr_loss,
            'cfg': args
        }
        saver.save(step, state)

        if curr_loss < best_loss:
            saver.save_best(state)
            best_loss = curr_loss

        start = time.time()
예제 #11
0
                loss_history = []
                for loss in train_epoch(model, train_loader, optimizer, scheduler, config.lambda_sparse, device):
                    loss_history.append(loss)
                logger.info(f"Epoch {ep:03d}/{num_epochs:03d}, train loss: {np.mean(loss_history):.6f}")

                ### Predict on validation ###
                logit_log_loss, auc = validation(model, val_loader, device)
                if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
                    scheduler.step(logit_log_loss)
                if logit_log_loss < best_logit_log_loss:
                    best_logit_log_loss = logit_log_loss
                    # best_auc = auc
                    write_this = {
                        'model': model.state_dict(),
                        'optim': optimizer.state_dict(),
                        'sched': scheduler.state_dict(),
                        'epoch': ep,
                    }
                    torch.save(write_this, filepath)
                    logger.info(f" ** Updated the best weight, logit log loss: {logit_log_loss:.6f}, auc: {auc:.6f} **")
                else:
                    logger.info(f"Passed to save the weight, best: {best_logit_log_loss:.6f} / logit log loss: {logit_log_loss:.6f}, auc: {auc:.6f}")

            ### Save OOF for CV ###
            best_state_dict = torch.load(filepath)
            model.load_state_dict(best_state_dict['model'])
            val_loader = DataLoader(
                DatasetWithoutLabel(X_val),
                batch_size=config.batch_size,
                collate_fn=collate_fn_test,
                shuffle=False,