Ejemplo n.º 1
0
class Maskv3Agent:
    def __init__(self, config):
        self.config = config

        # Train on device
        target_device = config['train']['device']
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.device = target_device
        else:
            self.device = "cpu"

        # Load dataset
        train_transform = get_yolo_transform(config['dataset']['size'],
                                             mode='train')
        valid_transform = get_yolo_transform(config['dataset']['size'],
                                             mode='test')
        train_dataset = YOLOMaskDataset(
            csv_file=config['dataset']['train']['csv'],
            img_dir=config['dataset']['train']['img_root'],
            mask_dir=config['dataset']['train']['mask_root'],
            anchors=config['dataset']['anchors'],
            scales=config['dataset']['scales'],
            n_classes=config['dataset']['n_classes'],
            transform=train_transform)
        valid_dataset = YOLOMaskDataset(
            csv_file=config['dataset']['valid']['csv'],
            img_dir=config['dataset']['valid']['img_root'],
            mask_dir=config['dataset']['valid']['mask_root'],
            anchors=config['dataset']['anchors'],
            scales=config['dataset']['scales'],
            n_classes=config['dataset']['n_classes'],
            transform=valid_transform)
        # DataLoader
        self.train_loader = DataLoader(
            dataset=train_dataset,
            batch_size=config['dataloader']['batch_size'],
            num_workers=config['dataloader']['num_workers'],
            collate_fn=maskv3_collate_fn,
            pin_memory=True,
            shuffle=True,
            drop_last=False)
        self.valid_loader = DataLoader(
            dataset=valid_dataset,
            batch_size=config['dataloader']['batch_size'],
            num_workers=config['dataloader']['num_workers'],
            collate_fn=maskv3_collate_fn,
            pin_memory=True,
            shuffle=False,
            drop_last=False)
        # Model
        model = Maskv3(
            # Detection Branch
            in_channels=config['model']['in_channels'],
            num_classes=config['model']['num_classes'],
            # Prototype Branch
            num_masks=config['model']['num_masks'],
            num_features=config['model']['num_features'],
        )
        self.model = model.to(self.device)
        # Faciliated Anchor boxes with model
        torch_anchors = torch.tensor(config['dataset']['anchors'])  # (3, 3, 2)
        torch_scales = torch.tensor(config['dataset']['scales'])  # (3,)
        scaled_anchors = (  # (3, 3, 2)
            torch_anchors *
            (torch_scales.unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)))
        self.scaled_anchors = scaled_anchors.to(self.device)

        # Optimizer
        self.scaler = torch.cuda.amp.GradScaler()
        self.optimizer = optim.Adam(
            params=self.model.parameters(),
            lr=config['optimizer']['lr'],
            weight_decay=config['optimizer']['weight_decay'],
        )
        # Scheduler
        self.scheduler = OneCycleLR(
            self.optimizer,
            max_lr=config['optimizer']['lr'],
            epochs=config['train']['n_epochs'],
            steps_per_epoch=len(self.train_loader),
        )
        # Loss function
        self.loss_fn = YOLOMaskLoss(num_classes=config['model']['num_classes'],
                                    num_masks=config['model']['num_masks'])

        # Tensorboard
        self.logdir = config['train']['logdir']
        self.board = SummaryWriter(logdir=config['train']['logdir'])

        # Training State
        self.current_epoch = 0
        self.current_map = 0

    def resume(self):
        checkpoint_path = osp.join(self.logdir, 'best.pth')
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        self.current_map = checkpoint['current_map']
        self.current_epoch = checkpoint['current_epoch']
        print("Restore checkpoint at '{}'".format(self.current_epoch))

    def train(self):
        for epoch in range(self.current_epoch + 1,
                           self.config['train']['n_epochs'] + 1):
            self.current_epoch = epoch
            self._train_one_epoch()
            self._validate()
            accs = self._check_accuracy()

            if self.current_epoch < self.config['valid']['when']:
                self._save_checkpoint()

            if (self.current_epoch >= self.config['valid']['when']
                    and self.current_epoch % 5 == 0):
                mAP50 = self._check_map()
                if mAP50 > self.current_map:
                    self.current_map = mAP50
                    self._save_checkpoint()

    def finalize(self):
        self._check_map()

    def _train_one_epoch(self):
        n_epochs = self.config['train']['n_epochs']
        current_epoch = self.current_epoch
        current_lr = self.optimizer.param_groups[0]['lr']
        loop = tqdm(self.train_loader,
                    leave=True,
                    desc=(f"Train Epoch:{current_epoch}/{n_epochs}"
                          f", LR: {current_lr:.5f}"))
        obj_losses = []
        box_losses = []
        noobj_losses = []
        class_losses = []
        total_losses = []
        segment_losses = []
        self.model.train()
        for batch_idx, (imgs, masks, targets) in enumerate(loop):
            # Move device
            imgs = imgs.to(self.device)  # (N, 3, 416, 416)
            masks = [m.to(self.device) for m in masks]  # (nM_g, H, W)
            target_s1 = targets[0].to(self.device)  # (N, 3, 13, 13, 6)
            target_s2 = targets[1].to(self.device)  # (N, 3, 26, 26, 6)
            target_s3 = targets[2].to(self.device)  # (N, 3, 52, 52, 6)
            # Model prediction
            with torch.cuda.amp.autocast():
                outs, prototypes = self.model(imgs)
                s1_loss = self.loss_fn(
                    outs[0],
                    target_s1,
                    self.scaled_anchors[0],  # Detection Branch
                    prototypes,
                    masks,  # Prototype Branch
                )
                s2_loss = self.loss_fn(
                    outs[1],
                    target_s2,
                    self.scaled_anchors[1],  # Detection Branch
                    prototypes,
                    masks,  # Prototype Branch
                )
                s3_loss = self.loss_fn(
                    outs[2],
                    target_s3,
                    self.scaled_anchors[2],  # Detection Branch
                    prototypes,
                    masks,  # Prototype Branch
                )
            # Aggregate loss
            obj_loss = s1_loss['obj_loss'] + s2_loss['obj_loss'] + s3_loss[
                'obj_loss']
            box_loss = s1_loss['box_loss'] + s2_loss['box_loss'] + s3_loss[
                'box_loss']
            noobj_loss = s1_loss['noobj_loss'] + s2_loss[
                'noobj_loss'] + s3_loss['noobj_loss']
            class_loss = s1_loss['class_loss'] + s2_loss[
                'class_loss'] + s3_loss['class_loss']
            segment_loss = s1_loss['segment_loss'] + s2_loss[
                'segment_loss'] + s3_loss['segment_loss']
            total_loss = s1_loss['total_loss'] + s2_loss[
                'total_loss'] + s3_loss['total_loss']
            # Moving average loss
            total_losses.append(total_loss.item())
            obj_losses.append(obj_loss.item())
            noobj_losses.append(noobj_loss.item())
            box_losses.append(box_loss.item())
            class_losses.append(class_loss.item())
            segment_losses.append(segment_loss.item())
            # Update Parameters
            self.optimizer.zero_grad()
            self.scaler.scale(total_loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.scheduler.step()
            # Upadte progress bar
            mean_total_loss = sum(total_losses) / len(total_losses)
            mean_obj_loss = sum(obj_losses) / len(obj_losses)
            mean_noobj_loss = sum(noobj_losses) / len(noobj_losses)
            mean_box_loss = sum(box_losses) / len(box_losses)
            mean_class_loss = sum(class_losses) / len(class_losses)
            mean_segment_loss = sum(segment_losses) / len(segment_losses)
            loop.set_postfix(
                loss=mean_total_loss,
                cls=mean_class_loss,
                box=mean_box_loss,
                obj=mean_obj_loss,
                noobj=mean_noobj_loss,
                segment=mean_segment_loss,
            )
        # Logging (epoch)
        epoch_total_loss = sum(total_losses) / len(total_losses)
        epoch_obj_loss = sum(obj_losses) / len(obj_losses)
        epoch_noobj_loss = sum(noobj_losses) / len(noobj_losses)
        epoch_box_loss = sum(box_losses) / len(box_losses)
        epoch_class_loss = sum(class_losses) / len(class_losses)
        epoch_segment_loss = sum(segment_losses) / len(segment_losses)
        self.board.add_scalar('Epoch Train Loss',
                              epoch_total_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Train BOX Loss',
                              epoch_box_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Train OBJ Loss',
                              epoch_obj_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Train NOOBJ Loss',
                              epoch_noobj_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Train CLASS Loss',
                              epoch_class_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Train SEGMENT Loss',
                              epoch_segment_loss,
                              global_step=self.current_epoch)

    def _validate(self):
        n_epochs = self.config['train']['n_epochs']
        current_epoch = self.current_epoch
        current_lr = self.optimizer.param_groups[0]['lr']
        loop = tqdm(self.valid_loader,
                    leave=True,
                    desc=(f"Valid Epoch:{current_epoch}/{n_epochs}"
                          f", LR: {current_lr:.5f}"))
        obj_losses = []
        box_losses = []
        noobj_losses = []
        class_losses = []
        total_losses = []
        segment_losses = []
        self.model.eval()
        for batch_idx, (imgs, masks, targets) in enumerate(loop):
            # Move device
            imgs = imgs.to(self.device)  # (N, 3, 416, 416)
            masks = [m.to(self.device) for m in masks]  # (nM_g, H, W)
            target_s1 = targets[0].to(self.device)  # (N, 3, 13, 13, 6)
            target_s2 = targets[1].to(self.device)  # (N, 3, 26, 26, 6)
            target_s3 = targets[2].to(self.device)  # (N, 3, 52, 52, 6)
            # Model Prediction
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    outs, prototypes = self.model(imgs)
                    s1_loss = self.loss_fn(
                        outs[0],
                        target_s1,
                        self.scaled_anchors[0],  # Detection Branch
                        prototypes,
                        masks,  # Prototype Branch
                    )
                    s2_loss = self.loss_fn(
                        outs[1],
                        target_s2,
                        self.scaled_anchors[1],  # Detection Branch
                        prototypes,
                        masks,  # Prototype Branch
                    )
                    s3_loss = self.loss_fn(
                        outs[2],
                        target_s3,
                        self.scaled_anchors[2],  # Detection Branch
                        prototypes,
                        masks,  # Prototype Branch
                    )
            # Aggregate loss
            obj_loss = s1_loss['obj_loss'] + s2_loss['obj_loss'] + s3_loss[
                'obj_loss']
            box_loss = s1_loss['box_loss'] + s2_loss['box_loss'] + s3_loss[
                'box_loss']
            noobj_loss = s1_loss['noobj_loss'] + s2_loss[
                'noobj_loss'] + s3_loss['noobj_loss']
            class_loss = s1_loss['class_loss'] + s2_loss[
                'class_loss'] + s3_loss['class_loss']
            segment_loss = s1_loss['segment_loss'] + s2_loss[
                'segment_loss'] + s3_loss['segment_loss']
            total_loss = s1_loss['total_loss'] + s2_loss[
                'total_loss'] + s3_loss['total_loss']
            # Moving average loss
            obj_losses.append(obj_loss.item())
            box_losses.append(box_loss.item())
            noobj_losses.append(noobj_loss.item())
            class_losses.append(class_loss.item())
            total_losses.append(total_loss.item())
            segment_losses.append(segment_loss.item())
            # Upadte progress bar
            mean_total_loss = sum(total_losses) / len(total_losses)
            mean_obj_loss = sum(obj_losses) / len(obj_losses)
            mean_noobj_loss = sum(noobj_losses) / len(noobj_losses)
            mean_box_loss = sum(box_losses) / len(box_losses)
            mean_class_loss = sum(class_losses) / len(class_losses)
            mean_segment_loss = sum(segment_losses) / len(segment_losses)
            loop.set_postfix(
                loss=mean_total_loss,
                cls=mean_class_loss,
                box=mean_box_loss,
                obj=mean_obj_loss,
                noobj=mean_noobj_loss,
                segment=mean_segment_loss,
            )
        # Logging (epoch)
        epoch_total_loss = sum(total_losses) / len(total_losses)
        epoch_obj_loss = sum(obj_losses) / len(obj_losses)
        epoch_noobj_loss = sum(noobj_losses) / len(noobj_losses)
        epoch_box_loss = sum(box_losses) / len(box_losses)
        epoch_class_loss = sum(class_losses) / len(class_losses)
        epoch_segment_loss = sum(segment_losses) / len(segment_losses)
        self.board.add_scalar('Epoch Valid Loss',
                              epoch_total_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Valid BOX Loss',
                              epoch_box_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Valid OBJ Loss',
                              epoch_obj_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Valid NOOBJ Loss',
                              epoch_noobj_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Valid CLASS Loss',
                              epoch_class_loss,
                              global_step=self.current_epoch)
        self.board.add_scalar('Epoch Valid SEGMENT Loss',
                              epoch_segment_loss,
                              global_step=self.current_epoch)

    def _check_accuracy(self):
        tot_obj = 0
        tot_noobj = 0
        correct_obj = 0
        correct_noobj = 0
        correct_class = 0
        self.model.eval()
        loop = tqdm(self.valid_loader, leave=True, desc=f"Check ACC")
        for batch_idx, (imgs, masks, targets) in enumerate(loop):
            batch_size = imgs.size(0)
            # Move device
            imgs = imgs.to(self.device)  # (N, 3, 416, 416)
            target_s1 = targets[0].to(self.device)  # (N, 3, 13, 13, 6)
            target_s2 = targets[1].to(self.device)  # (N, 3, 26, 26, 6)
            target_s3 = targets[2].to(self.device)  # (N, 3, 52, 52, 6)
            targets = [target_s1, target_s2, target_s3]
            # Model Prediction
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    outs, prototypes = self.model(imgs)
            for scale_idx in range(len(outs)):
                # Get output
                pred = outs[scale_idx]
                target = targets[scale_idx]
                # Get mask
                obj_mask = target[..., 4] == 1
                noobj_mask = target[..., 4] == 0
                # Count objects
                tot_obj += torch.sum(obj_mask)
                tot_noobj += torch.sum(noobj_mask)
                # Exception Handling
                if torch.sum(obj_mask) == 0:
                    obj_pred = torch.sigmoid(
                        pred[..., 4]) > self.config['valid']['conf_threshold']
                    correct_noobj += torch.sum(
                        obj_pred[noobj_mask] == target[..., 4][noobj_mask])
                    continue
                # Count number of correct classified object
                correct_class += torch.sum((torch.argmax(
                    pred[...,
                         5:5 + self.config['model']['num_classes']][obj_mask],
                    dim=-1) == target[..., 5][obj_mask]))
                # Count number of correct objectness & non-objectness
                obj_pred = torch.sigmoid(
                    pred[..., 4]) > self.config['valid']['conf_threshold']
                correct_obj += torch.sum(
                    obj_pred[obj_mask] == target[..., 4][obj_mask])
                correct_noobj += torch.sum(
                    obj_pred[noobj_mask] == target[..., 4][noobj_mask])
        # Aggregation Result
        acc_obj = (correct_obj / (tot_obj + 1e-6)) * 100
        acc_cls = (correct_class / (tot_obj + 1e-6)) * 100
        acc_noobj = (correct_noobj / (tot_noobj + 1e-6)) * 100
        accs = {
            'cls': acc_cls.item(),
            'obj': acc_obj.item(),
            'noobj': acc_noobj.item()
        }
        print(f"Epoch {self.current_epoch} [Accs]: {accs}")
        return accs

    def _check_map(self):
        sample_idx = 0
        all_pred_bboxes = []
        all_true_bboxes = []
        self.model.eval()
        loop = tqdm(self.valid_loader, leave=True, desc="Check mAP")
        for batch_idx, (imgs, masks, targets) in enumerate(loop):
            batch_size = imgs.size(0)
            # Move device
            imgs = imgs.to(self.device)  # (N, 3, 416, 416)
            target_s1 = targets[0].to(self.device)  # (N, 3, 13, 13, 6)
            target_s2 = targets[1].to(self.device)  # (N, 3, 26, 26, 6)
            target_s3 = targets[2].to(self.device)  # (N, 3, 52, 52, 6)
            targets = [target_s1, target_s2, target_s3]
            # Model Forward
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    preds, prototypes = self.model(imgs)
            # Convert cells to bboxes
            # =================================================================
            true_bboxes = [[] for _ in range(batch_size)]
            pred_bboxes = [[] for _ in range(batch_size)]
            for scale_idx, (pred, target) in enumerate(zip(preds, targets)):
                scale = pred.size(2)
                anchors = self.scaled_anchors[scale_idx]  # (3, 2)
                anchors = anchors.reshape(1, 3, 1, 1, 2)  # (1, 3, 1, 1, 2)
                # Convert prediction to correct format
                pred[..., 0:2] = torch.sigmoid(pred[...,
                                                    0:2])  # (N, 3, S, S, 2)
                pred[..., 2:4] = torch.exp(
                    pred[..., 2:4]) * anchors  # (N, 3, S, S, 2)
                pred[..., 4:5] = torch.sigmoid(pred[...,
                                                    4:5])  # (N, 3, S, S, 1)
                pred_cls_probs = F.softmax(
                    pred[..., 5:5 + self.config['model']['num_classes']],
                    dim=-1)  # (N, 3, S, S, C)
                _, indices = torch.max(pred_cls_probs, dim=-1)  # (N, 3, S, S)
                indices = indices.unsqueeze(-1)  # (N, 3, S, S, 1)
                pred = torch.cat([pred[..., :5], indices],
                                 dim=-1)  # (N, 3, S, S, 6)
                # Convert coordinate system to normalized format (xywh)
                pboxes = cells_to_boxes(cells=pred,
                                        scale=scale)  # (N, 3, S, S, 6)
                tboxes = cells_to_boxes(cells=target,
                                        scale=scale)  # (N, 3, S, S, 6)
                # Filter out bounding boxes from all cells
                for idx, cell_boxes in enumerate(pboxes):
                    obj_mask = cell_boxes[
                        ..., 4] > self.config['valid']['conf_threshold']
                    boxes = cell_boxes[obj_mask]
                    pred_bboxes[idx] += boxes.tolist()
                # Filter out bounding boxes from all cells
                for idx, cell_boxes in enumerate(tboxes):
                    obj_mask = cell_boxes[..., 4] > 0.99
                    boxes = cell_boxes[obj_mask]
                    true_bboxes[idx] += boxes.tolist()
            # Perform NMS batch-by-batch
            # =================================================================
            for batch_idx in range(batch_size):
                pbboxes = torch.tensor(pred_bboxes[batch_idx])
                tbboxes = torch.tensor(true_bboxes[batch_idx])
                # Perform NMS class-by-class
                for c in range(self.config['model']['num_classes']):
                    # Filter pred boxes of specific class
                    nms_pred_boxes = nms_by_class(
                        target=c,
                        bboxes=pbboxes,
                        iou_threshold=self.config['valid']
                        ['nms_iou_threshold'])
                    nms_true_boxes = nms_by_class(
                        target=c,
                        bboxes=tbboxes,
                        iou_threshold=self.config['valid']
                        ['nms_iou_threshold'])
                    all_pred_bboxes.extend([[sample_idx] + box
                                            for box in nms_pred_boxes])
                    all_true_bboxes.extend([[sample_idx] + box
                                            for box in nms_true_boxes])
                sample_idx += 1
        # Compute [email protected] & [email protected]
        # =================================================================
        # The format of the bboxes is (idx, x1, y1, x2, y2, conf, class)
        all_pred_bboxes = torch.tensor(all_pred_bboxes)  # (J, 7)
        all_true_bboxes = torch.tensor(all_true_bboxes)  # (K, 7)
        eval50 = mean_average_precision(
            all_pred_bboxes,
            all_true_bboxes,
            iou_threshold=0.5,
            n_classes=self.config['dataset']['n_classes'])
        eval75 = mean_average_precision(
            all_pred_bboxes,
            all_true_bboxes,
            iou_threshold=0.75,
            n_classes=self.config['dataset']['n_classes'])
        print((
            f"Epoch {self.current_epoch}:\n"
            f"\t-[[email protected]]={eval50['mAP']:.3f}, [Recall]={eval50['recall']:.3f}, [Precision]={eval50['precision']:.3f}\n"
            f"\t-[[email protected]]={eval75['mAP']:.3f}, [Recall]={eval75['recall']:.3f}, [Precision]={eval75['precision']:.3f}\n"
        ))
        return eval50['mAP']

    def _save_checkpoint(self):
        checkpoint = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'current_map': self.current_map,
            'current_epoch': self.current_epoch
        }
        checkpoint_path = osp.join(self.logdir, 'best.pth')
        torch.save(checkpoint, checkpoint_path)
        print("Save checkpoint at '{}'".format(checkpoint_path))
Ejemplo n.º 2
0
def train(args, checkpoint, mid_checkpoint_location, final_checkpoint_location, best_checkpoint_location,
          actfun, curr_seed, outfile_path, filename, fieldnames, curr_sample_size, device, num_params,
          curr_k=2, curr_p=1, curr_g=1, perm_method='shuffle'):
    """
    Runs training session for a given randomized model
    :param args: arguments for this job
    :param checkpoint: current checkpoint
    :param checkpoint_location: output directory for checkpoints
    :param actfun: activation function currently being used
    :param curr_seed: seed being used by current job
    :param outfile_path: path to save outputs from training session
    :param fieldnames: column names for output file
    :param device: reference to CUDA device for GPU support
    :param num_params: number of parameters in the network
    :param curr_k: k value for this iteration
    :param curr_p: p value for this iteration
    :param curr_g: g value for this iteration
    :param perm_method: permutation strategy for our network
    :return:
    """

    resnet_ver = args.resnet_ver
    resnet_width = args.resnet_width
    num_epochs = args.num_epochs

    actfuns_1d = ['relu', 'abs', 'swish', 'leaky_relu', 'tanh']
    if actfun in actfuns_1d:
        curr_k = 1
    kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}

    if args.one_shot:
        util.seed_all(curr_seed)
        model_temp, _ = load_model(args.model, args.dataset, actfun, curr_k, curr_p, curr_g, num_params=num_params,
                                   perm_method=perm_method, device=device, resnet_ver=resnet_ver,
                                   resnet_width=resnet_width, verbose=args.verbose)

        util.seed_all(curr_seed)
        dataset_temp = util.load_dataset(
            args,
            args.model,
            args.dataset,
            seed=curr_seed,
            validation=True,
            batch_size=args.batch_size,
            train_sample_size=curr_sample_size,
            kwargs=kwargs)

        curr_hparams = hparams.get_hparams(args.model, args.dataset, actfun, curr_seed,
                                           num_epochs, args.search, args.hp_idx, args.one_shot)
        optimizer = optim.Adam(model_temp.parameters(),
                               betas=(curr_hparams['beta1'], curr_hparams['beta2']),
                               eps=curr_hparams['eps'],
                               weight_decay=curr_hparams['wd']
                               )

        start_time = time.time()
        oneshot_fieldnames = fieldnames if args.search else None
        oneshot_outfile_path = outfile_path if args.search else None
        lr = util.run_lr_finder(
            args,
            model_temp,
            dataset_temp[0],
            optimizer,
            nn.CrossEntropyLoss(),
            val_loader=dataset_temp[3],
            show=False,
            device=device,
            fieldnames=oneshot_fieldnames,
            outfile_path=oneshot_outfile_path,
            hparams=curr_hparams
        )
        curr_hparams = {}
        print("Time to find LR: {}\n LR found: {:3e}".format(time.time() - start_time, lr))

    else:
        curr_hparams = hparams.get_hparams(args.model, args.dataset, actfun, curr_seed,
                                           num_epochs, args.search, args.hp_idx)
        lr = curr_hparams['max_lr']

        criterion = nn.CrossEntropyLoss()
        model, model_params = load_model(args.model, args.dataset, actfun, curr_k, curr_p, curr_g, num_params=num_params,
                                   perm_method=perm_method, device=device, resnet_ver=resnet_ver,
                                   resnet_width=resnet_width, verbose=args.verbose)

        util.seed_all(curr_seed)
        model.apply(util.weights_init)

        util.seed_all(curr_seed)
        dataset = util.load_dataset(
            args,
            args.model,
            args.dataset,
            seed=curr_seed,
            validation=args.validation,
            batch_size=args.batch_size,
            train_sample_size=curr_sample_size,
            kwargs=kwargs)
        loaders = {
            'aug_train': dataset[0],
            'train': dataset[1],
            'aug_eval': dataset[2],
            'eval': dataset[3],
        }
        sample_size = dataset[4]
        batch_size = dataset[5]

        if args.one_shot:
            optimizer = optim.Adam(model_params)
            scheduler = OneCycleLR(optimizer,
                                   max_lr=lr,
                                   epochs=num_epochs,
                                   steps_per_epoch=int(math.floor(sample_size / batch_size)),
                                   cycle_momentum=False
                                   )
        else:
            optimizer = optim.Adam(model_params,
                                   betas=(curr_hparams['beta1'], curr_hparams['beta2']),
                                   eps=curr_hparams['eps'],
                                   weight_decay=curr_hparams['wd']
                                   )
            scheduler = OneCycleLR(optimizer,
                                   max_lr=curr_hparams['max_lr'],
                                   epochs=num_epochs,
                                   steps_per_epoch=int(math.floor(sample_size / batch_size)),
                                   pct_start=curr_hparams['cycle_peak'],
                                   cycle_momentum=False
                                   )

        epoch = 1
        if checkpoint is not None:
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            epoch = checkpoint['epoch']
            model.to(device)
            print("*** LOADED CHECKPOINT ***"
                  "\n{}"
                  "\nSeed: {}"
                  "\nEpoch: {}"
                  "\nActfun: {}"
                  "\nNum Params: {}"
                  "\nSample Size: {}"
                  "\np: {}"
                  "\nk: {}"
                  "\ng: {}"
                  "\nperm_method: {}".format(mid_checkpoint_location, checkpoint['curr_seed'],
                                             checkpoint['epoch'], checkpoint['actfun'],
                                             checkpoint['num_params'], checkpoint['sample_size'],
                                             checkpoint['p'], checkpoint['k'], checkpoint['g'],
                                             checkpoint['perm_method']))

        util.print_exp_settings(curr_seed, args.dataset, outfile_path, args.model, actfun,
                                util.get_model_params(model), sample_size, batch_size, model.k, model.p, model.g,
                                perm_method, resnet_ver, resnet_width, args.optim, args.validation, curr_hparams)

        best_val_acc = 0

        if args.mix_pre_apex:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

        # ---- Start Training
        while epoch <= num_epochs:

            if args.check_path != '':
                torch.save({'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'curr_seed': curr_seed,
                            'epoch': epoch,
                            'actfun': actfun,
                            'num_params': num_params,
                            'sample_size': sample_size,
                            'p': curr_p, 'k': curr_k, 'g': curr_g,
                            'perm_method': perm_method
                            }, mid_checkpoint_location)

            util.seed_all((curr_seed * args.num_epochs) + epoch)
            start_time = time.time()
            if args.mix_pre:
                scaler = torch.cuda.amp.GradScaler()

            # ---- Training
            model.train()
            total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
            for batch_idx, (x, targetx) in enumerate(loaders['aug_train']):
                # print(batch_idx)
                x, targetx = x.to(device), targetx.to(device)
                optimizer.zero_grad()
                if args.mix_pre:
                    with torch.cuda.amp.autocast():
                        output = model(x)
                        train_loss = criterion(output, targetx)
                    total_train_loss += train_loss
                    n += 1
                    scaler.scale(train_loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                elif args.mix_pre_apex:
                    output = model(x)
                    train_loss = criterion(output, targetx)
                    total_train_loss += train_loss
                    n += 1
                    with amp.scale_loss(train_loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    optimizer.step()
                else:
                    output = model(x)
                    train_loss = criterion(output, targetx)
                    total_train_loss += train_loss
                    n += 1
                    train_loss.backward()
                    optimizer.step()
                if args.optim == 'onecycle' or args.optim == 'onecycle_sgd':
                    scheduler.step()
                _, prediction = torch.max(output.data, 1)
                num_correct += torch.sum(prediction == targetx.data)
                num_total += len(prediction)
            epoch_aug_train_loss = total_train_loss / n
            epoch_aug_train_acc = num_correct * 1.0 / num_total

            alpha_primes = []
            alphas = []
            if model.actfun == 'combinact':
                for i, layer_alpha_primes in enumerate(model.all_alpha_primes):
                    curr_alpha_primes = torch.mean(layer_alpha_primes, dim=0)
                    curr_alphas = F.softmax(curr_alpha_primes, dim=0).data.tolist()
                    curr_alpha_primes = curr_alpha_primes.tolist()
                    alpha_primes.append(curr_alpha_primes)
                    alphas.append(curr_alphas)

            model.eval()
            with torch.no_grad():
                total_val_loss, n, num_correct, num_total = 0, 0, 0, 0
                for batch_idx, (y, targety) in enumerate(loaders['aug_eval']):
                    y, targety = y.to(device), targety.to(device)
                    output = model(y)
                    val_loss = criterion(output, targety)
                    total_val_loss += val_loss
                    n += 1
                    _, prediction = torch.max(output.data, 1)
                    num_correct += torch.sum(prediction == targety.data)
                    num_total += len(prediction)
                epoch_aug_val_loss = total_val_loss / n
                epoch_aug_val_acc = num_correct * 1.0 / num_total

                total_val_loss, n, num_correct, num_total = 0, 0, 0, 0
                for batch_idx, (y, targety) in enumerate(loaders['eval']):
                    y, targety = y.to(device), targety.to(device)
                    output = model(y)
                    val_loss = criterion(output, targety)
                    total_val_loss += val_loss
                    n += 1
                    _, prediction = torch.max(output.data, 1)
                    num_correct += torch.sum(prediction == targety.data)
                    num_total += len(prediction)
                epoch_val_loss = total_val_loss / n
                epoch_val_acc = num_correct * 1.0 / num_total
            lr_curr = 0
            for param_group in optimizer.param_groups:
                lr_curr = param_group['lr']
            print(
                "    Epoch {}: LR {:1.5f} ||| aug_train_acc {:1.4f} | val_acc {:1.4f}, aug {:1.4f} ||| "
                "aug_train_loss {:1.4f} | val_loss {:1.4f}, aug {:1.4f} ||| time = {:1.4f}"
                    .format(epoch, lr_curr, epoch_aug_train_acc, epoch_val_acc, epoch_aug_val_acc,
                            epoch_aug_train_loss, epoch_val_loss, epoch_aug_val_loss, (time.time() - start_time)), flush=True
            )

            if args.hp_idx is None:
                hp_idx = -1
            else:
                hp_idx = args.hp_idx

            epoch_train_loss = 0
            epoch_train_acc = 0
            if epoch == num_epochs:
                with torch.no_grad():
                    total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
                    for batch_idx, (x, targetx) in enumerate(loaders['aug_train']):
                        x, targetx = x.to(device), targetx.to(device)
                        output = model(x)
                        train_loss = criterion(output, targetx)
                        total_train_loss += train_loss
                        n += 1
                        _, prediction = torch.max(output.data, 1)
                        num_correct += torch.sum(prediction == targetx.data)
                        num_total += len(prediction)
                    epoch_aug_train_loss = total_train_loss / n
                    epoch_aug_train_acc = num_correct * 1.0 / num_total

                    total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
                    for batch_idx, (x, targetx) in enumerate(loaders['train']):
                        x, targetx = x.to(device), targetx.to(device)
                        output = model(x)
                        train_loss = criterion(output, targetx)
                        total_train_loss += train_loss
                        n += 1
                        _, prediction = torch.max(output.data, 1)
                        num_correct += torch.sum(prediction == targetx.data)
                        num_total += len(prediction)
                    epoch_train_loss = total_val_loss / n
                    epoch_train_acc = num_correct * 1.0 / num_total

            # Outputting data to CSV at end of epoch
            with open(outfile_path, mode='a') as out_file:
                writer = csv.DictWriter(out_file, fieldnames=fieldnames, lineterminator='\n')
                writer.writerow({'dataset': args.dataset,
                                 'seed': curr_seed,
                                 'epoch': epoch,
                                 'time': (time.time() - start_time),
                                 'actfun': model.actfun,
                                 'sample_size': sample_size,
                                 'model': args.model,
                                 'batch_size': batch_size,
                                 'alpha_primes': alpha_primes,
                                 'alphas': alphas,
                                 'num_params': util.get_model_params(model),
                                 'var_nparams': args.var_n_params,
                                 'var_nsamples': args.var_n_samples,
                                 'k': curr_k,
                                 'p': curr_p,
                                 'g': curr_g,
                                 'perm_method': perm_method,
                                 'gen_gap': float(epoch_val_loss - epoch_train_loss),
                                 'aug_gen_gap': float(epoch_aug_val_loss - epoch_aug_train_loss),
                                 'resnet_ver': resnet_ver,
                                 'resnet_width': resnet_width,
                                 'epoch_train_loss': float(epoch_train_loss),
                                 'epoch_train_acc': float(epoch_train_acc),
                                 'epoch_aug_train_loss': float(epoch_aug_train_loss),
                                 'epoch_aug_train_acc': float(epoch_aug_train_acc),
                                 'epoch_val_loss': float(epoch_val_loss),
                                 'epoch_val_acc': float(epoch_val_acc),
                                 'epoch_aug_val_loss': float(epoch_aug_val_loss),
                                 'epoch_aug_val_acc': float(epoch_aug_val_acc),
                                 'hp_idx': hp_idx,
                                 'curr_lr': lr_curr,
                                 'found_lr': lr,
                                 'hparams': curr_hparams,
                                 'epochs': num_epochs
                                 })

            epoch += 1

            if args.optim == 'rmsprop':
                scheduler.step()

            if args.checkpoints:
                if epoch_val_acc > best_val_acc:
                    best_val_acc = epoch_val_acc
                    torch.save({'state_dict': model.state_dict(),
                                'optimizer': optimizer.state_dict(),
                                'scheduler': scheduler.state_dict(),
                                'curr_seed': curr_seed,
                                'epoch': epoch,
                                'actfun': actfun,
                                'num_params': num_params,
                                'sample_size': sample_size,
                                'p': curr_p, 'k': curr_k, 'g': curr_g,
                                'perm_method': perm_method
                                }, best_checkpoint_location)

                torch.save({'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'scheduler': scheduler.state_dict(),
                            'curr_seed': curr_seed,
                            'epoch': epoch,
                            'actfun': actfun,
                            'num_params': num_params,
                            'sample_size': sample_size,
                            'p': curr_p, 'k': curr_k, 'g': curr_g,
                            'perm_method': perm_method
                            }, final_checkpoint_location)
Ejemplo n.º 3
0
    def train(
            self,
            base_path: Union[Path, str],
            learning_rate: float = 0.1,
            mini_batch_size: int = 32,
            mini_batch_chunk_size: Optional[int] = None,
            max_epochs: int = 100,
            train_with_dev: bool = False,
            train_with_test: bool = False,
            monitor_train: bool = False,
            monitor_test: bool = False,
            main_evaluation_metric: Tuple[str, str] = ("micro avg", 'f1-score'),
            scheduler=AnnealOnPlateau,
            anneal_factor: float = 0.5,
            patience: int = 3,
            min_learning_rate: float = 0.0001,
            initial_extra_patience: int = 0,
            optimizer: torch.optim.Optimizer = SGD,
            cycle_momentum: bool = False,
            warmup_fraction: float = 0.1,
            embeddings_storage_mode: str = "cpu",
            checkpoint: bool = False,
            save_final_model: bool = True,
            anneal_with_restarts: bool = False,
            anneal_with_prestarts: bool = False,
            anneal_against_dev_loss: bool = False,
            batch_growth_annealing: bool = False,
            shuffle: bool = True,
            param_selection_mode: bool = False,
            write_weights: bool = False,
            num_workers: int = 6,
            sampler=None,
            use_amp: bool = False,
            amp_opt_level: str = "O1",
            eval_on_train_fraction: float = 0.0,
            eval_on_train_shuffle: bool = False,
            save_model_each_k_epochs: int = 0,
            tensorboard_comment: str = '',
            use_swa: bool = False,
            use_final_model_for_eval: bool = False,
            gold_label_dictionary_for_eval: Optional[Dictionary] = None,
            create_file_logs: bool = True,
            create_loss_file: bool = True,
            epoch: int = 0,
            use_tensorboard: bool = False,
            tensorboard_log_dir=None,
            metrics_for_tensorboard=[],
            optimizer_state_dict: Optional = None,
            scheduler_state_dict: Optional = None,
            save_optimizer_state: bool = False,
            **kwargs,
    ) -> dict:
        """
        Trains any class that implements the flair.nn.Model interface.
        :param base_path: Main path to which all output during training is logged and models are saved
        :param learning_rate: Initial learning rate (or max, if scheduler is OneCycleLR)
        :param mini_batch_size: Size of mini-batches during training
        :param mini_batch_chunk_size: If mini-batches are larger than this number, they get broken down into chunks of this size for processing purposes
        :param max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed.
        :param scheduler: The learning rate scheduler to use
        :param checkpoint: If True, a full checkpoint is saved at end of each epoch
        :param cycle_momentum: If scheduler is OneCycleLR, whether the scheduler should cycle also the momentum
        :param anneal_factor: The factor by which the learning rate is annealed
        :param patience: Patience is the number of epochs with no improvement the Trainer waits
         until annealing the learning rate
        :param min_learning_rate: If the learning rate falls below this threshold, training terminates
        :param warmup_fraction: Fraction of warmup steps if the scheduler is LinearSchedulerWithWarmup
        :param train_with_dev:  If True, the data from dev split is added to the training data
        :param train_with_test: If True, the data from test split is added to the training data
        :param monitor_train: If True, training data is evaluated at end of each epoch
        :param monitor_test: If True, test data is evaluated at end of each epoch
        :param embeddings_storage_mode: One of 'none' (all embeddings are deleted and freshly recomputed),
        'cpu' (embeddings are stored on CPU) or 'gpu' (embeddings are stored on GPU)
        :param save_final_model: If True, final model is saved
        :param anneal_with_restarts: If True, the last best model is restored when annealing the learning rate
        :param shuffle: If True, data is shuffled during training
        :param param_selection_mode: If True, testing is performed against dev data. Use this mode when doing
        parameter selection.
        :param num_workers: Number of workers in your data loader.
        :param sampler: You can pass a data sampler here for special sampling of data.
        :param eval_on_train_fraction: the fraction of train data to do the evaluation on,
        if 0. the evaluation is not performed on fraction of training data,
        if 'dev' the size is determined from dev set size
        :param eval_on_train_shuffle: if True the train data fraction is determined on the start of training
        and kept fixed during training, otherwise it's sampled at beginning of each epoch
        :param save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will
        be saved each 5 epochs. Default is 0 which means no model saving.
        :param main_evaluation_metric: Type of metric to use for best model tracking and learning rate scheduling (if dev data is available, otherwise loss will be used), currently only applicable for text_classification_model
        :param tensorboard_comment: Comment to use for tensorboard logging
        :param create_file_logs: If True, the logs will also be stored in a file 'training.log' in the model folder
        :param create_loss_file: If True, the loss will be writen to a file 'loss.tsv' in the model folder
        :param optimizer: The optimizer to use (typically SGD or Adam)
        :param epoch: The starting epoch (normally 0 but could be higher if you continue training model)
        :param use_tensorboard: If True, writes out tensorboard information
        :param tensorboard_log_dir: Directory into which tensorboard log files will be written
        :param metrics_for_tensorboard: List of tuples that specify which metrics (in addition to the main_score) shall be plotted in tensorboard, could be [("macro avg", 'f1-score'), ("macro avg", 'precision')] for example
        :param kwargs: Other arguments for the Optimizer
        :return:
        """

        # create a model card for this model with Flair and PyTorch version
        model_card = {'flair_version': flair.__version__, 'pytorch_version': torch.__version__}

        # also record Transformers version if library is loaded
        try:
            import transformers
            model_card['transformers_version'] = transformers.__version__
        except:
            pass

        # remember all parameters used in train() call
        local_variables = locals()
        training_parameters = {}
        for parameter in signature(self.train).parameters:
            training_parameters[parameter] = local_variables[parameter]
        model_card['training_parameters'] = training_parameters

        # add model card to model
        self.model.model_card = model_card

        if use_tensorboard:
            try:
                from torch.utils.tensorboard import SummaryWriter

                if tensorboard_log_dir is not None and not os.path.exists(tensorboard_log_dir):
                    os.mkdir(tensorboard_log_dir)
                writer = SummaryWriter(log_dir=tensorboard_log_dir, comment=tensorboard_comment)
                log.info(f"tensorboard logging path is {tensorboard_log_dir}")

            except:
                log_line(log)
                log.warning("ATTENTION! PyTorch >= 1.1.0 and pillow are required for TensorBoard support!")
                log_line(log)
                use_tensorboard = False
                pass

        if use_amp:
            if sys.version_info < (3, 0):
                raise RuntimeError("Apex currently only supports Python 3. Aborting.")
            if amp is None:
                raise RuntimeError(
                    "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
                    "to enable mixed-precision training."
                )

        if mini_batch_chunk_size is None:
            mini_batch_chunk_size = mini_batch_size
        if learning_rate < min_learning_rate:
            min_learning_rate = learning_rate / 10

        initial_learning_rate = learning_rate

        # cast string to Path
        if type(base_path) is str:
            base_path = Path(base_path)
        base_path.mkdir(exist_ok=True, parents=True)

        if create_file_logs:
            log_handler = add_file_handler(log, base_path / "training.log")
        else:
            log_handler = None

        log_line(log)
        log.info(f'Model: "{self.model}"')
        log_line(log)
        log.info(f'Corpus: "{self.corpus}"')
        log_line(log)
        log.info("Parameters:")
        log.info(f' - learning_rate: "{learning_rate}"')
        log.info(f' - mini_batch_size: "{mini_batch_size}"')
        log.info(f' - patience: "{patience}"')
        log.info(f' - anneal_factor: "{anneal_factor}"')
        log.info(f' - max_epochs: "{max_epochs}"')
        log.info(f' - shuffle: "{shuffle}"')
        log.info(f' - train_with_dev: "{train_with_dev}"')
        log.info(f' - batch_growth_annealing: "{batch_growth_annealing}"')
        log_line(log)
        log.info(f'Model training base path: "{base_path}"')
        log_line(log)
        log.info(f"Device: {flair.device}")
        log_line(log)
        log.info(f"Embeddings storage mode: {embeddings_storage_mode}")
        if isinstance(self.model, SequenceTagger) and self.model.weight_dict and self.model.use_crf:
            log_line(log)
            log.warning(f'WARNING: Specified class weights will not take effect when using CRF')

        # check for previously saved best models in the current training folder and delete them
        self.check_for_and_delete_previous_best_models(base_path)

        # determine what splits (train, dev, test) to evaluate and log
        log_train = True if monitor_train else False
        log_test = True if (not param_selection_mode and self.corpus.test and monitor_test) else False
        log_dev = False if train_with_dev or not self.corpus.dev else True
        log_train_part = True if (eval_on_train_fraction == "dev" or eval_on_train_fraction > 0.0) else False

        if log_train_part:
            train_part_size = len(self.corpus.dev) if eval_on_train_fraction == "dev" \
                else int(len(self.corpus.train) * eval_on_train_fraction)

            assert train_part_size > 0
            if not eval_on_train_shuffle:
                train_part_indices = list(range(train_part_size))
                train_part = torch.utils.data.dataset.Subset(self.corpus.train, train_part_indices)

        # prepare loss logging file and set up header
        loss_txt = init_output_file(base_path, "loss.tsv") if create_loss_file else None

        weight_extractor = WeightExtractor(base_path)

        # if optimizer class is passed, instantiate:
        if inspect.isclass(optimizer):
            optimizer: torch.optim.Optimizer = optimizer(self.model.parameters(), lr=learning_rate, **kwargs)

        if use_swa:
            import torchcontrib
            optimizer = torchcontrib.optim.SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=learning_rate)

        if use_amp:
            self.model, optimizer = amp.initialize(
                self.model, optimizer, opt_level=amp_opt_level
            )

        # load existing optimizer state dictionary if it exists
        if optimizer_state_dict:
            optimizer.load_state_dict(optimizer_state_dict)

        # minimize training loss if training with dev data, else maximize dev score
        anneal_mode = "min" if train_with_dev or anneal_against_dev_loss else "max"
        best_validation_score = 100000000000 if train_with_dev or anneal_against_dev_loss else 0.

        dataset_size = len(self.corpus.train)
        if train_with_dev:
            dataset_size += len(self.corpus.dev)

        # if scheduler is passed as a class, instantiate
        if inspect.isclass(scheduler):
            if scheduler == OneCycleLR:
                scheduler = OneCycleLR(optimizer,
                                       max_lr=learning_rate,
                                       steps_per_epoch=dataset_size // mini_batch_size + 1,
                                       epochs=max_epochs - epoch,
                                       # if we load a checkpoint, we have already trained for epoch
                                       pct_start=0.0,
                                       cycle_momentum=cycle_momentum)
            elif scheduler == LinearSchedulerWithWarmup:
                steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size
                num_train_steps = int(steps_per_epoch * max_epochs)
                num_warmup_steps = int(num_train_steps * warmup_fraction)

                scheduler = LinearSchedulerWithWarmup(optimizer,
                                                      num_train_steps=num_train_steps,
                                                      num_warmup_steps=num_warmup_steps)
            else:
                scheduler = scheduler(
                    optimizer,
                    factor=anneal_factor,
                    patience=patience,
                    initial_extra_patience=initial_extra_patience,
                    mode=anneal_mode,
                    verbose=True,
                )

        # load existing scheduler state dictionary if it exists
        if scheduler_state_dict:
            scheduler.load_state_dict(scheduler_state_dict)

        # update optimizer and scheduler in model card
        model_card['training_parameters']['optimizer'] = optimizer
        model_card['training_parameters']['scheduler'] = scheduler

        if isinstance(scheduler, OneCycleLR) and batch_growth_annealing:
            raise ValueError("Batch growth with OneCycle policy is not implemented.")

        train_data = self.corpus.train

        # if training also uses dev/train data, include in training set
        if train_with_dev or train_with_test:

            parts = [self.corpus.train]
            if train_with_dev: parts.append(self.corpus.dev)
            if train_with_test: parts.append(self.corpus.test)

            train_data = ConcatDataset(parts)

        # initialize sampler if provided
        if sampler is not None:
            # init with default values if only class is provided
            if inspect.isclass(sampler):
                sampler = sampler()
            # set dataset to sample from
            sampler.set_dataset(train_data)
            shuffle = False

        dev_score_history = []
        dev_loss_history = []
        train_loss_history = []

        micro_batch_size = mini_batch_chunk_size

        # At any point you can hit Ctrl + C to break out of training early.
        try:
            previous_learning_rate = learning_rate
            momentum = 0
            for group in optimizer.param_groups:
                if "momentum" in group:
                    momentum = group["momentum"]

            for epoch in range(epoch + 1, max_epochs + 1):
                log_line(log)

                # update epoch in model card
                self.model.model_card['training_parameters']['epoch'] = epoch

                if anneal_with_prestarts:
                    last_epoch_model_state_dict = copy.deepcopy(self.model.state_dict())

                if eval_on_train_shuffle:
                    train_part_indices = list(range(self.corpus.train))
                    random.shuffle(train_part_indices)
                    train_part_indices = train_part_indices[:train_part_size]
                    train_part = torch.utils.data.dataset.Subset(self.corpus.train, train_part_indices)

                # get new learning rate
                for group in optimizer.param_groups:
                    learning_rate = group["lr"]

                if learning_rate != previous_learning_rate and batch_growth_annealing:
                    mini_batch_size *= 2

                # reload last best model if annealing with restarts is enabled
                if (
                        (anneal_with_restarts or anneal_with_prestarts)
                        and learning_rate != previous_learning_rate
                        and os.path.exists(base_path / "best-model.pt")
                ):
                    if anneal_with_restarts:
                        log.info("resetting to best model")
                        self.model.load_state_dict(
                            self.model.load(base_path / "best-model.pt").state_dict()
                        )
                    if anneal_with_prestarts:
                        log.info("resetting to pre-best model")
                        self.model.load_state_dict(
                            self.model.load(base_path / "pre-best-model.pt").state_dict()
                        )

                previous_learning_rate = learning_rate
                if use_tensorboard:
                    writer.add_scalar("learning_rate", learning_rate, epoch)

                # stop training if learning rate becomes too small
                if ((not isinstance(scheduler, (OneCycleLR, LinearSchedulerWithWarmup)) and
                     learning_rate < min_learning_rate)):
                    log_line(log)
                    log.info("learning rate too small - quitting training!")
                    log_line(log)
                    break

                batch_loader = DataLoader(
                    train_data,
                    batch_size=mini_batch_size,
                    shuffle=shuffle if epoch > 1 else False,  # never shuffle the first epoch
                    num_workers=num_workers,
                    sampler=sampler,
                )

                self.model.train()

                train_loss: float = 0

                seen_batches = 0
                total_number_of_batches = len(batch_loader)

                modulo = max(1, int(total_number_of_batches / 10))

                # process mini-batches
                batch_time = 0
                average_over = 0
                for batch_no, batch in enumerate(batch_loader):

                    start_time = time.time()

                    # zero the gradients on the model and optimizer
                    self.model.zero_grad()
                    optimizer.zero_grad()

                    # if necessary, make batch_steps
                    batch_steps = [batch]
                    if len(batch) > micro_batch_size:
                        batch_steps = [batch[x: x + micro_batch_size] for x in range(0, len(batch), micro_batch_size)]

                    # forward and backward for batch
                    for batch_step in batch_steps:

                        # forward pass
                        loss = self.model.forward_loss(batch_step)

                        if isinstance(loss, Tuple):
                            average_over += loss[1]
                            loss = loss[0]

                        # Backward
                        if use_amp:
                            with amp.scale_loss(loss, optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            loss.backward()
                        train_loss += loss.item()

                    # do the optimizer step
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
                    optimizer.step()

                    # do the scheduler step if one-cycle or linear decay
                    if isinstance(scheduler, (OneCycleLR, LinearSchedulerWithWarmup)):
                        scheduler.step()
                        # get new learning rate
                        for group in optimizer.param_groups:
                            learning_rate = group["lr"]
                            if "momentum" in group:
                                momentum = group["momentum"]
                            if "betas" in group:
                                momentum, _ = group["betas"]

                    seen_batches += 1

                    # depending on memory mode, embeddings are moved to CPU, GPU or deleted
                    store_embeddings(batch, embeddings_storage_mode)

                    batch_time += time.time() - start_time
                    if seen_batches % modulo == 0:
                        momentum_info = f' - momentum: {momentum:.4f}' if cycle_momentum else ''
                        intermittent_loss = train_loss / average_over if average_over > 0 else train_loss / seen_batches
                        log.info(
                            f"epoch {epoch} - iter {seen_batches}/{total_number_of_batches} - loss "
                            f"{intermittent_loss:.8f} - samples/sec: {mini_batch_size * modulo / batch_time:.2f}"
                            f" - lr: {learning_rate:.6f}{momentum_info}"
                        )
                        batch_time = 0
                        iteration = epoch * total_number_of_batches + batch_no
                        if not param_selection_mode and write_weights:
                            weight_extractor.extract_weights(self.model.state_dict(), iteration)

                if average_over != 0:
                    train_loss /= average_over

                self.model.eval()

                log_line(log)
                log.info(f"EPOCH {epoch} done: loss {train_loss:.4f} - lr {learning_rate:.7f}")

                if use_tensorboard:
                    writer.add_scalar("train_loss", train_loss, epoch)

                # evaluate on train / dev / test split depending on training settings
                result_line: str = ""

                if log_train:
                    train_eval_result = self.model.evaluate(
                        self.corpus.train,
                        gold_label_type=self.model.label_type,
                        mini_batch_size=mini_batch_chunk_size,
                        num_workers=num_workers,
                        embedding_storage_mode=embeddings_storage_mode,
                        main_evaluation_metric=main_evaluation_metric,
                        gold_label_dictionary=gold_label_dictionary_for_eval,
                    )
                    result_line += f"\t{train_eval_result.log_line}"

                    # depending on memory mode, embeddings are moved to CPU, GPU or deleted
                    store_embeddings(self.corpus.train, embeddings_storage_mode)

                if log_train_part:
                    train_part_eval_result = self.model.evaluate(
                        train_part,
                        gold_label_type=self.model.label_type,
                        mini_batch_size=mini_batch_chunk_size,
                        num_workers=num_workers,
                        embedding_storage_mode=embeddings_storage_mode,
                        main_evaluation_metric=main_evaluation_metric,
                        gold_label_dictionary=gold_label_dictionary_for_eval,
                    )
                    result_line += f"\t{train_part_eval_result.loss}\t{train_part_eval_result.log_line}"

                    log.info(
                        f"TRAIN_SPLIT : loss {train_part_eval_result.loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]}) {round(train_part_eval_result.main_score, 4)}"
                    )
                if use_tensorboard:
                    for (metric_class_avg_type, metric_type) in metrics_for_tensorboard:
                        writer.add_scalar(
                            f"train_{metric_class_avg_type}_{metric_type}",
                            train_part_eval_result.classification_report[metric_class_avg_type][metric_type], epoch
                        )

                if log_dev:
                    dev_eval_result = self.model.evaluate(
                        self.corpus.dev,
                        gold_label_type=self.model.label_type,
                        mini_batch_size=mini_batch_chunk_size,
                        num_workers=num_workers,
                        out_path=base_path / "dev.tsv",
                        embedding_storage_mode=embeddings_storage_mode,
                        main_evaluation_metric=main_evaluation_metric,
                        gold_label_dictionary=gold_label_dictionary_for_eval,
                    )
                    result_line += f"\t{dev_eval_result.loss}\t{dev_eval_result.log_line}"
                    log.info(
                        f"DEV : loss {dev_eval_result.loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]})  {round(dev_eval_result.main_score, 4)}"
                    )
                    # calculate scores using dev data if available
                    # append dev score to score history
                    dev_score_history.append(dev_eval_result.main_score)
                    dev_loss_history.append(dev_eval_result.loss)

                    dev_score = dev_eval_result.main_score

                    # depending on memory mode, embeddings are moved to CPU, GPU or deleted
                    store_embeddings(self.corpus.dev, embeddings_storage_mode)

                    if use_tensorboard:
                        writer.add_scalar("dev_loss", dev_eval_result.loss, epoch)
                        writer.add_scalar("dev_score", dev_eval_result.main_score, epoch)
                        for (metric_class_avg_type, metric_type) in metrics_for_tensorboard:
                            writer.add_scalar(
                                f"dev_{metric_class_avg_type}_{metric_type}",
                                dev_eval_result.classification_report[metric_class_avg_type][metric_type], epoch
                            )

                if log_test:
                    test_eval_result = self.model.evaluate(
                        self.corpus.test,
                        gold_label_type=self.model.label_type,
                        mini_batch_size=mini_batch_chunk_size,
                        num_workers=num_workers,
                        out_path=base_path / "test.tsv",
                        embedding_storage_mode=embeddings_storage_mode,
                        main_evaluation_metric=main_evaluation_metric,
                        gold_label_dictionary=gold_label_dictionary_for_eval,
                    )
                    result_line += f"\t{test_eval_result.loss}\t{test_eval_result.log_line}"
                    log.info(
                        f"TEST : loss {test_eval_result.loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]})  {round(test_eval_result.main_score, 4)}"
                    )

                    # depending on memory mode, embeddings are moved to CPU, GPU or deleted
                    store_embeddings(self.corpus.test, embeddings_storage_mode)

                    if use_tensorboard:
                        writer.add_scalar("test_loss", test_eval_result.loss, epoch)
                        writer.add_scalar("test_score", test_eval_result.main_score, epoch)
                        for (metric_class_avg_type, metric_type) in metrics_for_tensorboard:
                            writer.add_scalar(
                                f"test_{metric_class_avg_type}_{metric_type}",
                                test_eval_result.classification_report[metric_class_avg_type][metric_type], epoch
                            )

                # determine if this is the best model or if we need to anneal
                current_epoch_has_best_model_so_far = False
                # default mode: anneal against dev score
                if not train_with_dev and not anneal_against_dev_loss:
                    if dev_score > best_validation_score:
                        current_epoch_has_best_model_so_far = True
                        best_validation_score = dev_score

                    if isinstance(scheduler, AnnealOnPlateau):
                        scheduler.step(dev_score, dev_eval_result.loss)

                # alternative: anneal against dev loss
                if not train_with_dev and anneal_against_dev_loss:
                    if dev_eval_result.loss < best_validation_score:
                        current_epoch_has_best_model_so_far = True
                        best_validation_score = dev_eval_result.loss

                    if isinstance(scheduler, AnnealOnPlateau):
                        scheduler.step(dev_eval_result.loss)

                # alternative: anneal against train loss
                if train_with_dev:
                    if train_loss < best_validation_score:
                        current_epoch_has_best_model_so_far = True
                        best_validation_score = train_loss

                    if isinstance(scheduler, AnnealOnPlateau):
                        scheduler.step(train_loss)

                train_loss_history.append(train_loss)

                # determine bad epoch number
                try:
                    bad_epochs = scheduler.num_bad_epochs
                except:
                    bad_epochs = 0
                for group in optimizer.param_groups:
                    new_learning_rate = group["lr"]
                if new_learning_rate != previous_learning_rate:
                    bad_epochs = patience + 1
                    if previous_learning_rate == initial_learning_rate: bad_epochs += initial_extra_patience

                # log bad epochs
                log.info(f"BAD EPOCHS (no improvement): {bad_epochs}")

                if create_loss_file:
                    # output log file
                    with open(loss_txt, "a") as f:

                        # make headers on first epoch
                        if epoch == 1:
                            f.write(f"EPOCH\tTIMESTAMP\tBAD_EPOCHS\tLEARNING_RATE\tTRAIN_LOSS")

                            if log_train:
                                f.write("\tTRAIN_" + "\tTRAIN_".join(train_eval_result.log_header.split("\t")))

                            if log_train_part:
                                f.write("\tTRAIN_PART_LOSS\tTRAIN_PART_" + "\tTRAIN_PART_".join(
                                    train_part_eval_result.log_header.split("\t")))

                            if log_dev:
                                f.write("\tDEV_LOSS\tDEV_" + "\tDEV_".join(dev_eval_result.log_header.split("\t")))

                            if log_test:
                                f.write("\tTEST_LOSS\tTEST_" + "\tTEST_".join(test_eval_result.log_header.split("\t")))

                        f.write(
                            f"\n{epoch}\t{datetime.datetime.now():%H:%M:%S}\t{bad_epochs}\t{learning_rate:.4f}\t{train_loss}"
                        )
                        f.write(result_line)

                # if checkpoint is enabled, save model at each epoch
                if checkpoint and not param_selection_mode:
                    self.model.save(base_path / "checkpoint.pt", checkpoint=True)

                # Check whether to save best model
                if (
                        (not train_with_dev or anneal_with_restarts or anneal_with_prestarts)
                        and not param_selection_mode
                        and current_epoch_has_best_model_so_far
                        and not use_final_model_for_eval
                ):
                    log.info("saving best model")
                    self.model.save(base_path / "best-model.pt", checkpoint=save_optimizer_state)

                    if anneal_with_prestarts:
                        current_state_dict = self.model.state_dict()
                        self.model.load_state_dict(last_epoch_model_state_dict)
                        self.model.save(base_path / "pre-best-model.pt")
                        self.model.load_state_dict(current_state_dict)

                if save_model_each_k_epochs > 0 and not epoch % save_model_each_k_epochs:
                    print("saving model of current epoch")
                    model_name = "model_epoch_" + str(epoch) + ".pt"
                    self.model.save(base_path / model_name, checkpoint=save_optimizer_state)

            if use_swa:
                optimizer.swap_swa_sgd()

            # if we do not use dev data for model selection, save final model
            if save_final_model and not param_selection_mode:
                self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state)

        except KeyboardInterrupt:
            log_line(log)
            log.info("Exiting from training early.")

            if use_tensorboard:
                writer.close()

            if not param_selection_mode:
                log.info("Saving model ...")
                self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state)
                log.info("Done.")

        # test best model if test data is present
        if self.corpus.test and not train_with_test:
            final_score = self.final_test(
                base_path=base_path,
                eval_mini_batch_size=mini_batch_chunk_size,
                num_workers=num_workers,
                main_evaluation_metric=main_evaluation_metric,
                gold_label_dictionary_for_eval=gold_label_dictionary_for_eval,
            )
        else:
            final_score = 0
            log.info("Test data not provided setting final score to 0")

        if create_file_logs:
            log_handler.close()
            log.removeHandler(log_handler)

        if use_tensorboard:
            writer.close()

        return {
            "test_score": final_score,
            "dev_score_history": dev_score_history,
            "train_loss_history": train_loss_history,
            "dev_loss_history": dev_loss_history,
        }
Ejemplo n.º 4
0
class Learner:
    def __init__(self, model, train_loader, valid_loader, fold, config, seed):
        self.config = config
        self.seed = seed
        self.device = self.config.device
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.model = model.to(self.device)

        self.fold = fold
        self.logger = init_logger(
            config.log_dir, f'train_seed{self.seed}_fold{self.fold}.log')
        self.tb_logger = init_tb_logger(
            config.log_dir, f'train_seed{self.seed}_fold{self.fold}')
        if self.fold == 0:
            self.log('\n'.join(
                [f"{k} = {v}" for k, v in self.config.__dict__.items()]))

        self.criterion = SmoothBCEwLogits(smoothing=self.config.smoothing)
        self.evaluator = nn.BCEWithLogitsLoss()
        self.summary_loss = AverageMeter()
        self.history = {'train': [], 'valid': []}

        self.optimizer = Adam(self.model.parameters(),
                              lr=config.lr,
                              weight_decay=self.config.weight_decay)
        self.scheduler = OneCycleLR(optimizer=self.optimizer,
                                    pct_start=0.1,
                                    div_factor=1e3,
                                    max_lr=1e-2,
                                    epochs=config.n_epochs,
                                    steps_per_epoch=len(train_loader))
        self.scaler = GradScaler() if config.fp16 else None

        self.epoch = 0
        self.best_epoch = 0
        self.best_loss = np.inf

    def train_one_epoch(self):
        self.model.train()
        self.summary_loss.reset()
        iters = len(self.train_loader)
        for step, (g_x, c_x, cate_x, labels,
                   non_labels) in enumerate(self.train_loader):
            self.optimizer.zero_grad()
            # self.tb_logger.add_scalar('Train/lr', self.optimizer.param_groups[0]['lr'],
            #                           iters * self.epoch + step)
            labels = labels.to(self.device)
            non_labels = non_labels.to(self.device)
            g_x = g_x.to(self.device)
            c_x = c_x.to(self.device)
            cate_x = cate_x.to(self.device)
            batch_size = labels.shape[0]

            with ExitStack() as stack:
                if self.config.fp16:
                    auto = stack.enter_context(autocast())
                outputs = self.model(g_x, c_x, cate_x)
                loss = self.criterion(outputs, labels)

            if self.config.fp16:
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                loss.backward()
                self.optimizer.step()

            self.summary_loss.update(loss.item(), batch_size)
            if self.scheduler.__class__.__name__ != 'ReduceLROnPlateau':
                self.scheduler.step()

        self.history['train'].append(self.summary_loss.avg)
        return self.summary_loss.avg

    def validation(self):
        self.model.eval()
        self.summary_loss.reset()
        iters = len(self.valid_loader)
        for step, (g_x, c_x, cate_x, labels,
                   non_labels) in enumerate(self.valid_loader):
            with torch.no_grad():
                labels = labels.to(self.device)
                g_x = g_x.to(self.device)
                c_x = c_x.to(self.device)
                cate_x = cate_x.to(self.device)
                batch_size = labels.shape[0]
                outputs = self.model(g_x, c_x, cate_x)
                loss = self.evaluator(outputs, labels)

                self.summary_loss.update(loss.detach().item(), batch_size)

        self.history['valid'].append(self.summary_loss.avg)
        return self.summary_loss.avg

    def fit(self, epochs):
        self.log(f'Start training....')
        for e in range(epochs):
            t = time.time()
            loss = self.train_one_epoch()

            # self.log(f'[Train] \t Epoch: {self.epoch}, loss: {loss:.6f}, time: {(time.time() - t):.2f}')
            self.tb_logger.add_scalar('Train/Loss', loss, self.epoch)

            t = time.time()
            loss = self.validation()

            # self.log(f'[Valid] \t Epoch: {self.epoch}, loss: {loss:.6f}, time: {(time.time() - t):.2f}')
            self.tb_logger.add_scalar('Valid/Loss', loss, self.epoch)
            self.post_processing(loss)

            self.epoch += 1
        self.log(f'best epoch: {self.best_epoch}, best loss: {self.best_loss}')
        return self.history

    def post_processing(self, loss):
        if loss < self.best_loss:
            self.best_loss = loss
            self.best_epoch = self.epoch

            self.model.eval()
            torch.save(
                {
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'best_loss': self.best_loss,
                    'epoch': self.epoch,
                },
                f'{os.path.join(self.config.log_dir, f"{self.config.name}_seed{self.seed}_fold{self.fold}.pth")}'
            )
            self.log(f'best model: {self.epoch} epoch - loss: {loss:.6f}')

    def load(self, path):
        checkpoint = torch.load(path,
                                map_location=lambda storage, loc: storage)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.best_loss = checkpoint['best_loss']
        self.epoch = checkpoint['epoch'] + 1

    def log(self, text):
        self.logger.info(text)
Ejemplo n.º 5
0
class Trainer():
    def __init__(self, config, pretrained=True, augmentor=ImgAugTransform()):

        self.config = config
        self.model, self.vocab = build_model(config)

        self.device = config['device']
        self.num_iters = config['trainer']['iters']
        self.beamsearch = config['predictor']['beamsearch']

        self.data_root = config['dataset']['data_root']
        self.train_annotation = config['dataset']['train_annotation']
        self.valid_annotation = config['dataset']['valid_annotation']
        self.train_lmdb = config['dataset']['train_lmdb']
        self.valid_lmdb = config['dataset']['valid_lmdb']
        self.dataset_name = config['dataset']['name']

        self.batch_size = config['trainer']['batch_size']
        self.print_every = config['trainer']['print_every']
        self.valid_every = config['trainer']['valid_every']

        self.image_aug = config['aug']['image_aug']
        self.masked_language_model = config['aug']['masked_language_model']
        self.metrics = config['trainer']['metrics']
        self.is_padding = config['dataset']['is_padding']

        self.tensorboard_dir = config['monitor']['log_dir']
        if not os.path.exists(self.tensorboard_dir):
            os.makedirs(self.tensorboard_dir, exist_ok=True)
        self.writer = SummaryWriter(self.tensorboard_dir)

        # LOGGER
        self.logger = Logger(config['monitor']['log_dir'])
        self.logger.info(config)

        self.iter = 0
        self.best_acc = 0
        self.scheduler = None
        self.is_finetuning = config['trainer']['is_finetuning']

        if self.is_finetuning:
            self.logger.info("Finetuning model ---->")
            if self.model.seq_modeling == 'crnn':
                self.optimizer = Adam(lr=0.0001,
                                      params=self.model.parameters(),
                                      betas=(0.5, 0.999))
            else:
                self.optimizer = AdamW(lr=0.0001,
                                       params=self.model.parameters(),
                                       betas=(0.9, 0.98),
                                       eps=1e-09)

        else:

            self.optimizer = AdamW(self.model.parameters(),
                                   betas=(0.9, 0.98),
                                   eps=1e-09)
            self.scheduler = OneCycleLR(self.optimizer,
                                        total_steps=self.num_iters,
                                        **config['optimizer'])

        if self.model.seq_modeling == 'crnn':
            self.criterion = torch.nn.CTCLoss(self.vocab.pad,
                                              zero_infinity=True)
        else:
            self.criterion = LabelSmoothingLoss(len(self.vocab),
                                                padding_idx=self.vocab.pad,
                                                smoothing=0.1)

        # Pretrained model
        if config['trainer']['pretrained']:
            self.load_weights(config['trainer']['pretrained'])
            self.logger.info("Loaded trained model from: {}".format(
                config['trainer']['pretrained']))

        # Resume
        elif config['trainer']['resume_from']:
            self.load_checkpoint(config['trainer']['resume_from'])
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(torch.device(self.device))

            self.logger.info("Resume training from {}".format(
                config['trainer']['resume_from']))

        # DATASET
        transforms = None
        if self.image_aug:
            transforms = augmentor

        train_lmdb_paths = [
            os.path.join(self.data_root, lmdb_path)
            for lmdb_path in self.train_lmdb
        ]

        self.train_gen = self.data_gen(
            lmdb_paths=train_lmdb_paths,
            data_root=self.data_root,
            annotation=self.train_annotation,
            masked_language_model=self.masked_language_model,
            transform=transforms,
            is_train=True)

        if self.valid_annotation:
            self.valid_gen = self.data_gen(
                lmdb_paths=[os.path.join(self.data_root, self.valid_lmdb)],
                data_root=self.data_root,
                annotation=self.valid_annotation,
                masked_language_model=False)

        self.train_losses = []
        self.logger.info("Number batch samples of training: %d" %
                         len(self.train_gen))
        self.logger.info("Number batch samples of valid: %d" %
                         len(self.valid_gen))

        config_savepath = os.path.join(self.tensorboard_dir, "config.yml")
        if not os.path.exists(config_savepath):
            self.logger.info("Saving config file at: %s" % config_savepath)
            Cfg(config).save(config_savepath)

    def train(self):
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1
            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start
            start = time.time()

            # LOSS
            loss = self.step(batch)
            total_loss += loss
            self.train_losses.append((self.iter, loss))

            total_gpu_time += time.time() - start

            if self.iter % self.print_every == 0:

                info = 'Iter: {:06d} - Train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)
                lastest_loss = total_loss / self.print_every
                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                self.logger.info(info)

            if self.valid_annotation and self.iter % self.valid_every == 0:
                val_time = time.time()
                val_loss = self.validate()
                acc_full_seq, acc_per_char, wer = self.precision(self.metrics)

                self.logger.info("Iter: {:06d}, start validating".format(
                    self.iter))
                info = 'Iter: {:06d} - Valid loss: {:.3f} - Acc full seq: {:.4f} - Acc per char: {:.4f} - WER: {:.4f} - Time: {:.4f}'.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char, wer,
                    time.time() - val_time)
                self.logger.info(info)

                if acc_full_seq > self.best_acc:
                    self.save_weights(self.tensorboard_dir + "/best.pt")
                    self.best_acc = acc_full_seq

                self.logger.info("Iter: {:06d} - Best acc: {:.4f}".format(
                    self.iter, self.best_acc))

                filename = 'last.pt'
                filepath = os.path.join(self.tensorboard_dir, filename)
                self.logger.info("Save checkpoint %s" % filename)
                self.save_checkpoint(filepath)

                log_loss = {'train loss': lastest_loss, 'val loss': val_loss}
                self.writer.add_scalars('Loss', log_loss, self.iter)
                self.writer.add_scalar('WER', wer, self.iter)

    def validate(self):
        self.model.eval()

        total_loss = []

        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                img, tgt_input, tgt_output, tgt_padding_mask = batch[
                    'img'], batch['tgt_input'], batch['tgt_output'], batch[
                        'tgt_padding_mask']

                outputs = self.model(img, tgt_input, tgt_padding_mask)
                #                loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))

                if self.model.seq_modeling == 'crnn':
                    length = batch['labels_len']
                    preds_size = torch.autograd.Variable(
                        torch.IntTensor([outputs.size(0)] * self.batch_size))
                    loss = self.criterion(outputs, tgt_output, preds_size,
                                          length)
                else:
                    outputs = outputs.flatten(0, 1)
                    tgt_output = tgt_output.flatten()
                    loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                del outputs
                del loss

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        img_files = []
        probs_sents = []
        imgs_sents = []

        for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)):
            batch = self.batch_to_device(batch)

            if self.model.seq_modeling != 'crnn':
                if self.beamsearch:
                    translated_sentence = batch_translate_beam_search(
                        batch['img'], self.model)
                    prob = None
                else:
                    translated_sentence, prob = translate(
                        batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist())
            else:
                translated_sentence, prob = translate_crnn(
                    batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist(), crnn=True)

            actual_sent = self.vocab.batch_decode(batch['tgt_output'].tolist())
            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)

            imgs_sents.extend(batch['img'])
            img_files.extend(batch['filenames'])
            probs_sents.extend(prob)

            # Visualize in tensorboard
            if idx == 0:
                try:
                    num_samples = self.config['monitor']['num_samples']
                    fig = plt.figure(figsize=(12, 15))
                    imgs_samples = imgs_sents[:num_samples]
                    preds_samples = pred_sents[:num_samples]
                    actuals_samples = actual_sents[:num_samples]
                    probs_samples = probs_sents[:num_samples]
                    for id_img in range(len(imgs_samples)):
                        img = imgs_samples[id_img]
                        img = img.permute(1, 2, 0)
                        img = img.cpu().detach().numpy()
                        ax = fig.add_subplot(num_samples,
                                             1,
                                             id_img + 1,
                                             xticks=[],
                                             yticks=[])
                        plt.imshow(img)
                        ax.set_title(
                            "LB: {} \n Pred: {:.4f}-{}".format(
                                actuals_samples[id_img], probs_samples[id_img],
                                preds_samples[id_img]),
                            color=('green' if actuals_samples[id_img]
                                   == preds_samples[id_img] else 'red'),
                            fontdict={
                                'fontsize': 18,
                                'fontweight': 'medium'
                            })

                    self.writer.add_figure('predictions vs. actuals',
                                           fig,
                                           global_step=self.iter)
                except Exception as error:
                    print(error)
                    continue

            if sample != None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, img_files, probs_sents, imgs_sents

    def precision(self, sample=None, measure_time=True):
        t1 = time.time()
        pred_sents, actual_sents, _, _, _ = self.predict(sample=sample)
        time_predict = time.time() - t1

        sensitive_case = self.config['predictor']['sensitive_case']
        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        sensitive_case,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        sensitive_case,
                                        mode='per_char')
        wer = compute_accuracy(actual_sents,
                               pred_sents,
                               sensitive_case,
                               mode='wer')

        if measure_time:
            print("Time: {:.4f}".format(time_predict / len(actual_sents)))
        return acc_full_seq, acc_per_char, wer

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16,
                             save_fig=False):

        pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]
            probs = [probs[i] for i in wrongs]
            imgs = [imgs[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}
        ncols = 5
        nrows = int(math.ceil(len(img_files) / ncols))
        fig, ax = plt.subplots(nrows, ncols, figsize=(12, 15))

        for vis_idx in range(0, len(img_files)):
            row = vis_idx // ncols
            col = vis_idx % ncols

            pred_sent = pred_sents[vis_idx]
            actual_sent = actual_sents[vis_idx]
            prob = probs[vis_idx]
            img = imgs[vis_idx].permute(1, 2, 0).cpu().detach().numpy()

            ax[row, col].imshow(img)
            ax[row, col].set_title(
                "Pred: {: <2} \n Actual: {} \n prob: {:.2f}".format(
                    pred_sent, actual_sent, prob),
                fontname=fontname,
                color='r' if pred_sent != actual_sent else 'g')
            ax[row, col].get_xaxis().set_ticks([])
            ax[row, col].get_yaxis().set_ticks([])

        plt.subplots_adjust()
        if save_fig:
            fig.savefig('vis_prediction.png')
        plt.show()

    def log_prediction(self, sample=16, csv_file='model.csv'):
        pred_sents, actual_sents, img_files, probs, imgs = self.predict(sample)
        save_predictions(csv_file, pred_sents, actual_sents, img_files)

    def vis_data(self, sample=20):

        ncols = 5
        nrows = int(math.ceil(sample / ncols))
        fig, ax = plt.subplots(nrows, ncols, figsize=(12, 12))

        num_plots = 0
        for idx, batch in enumerate(self.train_gen):
            for vis_idx in range(self.batch_size):
                row = num_plots // ncols
                col = num_plots % ncols

                img = batch['img'][vis_idx].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(
                    batch['tgt_input'].T[vis_idx].tolist())

                ax[row, col].imshow(img)
                ax[row, col].set_title("Label: {: <2}".format(sent),
                                       fontsize=16,
                                       color='g')

                ax[row, col].get_xaxis().set_ticks([])
                ax[row, col].get_yaxis().set_ticks([])

                num_plots += 1
                if num_plots >= sample:
                    plt.subplots_adjust()
                    fig.savefig('vis_dataset.png')
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']
        self.train_losses = checkpoint['train_losses']
        if self.scheduler is not None:
            self.scheduler.load_state_dict(checkpoint['scheduler'])

        self.best_acc = checkpoint['best_acc']

    def save_checkpoint(self, filename):
        state = {
            'iter':
            self.iter,
            'state_dict':
            self.model.state_dict(),
            'optimizer':
            self.optimizer.state_dict(),
            'train_losses':
            self.train_losses,
            'scheduler':
            None if self.scheduler is None else self.scheduler.state_dict(),
            'best_acc':
            self.best_acc
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))
        if self.is_checkpoint(state_dict):
            self.model.load_state_dict(state_dict['state_dict'])
        else:

            for name, param in self.model.named_parameters():
                if name not in state_dict:
                    print('{} not found'.format(name))
                elif state_dict[name].shape != param.shape:
                    print('{} missmatching shape, required {} but found {}'.
                          format(name, param.shape, state_dict[name].shape))
                    del state_dict[name]
            self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def is_checkpoint(self, checkpoint):
        try:
            checkpoint['state_dict']
        except:
            return False
        else:
            return True

    def batch_to_device(self, batch):
        img = batch['img'].to(self.device, non_blocking=True)
        tgt_input = batch['tgt_input'].to(self.device, non_blocking=True)
        tgt_output = batch['tgt_output'].to(self.device, non_blocking=True)
        tgt_padding_mask = batch['tgt_padding_mask'].to(self.device,
                                                        non_blocking=True)

        batch = {
            'img': img,
            'tgt_input': tgt_input,
            'tgt_output': tgt_output,
            'tgt_padding_mask': tgt_padding_mask,
            'filenames': batch['filenames'],
            'labels_len': batch['labels_len']
        }

        return batch

    def data_gen(self,
                 lmdb_paths,
                 data_root,
                 annotation,
                 masked_language_model=True,
                 transform=None,
                 is_train=False):
        datasets = []
        for lmdb_path in lmdb_paths:
            dataset = OCRDataset(
                lmdb_path=lmdb_path,
                root_dir=data_root,
                annotation_path=annotation,
                vocab=self.vocab,
                transform=transform,
                image_height=self.config['dataset']['image_height'],
                image_min_width=self.config['dataset']['image_min_width'],
                image_max_width=self.config['dataset']['image_max_width'],
                separate=self.config['dataset']['separate'],
                batch_size=self.batch_size,
                is_padding=self.is_padding)
            datasets.append(dataset)
        if len(self.train_lmdb) > 1:
            dataset = torch.utils.data.ConcatDataset(datasets)

        if self.is_padding:
            sampler = None
        else:
            sampler = ClusterRandomSampler(dataset, self.batch_size, True)

        collate_fn = Collator(masked_language_model)

        gen = DataLoader(dataset,
                         batch_size=self.batch_size,
                         sampler=sampler,
                         collate_fn=collate_fn,
                         shuffle=is_train,
                         drop_last=self.model.seq_modeling == 'crnn',
                         **self.config['dataloader'])

        return gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        img, tgt_input, tgt_output, tgt_padding_mask = batch['img'], batch[
            'tgt_input'], batch['tgt_output'], batch['tgt_padding_mask']

        outputs = self.model(img,
                             tgt_input,
                             tgt_key_padding_mask=tgt_padding_mask)
        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))

        if self.model.seq_modeling == 'crnn':
            length = batch['labels_len']
            preds_size = torch.autograd.Variable(
                torch.IntTensor([outputs.size(0)] * self.batch_size))
            loss = self.criterion(outputs, tgt_output, preds_size, length)
        else:
            outputs = outputs.view(
                -1, outputs.size(2))  # flatten(0, 1)    # B*S x N_class
            tgt_output = tgt_output.view(-1)  # flatten()    # B*S
            loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()

        if not self.is_finetuning:
            self.scheduler.step()

        loss_item = loss.item()

        return loss_item

    def count_parameters(self, model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def gen_pseudo_labels(self, outfile=None):
        pred_sents = []
        img_files = []
        probs_sents = []

        for idx, batch in enumerate(tqdm.tqdm(self.valid_gen)):
            batch = self.batch_to_device(batch)

            if self.model.seq_modeling != 'crnn':
                if self.beamsearch:
                    translated_sentence = batch_translate_beam_search(
                        batch['img'], self.model)
                    prob = None
                else:
                    translated_sentence, prob = translate(
                        batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist())
            else:
                translated_sentence, prob = translate_crnn(
                    batch['img'], self.model)
                pred_sent = self.vocab.batch_decode(
                    translated_sentence.tolist(), crnn=True)

            pred_sents.extend(pred_sent)
            img_files.extend(batch['filenames'])
            probs_sents.extend(prob)
        assert len(pred_sents) == len(img_files) and len(img_files) == len(
            probs_sents)
        with open(outfile, 'w', encoding='utf-8') as f:
            for anno in zip(img_files, pred_sents, probs_sents):
                f.write('||||'.join([anno[0], anno[1],
                                     str(float(anno[2]))]) + '\n')
Ejemplo n.º 6
0
def train(args, training_features, model, tokenizer):
    """ Train the model """
    wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"),
               config=args,
               name=args.run_name)
    wandb.watch(model)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
    else:
        amp = None

    # model recover
    recover_step = utils.get_max_epoch_model(args.output_dir)

    # if recover_step:
    #     model_recover_checkpoint = os.path.join(args.output_dir, "model.{}.bin".format(recover_step))
    #     logger.info(" ** Recover model checkpoint in %s ** ", model_recover_checkpoint)
    #     model_state_dict = torch.load(model_recover_checkpoint, map_location='cpu')
    #     optimizer_recover_checkpoint = os.path.join(args.output_dir, "optim.{}.bin".format(recover_step))
    #     checkpoint_state_dict = torch.load(optimizer_recover_checkpoint, map_location='cpu')
    #     checkpoint_state_dict['model'] = model_state_dict
    # else:
    checkpoint_state_dict = None

    model.to(args.device)
    model, optimizer = prepare_for_training(args,
                                            model,
                                            checkpoint_state_dict,
                                            amp=amp)

    if args.n_gpu == 0 or args.no_cuda:
        per_node_train_batch_size = args.per_gpu_train_batch_size * args.gradient_accumulation_steps
    else:
        per_node_train_batch_size = args.per_gpu_train_batch_size * args.n_gpu * args.gradient_accumulation_steps

    train_batch_size = per_node_train_batch_size * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)
    global_step = recover_step if recover_step else 0

    if args.num_training_steps == -1:
        args.num_training_steps = int(args.num_training_epochs *
                                      len(training_features) /
                                      train_batch_size)

    if args.warmup_portion:
        args.num_warmup_steps = args.warmup_portion * args.num_training_steps

    if args.scheduler == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.num_warmup_steps,
            num_training_steps=args.num_training_steps,
            last_epoch=-1)

    elif args.scheduler == "constant":
        scheduler = get_constant_schedule(optimizer, last_epoch=-1)

    elif args.scheduler == "1cycle":
        scheduler = OneCycleLR(optimizer,
                               max_lr=args.learning_rate,
                               total_steps=args.num_training_steps,
                               pct_start=args.warmup_portion,
                               anneal_strategy=args.anneal_strategy,
                               final_div_factor=1e4,
                               last_epoch=-1)

    else:
        assert False

    if checkpoint_state_dict:
        scheduler.load_state_dict(checkpoint_state_dict["lr_scheduler"])

    train_dataset = utils.Seq2seqDatasetForBert(
        features=training_features,
        max_source_len=args.max_source_seq_length,
        max_target_len=args.max_target_seq_length,
        vocab_size=tokenizer.vocab_size,
        cls_id=tokenizer.cls_token_id,
        sep_id=tokenizer.sep_token_id,
        pad_id=tokenizer.pad_token_id,
        mask_id=tokenizer.mask_token_id,
        random_prob=args.random_prob,
        keep_prob=args.keep_prob,
        offset=train_batch_size * global_step,
        num_training_instances=train_batch_size * args.num_training_steps,
    )

    logger.info("Check dataset:")
    for i in range(5):
        source_ids, target_ids, pseudo_ids, num_source_tokens, num_target_tokens = train_dataset.__getitem__(
            i)
        logger.info("Instance-%d" % i)
        logger.info("Source tokens = %s" %
                    " ".join(tokenizer.convert_ids_to_tokens(source_ids)))
        logger.info("Target tokens = %s" %
                    " ".join(tokenizer.convert_ids_to_tokens(target_ids)))

    logger.info("Mode = %s" % str(model))

    # Train!
    logger.info("  ***** Running training *****  *")
    logger.info("  Num examples = %d", len(training_features))
    logger.info("  Num Epochs = %.2f",
                len(train_dataset) / len(training_features))
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info("  Batch size per node = %d", per_node_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", args.num_training_steps)

    if args.num_training_steps <= global_step:
        logger.info(
            "Training is done. Please use a new dir or clean this dir!")
    else:
        # The training features are shuffled
        train_sampler = SequentialSampler(train_dataset) \
            if args.local_rank == -1 else DistributedSampler(train_dataset, shuffle=False)
        train_dataloader = DataLoader(
            train_dataset,
            sampler=train_sampler,
            batch_size=per_node_train_batch_size //
            args.gradient_accumulation_steps,
            collate_fn=utils.batch_list_to_batch_tensors)

        train_iterator = tqdm.tqdm(train_dataloader,
                                   initial=global_step,
                                   desc="Iter (loss=X.XXX, lr=X.XXXXXXX)",
                                   disable=args.local_rank not in [-1, 0])

        model.train()
        model.zero_grad()

        tr_loss, logging_loss = 0.0, 0.0

        for step, batch in enumerate(train_iterator):
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'source_ids': batch[0],
                'target_ids': batch[1],
                'pseudo_ids': batch[2],
                'num_source_tokens': batch[3],
                'num_target_tokens': batch[4]
            }
            loss = model(**inputs)
            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training

            train_iterator.set_description(
                'Iter (loss=%5.3f) lr=%9.7f' %
                (loss.item(), scheduler.get_last_lr()[0]))

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            logging_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    wandb.log(
                        {
                            'lr': scheduler.get_last_lr()[0],
                            'loss': logging_loss / args.logging_steps
                        },
                        step=global_step)

                    logger.info(" Step [%d ~ %d]: %.2f",
                                global_step - args.logging_steps, global_step,
                                logging_loss)
                    logging_loss = 0.0

                if args.local_rank in [-1, 0] and args.save_steps > 0 and \
                        (global_step % args.save_steps == 0 or global_step == args.num_training_steps):

                    save_path = os.path.join(args.output_dir,
                                             "ckpt-%d" % global_step)
                    os.makedirs(save_path, exist_ok=True)
                    model_to_save = model.module if hasattr(
                        model, "module") else model
                    model_to_save.save_pretrained(save_path)

                    # optim_to_save = {
                    #     "optimizer": optimizer.state_dict(),
                    #     "lr_scheduler": scheduler.state_dict(),
                    # }
                    # if args.fp16:
                    #     optim_to_save["amp"] = amp.state_dict()
                    # torch.save(
                    #     optim_to_save, os.path.join(args.output_dir, 'optim.{}.bin'.format(global_step)))

                    logger.info("Saving model checkpoint %d into %s",
                                global_step, save_path)

    wandb.save(f'{save_path}/*')
Ejemplo n.º 7
0
class Trainer():
    def __init__(self, alphabets_, list_ngram):

        self.vocab = Vocab(alphabets_)
        self.synthesizer = SynthesizeData(vocab_path="")
        self.list_ngrams_train, self.list_ngrams_valid = self.train_test_split(
            list_ngram, test_size=0.1)
        print("Loaded data!!!")
        print("Total training samples: ", len(self.list_ngrams_train))
        print("Total valid samples: ", len(self.list_ngrams_valid))

        INPUT_DIM = self.vocab.__len__()
        OUTPUT_DIM = self.vocab.__len__()

        self.device = DEVICE
        self.num_iters = NUM_ITERS
        self.beamsearch = BEAM_SEARCH

        self.batch_size = BATCH_SIZE
        self.print_every = PRINT_PER_ITER
        self.valid_every = VALID_PER_ITER

        self.checkpoint = CHECKPOINT
        self.export_weights = EXPORT
        self.metrics = MAX_SAMPLE_VALID
        logger = LOG

        if logger:
            self.logger = Logger(logger)

        self.iter = 0

        self.model = Seq2Seq(input_dim=INPUT_DIM,
                             output_dim=OUTPUT_DIM,
                             encoder_embbeded=ENC_EMB_DIM,
                             decoder_embedded=DEC_EMB_DIM,
                             encoder_hidden=ENC_HID_DIM,
                             decoder_hidden=DEC_HID_DIM,
                             encoder_dropout=ENC_DROPOUT,
                             decoder_dropout=DEC_DROPOUT)

        self.optimizer = AdamW(self.model.parameters(),
                               betas=(0.9, 0.98),
                               eps=1e-09)
        self.scheduler = OneCycleLR(self.optimizer,
                                    total_steps=self.num_iters,
                                    pct_start=PCT_START,
                                    max_lr=MAX_LR)

        self.criterion = LabelSmoothingLoss(len(self.vocab),
                                            padding_idx=self.vocab.pad,
                                            smoothing=0.1)

        self.train_gen = self.data_gen(self.list_ngrams_train,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=True)
        self.valid_gen = self.data_gen(self.list_ngrams_valid,
                                       self.synthesizer,
                                       self.vocab,
                                       is_train=False)

        self.train_losses = []

        # to device
        self.model.to(self.device)
        self.criterion.to(self.device)

    def train_test_split(self, list_phrases, test_size=0.1):
        list_phrases = list_phrases
        train_idx = int(len(list_phrases) * (1 - test_size))
        list_phrases_train = list_phrases[:train_idx]
        list_phrases_valid = list_phrases[train_idx:]
        return list_phrases_train, list_phrases_valid

    def data_gen(self, list_ngrams_np, synthesizer, vocab, is_train=True):
        dataset = AutoCorrectDataset(list_ngrams_np,
                                     transform_noise=synthesizer,
                                     vocab=vocab,
                                     maxlen=MAXLEN)

        shuffle = True if is_train else False
        gen = DataLoader(dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=shuffle,
                         drop_last=False)

        return gen

    def step(self, batch):
        self.model.train()

        batch = self.batch_to_device(batch)
        src, tgt = batch['src'], batch['tgt']
        src, tgt = src.transpose(1, 0), tgt.transpose(
            1, 0)  # batch x src_len -> src_len x batch

        outputs = self.model(
            src, tgt)  # src : src_len x B, outpus : B x tgt_len x vocab

        #        loss = self.criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_output, 'b o -> (b o)'))
        outputs = outputs.view(-1, outputs.size(2))  # flatten(0, 1)

        tgt_output = tgt.transpose(0, 1).reshape(
            -1)  # flatten()   # tgt: tgt_len xB , need convert to B x tgt_len

        loss = self.criterion(outputs, tgt_output)

        self.optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.optimizer.step()
        self.scheduler.step()

        loss_item = loss.item()

        return loss_item

    def train(self):
        print("Begin training from iter: ", self.iter)
        total_loss = 0

        total_loader_time = 0
        total_gpu_time = 0
        best_acc = -1

        data_iter = iter(self.train_gen)
        for i in range(self.num_iters):
            self.iter += 1

            start = time.time()

            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(self.train_gen)
                batch = next(data_iter)

            total_loader_time += time.time() - start

            start = time.time()
            loss = self.step(batch)
            total_gpu_time += time.time() - start

            total_loss += loss
            self.train_losses.append((self.iter, loss))

            if self.iter % self.print_every == 0:
                info = 'iter: {:06d} - train loss: {:.3f} - lr: {:.2e} - load time: {:.2f} - gpu time: {:.2f}'.format(
                    self.iter, total_loss / self.print_every,
                    self.optimizer.param_groups[0]['lr'], total_loader_time,
                    total_gpu_time)

                total_loss = 0
                total_loader_time = 0
                total_gpu_time = 0
                print(info)
                self.logger.log(info)

            if self.iter % self.valid_every == 0:
                val_loss, preds, actuals, inp_sents = self.validate()
                acc_full_seq, acc_per_char, cer = self.precision(self.metrics)

                info = 'iter: {:06d} - valid loss: {:.3f} - acc full seq: {:.4f} - acc per char: {:.4f} - CER: {:.4f} '.format(
                    self.iter, val_loss, acc_full_seq, acc_per_char, cer)
                print(info)
                print("--- Sentence predict ---")
                for pred, inp, label in zip(preds, inp_sents, actuals):
                    infor_predict = 'Pred: {} - Inp: {} - Label: {}'.format(
                        pred, inp, label)
                    print(infor_predict)
                    self.logger.log(infor_predict)
                self.logger.log(info)

                if acc_full_seq > best_acc:
                    self.save_weights(self.export_weights)
                    best_acc = acc_full_seq
                self.save_checkpoint(self.checkpoint)

    def validate(self):
        self.model.eval()

        total_loss = []
        max_step = self.metrics / self.batch_size
        with torch.no_grad():
            for step, batch in enumerate(self.valid_gen):
                batch = self.batch_to_device(batch)
                src, tgt = batch['src'], batch['tgt']
                src, tgt = src.transpose(1, 0), tgt.transpose(1, 0)

                outputs = self.model(src, tgt, 0)  # turn off teaching force

                outputs = outputs.flatten(0, 1)
                tgt_output = tgt.flatten()
                loss = self.criterion(outputs, tgt_output)

                total_loss.append(loss.item())

                preds, actuals, inp_sents, probs = self.predict(5)

                del outputs
                del loss
                if step > max_step:
                    break

        total_loss = np.mean(total_loss)
        self.model.train()

        return total_loss, preds[:3], actuals[:3], inp_sents[:3]

    def predict(self, sample=None):
        pred_sents = []
        actual_sents = []
        inp_sents = []

        for batch in self.valid_gen:
            batch = self.batch_to_device(batch)

            if self.beamsearch:
                translated_sentence = batch_translate_beam_search(
                    batch['src'], self.model)
                prob = None
            else:
                translated_sentence, prob = translate(batch['src'], self.model)

            pred_sent = self.vocab.batch_decode(translated_sentence.tolist())
            actual_sent = self.vocab.batch_decode(batch['tgt'].tolist())
            inp_sent = self.vocab.batch_decode(batch['src'].tolist())

            pred_sents.extend(pred_sent)
            actual_sents.extend(actual_sent)
            inp_sents.extend(inp_sent)

            if sample is not None and len(pred_sents) > sample:
                break

        return pred_sents, actual_sents, inp_sents, prob

    def precision(self, sample=None):

        pred_sents, actual_sents, _, _ = self.predict(sample=sample)

        acc_full_seq = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='full_sequence')
        acc_per_char = compute_accuracy(actual_sents,
                                        pred_sents,
                                        mode='per_char')
        cer = compute_accuracy(actual_sents, pred_sents, mode='CER')

        return acc_full_seq, acc_per_char, cer

    def visualize_prediction(self,
                             sample=16,
                             errorcase=False,
                             fontname='serif',
                             fontsize=16):

        pred_sents, actual_sents, img_files, probs = self.predict(sample)

        if errorcase:
            wrongs = []
            for i in range(len(img_files)):
                if pred_sents[i] != actual_sents[i]:
                    wrongs.append(i)

            pred_sents = [pred_sents[i] for i in wrongs]
            actual_sents = [actual_sents[i] for i in wrongs]
            img_files = [img_files[i] for i in wrongs]
            probs = [probs[i] for i in wrongs]

        img_files = img_files[:sample]

        fontdict = {'family': fontname, 'size': fontsize}

    def visualize_dataset(self, sample=16, fontname='serif'):
        n = 0
        for batch in self.train_gen:
            for i in range(self.batch_size):
                img = batch['img'][i].numpy().transpose(1, 2, 0)
                sent = self.vocab.decode(batch['tgt_input'].T[i].tolist())

                n += 1
                if n >= sample:
                    return

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)

        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        self.model.load_state_dict(checkpoint['state_dict'])
        self.iter = checkpoint['iter']

        self.train_losses = checkpoint['train_losses']

    def save_checkpoint(self, filename):
        state = {
            'iter': self.iter,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'scheduler': self.scheduler.state_dict()
        }

        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(state, filename)

    def load_weights(self, filename):
        state_dict = torch.load(filename,
                                map_location=torch.device(self.device))

        for name, param in self.model.named_parameters():
            if name not in state_dict:
                print('{} not found'.format(name))
            elif state_dict[name].shape != param.shape:
                print('{} missmatching shape, required {} but found {}'.format(
                    name, param.shape, state_dict[name].shape))
                del state_dict[name]

        self.model.load_state_dict(state_dict, strict=False)

    def save_weights(self, filename):
        path, _ = os.path.split(filename)
        os.makedirs(path, exist_ok=True)

        torch.save(self.model.state_dict(), filename)

    def batch_to_device(self, batch):

        src = batch['src'].to(self.device, non_blocking=True)
        tgt = batch['tgt'].to(self.device, non_blocking=True)

        batch = {'src': src, 'tgt': tgt}

        return batch
Ejemplo n.º 8
0
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    _logger.info('====================\n\n'
                 'Actfun: {}\n'
                 'LR: {}\n'
                 'Epochs: {}\n'
                 'p: {}\n'
                 'k: {}\n'
                 'g: {}\n'
                 'Extra channel multiplier: {}\n'
                 'Weight Init: {}\n'
                 '\n===================='.format(args.actfun, args.lr,
                                                 args.epochs, args.p, args.k,
                                                 args.g,
                                                 args.extra_channel_mult,
                                                 args.weight_init))

    # ================================================================================= Loading models
    pre_model = create_model(
        args.model,
        pretrained=True,
        actfun='swish',
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint,
        p=args.p,
        k=args.k,
        g=args.g,
        extra_channel_mult=args.extra_channel_mult,
        weight_init_name=args.weight_init,
        partial_ho_actfun=args.partial_ho_actfun)
    pre_model_layers = list(pre_model.children())
    pre_model = torch.nn.Sequential(*pre_model_layers[:-1])
    pre_model.to(device)

    model = MLP.MLP(actfun=args.actfun,
                    input_dim=1280,
                    output_dim=args.num_classes,
                    k=args.k,
                    p=args.p,
                    g=args.g,
                    num_params=1_000_000,
                    permute_type='shuffle')
    model.to(device)

    # ================================================================================= Loading dataset
    util.seed_all(args.seed)
    if args.data == 'caltech101' and not os.path.exists('caltech101'):
        dir_root = r'101_ObjectCategories'
        dir_new = r'caltech101'
        dir_new_train = os.path.join(dir_new, 'train')
        dir_new_val = os.path.join(dir_new, 'val')
        dir_new_test = os.path.join(dir_new, 'test')
        if not os.path.exists(dir_new):
            os.mkdir(dir_new)
            os.mkdir(dir_new_train)
            os.mkdir(dir_new_val)
            os.mkdir(dir_new_test)

        for dir2 in os.listdir(dir_root):
            if dir2 != 'BACKGROUND_Google':
                curr_path = os.path.join(dir_root, dir2)
                new_path_train = os.path.join(dir_new_train, dir2)
                new_path_val = os.path.join(dir_new_val, dir2)
                new_path_test = os.path.join(dir_new_test, dir2)
                if not os.path.exists(new_path_train):
                    os.mkdir(new_path_train)
                if not os.path.exists(new_path_val):
                    os.mkdir(new_path_val)
                if not os.path.exists(new_path_test):
                    os.mkdir(new_path_test)

                train_upper = int(0.8 * len(os.listdir(curr_path)))
                val_upper = int(0.9 * len(os.listdir(curr_path)))
                curr_files_all = os.listdir(curr_path)
                curr_files_train = curr_files_all[:train_upper]
                curr_files_val = curr_files_all[train_upper:val_upper]
                curr_files_test = curr_files_all[val_upper:]

                for file in curr_files_train:
                    copyfile(os.path.join(curr_path, file),
                             os.path.join(new_path_train, file))
                for file in curr_files_val:
                    copyfile(os.path.join(curr_path, file),
                             os.path.join(new_path_val, file))
                for file in curr_files_test:
                    copyfile(os.path.join(curr_path, file),
                             os.path.join(new_path_test, file))
    time.sleep(5)

    # create the train and eval datasets
    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        _logger.error(
            'Training folder does not exist at: {}'.format(train_dir))
        exit(1)
    dataset_train = Dataset(train_dir)

    eval_dir = os.path.join(args.data, 'val')
    if not os.path.isdir(eval_dir):
        eval_dir = os.path.join(args.data, 'validation')
        if not os.path.isdir(eval_dir):
            _logger.error(
                'Validation folder does not exist at: {}'.format(eval_dir))
            exit(1)
    dataset_eval = Dataset(eval_dir)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(mixup_alpha=args.mixup,
                          cutmix_alpha=args.cutmix,
                          cutmix_minmax=args.cutmix_minmax,
                          prob=args.mixup_prob,
                          switch_prob=args.mixup_switch_prob,
                          mode=args.mixup_mode,
                          label_smoothing=args.smoothing,
                          num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # create data loaders w/ augmentation pipeline
    train_interpolation = args.train_interpolation
    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # ================================================================================= Optimizer / scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=args.lr,
        epochs=args.epochs,
        steps_per_epoch=int(math.floor(len(dataset_train) / args.batch_size)),
        cycle_momentum=False)

    # ================================================================================= Save file / checkpoints
    fieldnames = [
        'dataset', 'seed', 'epoch', 'time', 'actfun', 'model', 'batch_size',
        'alpha_primes', 'alphas', 'num_params', 'k', 'p', 'g', 'perm_method',
        'gen_gap', 'epoch_train_loss', 'epoch_train_acc',
        'epoch_aug_train_loss', 'epoch_aug_train_acc', 'epoch_val_loss',
        'epoch_val_acc', 'curr_lr', 'found_lr', 'epochs'
    ]
    filename = 'out_{}_{}_{}_{}'.format(datetime.date.today(), args.actfun,
                                        args.data, args.seed)
    outfile_path = os.path.join(args.output, filename) + '.csv'
    checkpoint_path = os.path.join(args.check_path, filename) + '.pth'
    if not os.path.exists(outfile_path):
        with open(outfile_path, mode='w') as out_file:
            writer = csv.DictWriter(out_file,
                                    fieldnames=fieldnames,
                                    lineterminator='\n')
            writer.writeheader()

    epoch = 1
    checkpoint = torch.load(checkpoint_path) if os.path.exists(
        checkpoint_path) else None
    if checkpoint is not None:
        pre_model.load_state_dict(checkpoint['pre_model_state_dict'])
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        epoch = checkpoint['epoch']
        pre_model.to(device)
        model.to(device)
        print("*** LOADED CHECKPOINT ***"
              "\n{}"
              "\nSeed: {}"
              "\nEpoch: {}"
              "\nActfun: {}"
              "\np: {}"
              "\nk: {}"
              "\ng: {}"
              "\nperm_method: {}".format(checkpoint_path,
                                         checkpoint['curr_seed'],
                                         checkpoint['epoch'],
                                         checkpoint['actfun'], checkpoint['p'],
                                         checkpoint['k'], checkpoint['g'],
                                         checkpoint['perm_method']))

    args.mix_pre_apex = False
    if args.control_amp == 'apex':
        args.mix_pre_apex = True
        model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

    # ================================================================================= Training
    while epoch <= args.epochs:

        if args.check_path != '':
            torch.save(
                {
                    'pre_model_state_dict': pre_model.state_dict(),
                    'model_state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'curr_seed': args.seed,
                    'epoch': epoch,
                    'actfun': args.actfun,
                    'p': args.p,
                    'k': args.k,
                    'g': args.g,
                    'perm_method': 'shuffle'
                }, checkpoint_path)

        util.seed_all((args.seed * args.epochs) + epoch)
        start_time = time.time()
        args.mix_pre = False
        if args.control_amp == 'native':
            args.mix_pre = True
            scaler = torch.cuda.amp.GradScaler()

        # ---- Training
        model.train()
        total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
        for batch_idx, (x, targetx) in enumerate(loader_train):
            x, targetx = x.to(device), targetx.to(device)
            optimizer.zero_grad()
            if args.mix_pre:
                with torch.cuda.amp.autocast():
                    with torch.no_grad():
                        x = pre_model(x)
                    output = model(x)
                    train_loss = criterion(output, targetx)
                total_train_loss += train_loss
                n += 1
                scaler.scale(train_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            elif args.mix_pre_apex:
                with torch.no_grad():
                    x = pre_model(x)
                output = model(x)
                train_loss = criterion(output, targetx)
                total_train_loss += train_loss
                n += 1
                with amp.scale_loss(train_loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()
            else:
                with torch.no_grad():
                    x = pre_model(x)
                output = model(x)
                train_loss = criterion(output, targetx)
                total_train_loss += train_loss
                n += 1
                train_loss.backward()
                optimizer.step()
            scheduler.step()
            _, prediction = torch.max(output.data, 1)
            num_correct += torch.sum(prediction == targetx.data)
            num_total += len(prediction)
        epoch_aug_train_loss = total_train_loss / n
        epoch_aug_train_acc = num_correct * 1.0 / num_total

        alpha_primes = []
        alphas = []
        if model.actfun == 'combinact':
            for i, layer_alpha_primes in enumerate(model.all_alpha_primes):
                curr_alpha_primes = torch.mean(layer_alpha_primes, dim=0)
                curr_alphas = F.softmax(curr_alpha_primes, dim=0).data.tolist()
                curr_alpha_primes = curr_alpha_primes.tolist()
                alpha_primes.append(curr_alpha_primes)
                alphas.append(curr_alphas)

        model.eval()
        with torch.no_grad():
            total_val_loss, n, num_correct, num_total = 0, 0, 0, 0
            for batch_idx, (y, targety) in enumerate(loader_eval):
                y, targety = y.to(device), targety.to(device)
                with torch.no_grad():
                    y = pre_model(y)
                output = model(y)
                val_loss = criterion(output, targety)
                total_val_loss += val_loss
                n += 1
                _, prediction = torch.max(output.data, 1)
                num_correct += torch.sum(prediction == targety.data)
                num_total += len(prediction)
            epoch_val_loss = total_val_loss / n
            epoch_val_acc = num_correct * 1.0 / num_total
        lr_curr = 0
        for param_group in optimizer.param_groups:
            lr_curr = param_group['lr']
        print(
            "    Epoch {}: LR {:1.5f} ||| aug_train_acc {:1.4f} | val_acc {:1.4f} ||| "
            "aug_train_loss {:1.4f} | val_loss {:1.4f} ||| time = {:1.4f}".
            format(epoch, lr_curr, epoch_aug_train_acc, epoch_val_acc,
                   epoch_aug_train_loss, epoch_val_loss,
                   (time.time() - start_time)),
            flush=True)

        epoch_train_loss = 0
        epoch_train_acc = 0
        if epoch == args.epochs:
            with torch.no_grad():
                total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
                for batch_idx, (x, targetx) in enumerate(loader_train):
                    x, targetx = x.to(device), targetx.to(device)
                    with torch.no_grad():
                        x = pre_model(x)
                    output = model(x)
                    train_loss = criterion(output, targetx)
                    total_train_loss += train_loss
                    n += 1
                    _, prediction = torch.max(output.data, 1)
                    num_correct += torch.sum(prediction == targetx.data)
                    num_total += len(prediction)
                epoch_aug_train_loss = total_train_loss / n
                epoch_aug_train_acc = num_correct * 1.0 / num_total

                total_train_loss, n, num_correct, num_total = 0, 0, 0, 0
                for batch_idx, (x, targetx) in enumerate(loader_eval):
                    x, targetx = x.to(device), targetx.to(device)
                    with torch.no_grad():
                        x = pre_model(x)
                    output = model(x)
                    train_loss = criterion(output, targetx)
                    total_train_loss += train_loss
                    n += 1
                    _, prediction = torch.max(output.data, 1)
                    num_correct += torch.sum(prediction == targetx.data)
                    num_total += len(prediction)
                epoch_train_loss = total_val_loss / n
                epoch_train_acc = num_correct * 1.0 / num_total

        # Outputting data to CSV at end of epoch
        with open(outfile_path, mode='a') as out_file:
            writer = csv.DictWriter(out_file,
                                    fieldnames=fieldnames,
                                    lineterminator='\n')
            writer.writerow({
                'dataset': args.data,
                'seed': args.seed,
                'epoch': epoch,
                'time': (time.time() - start_time),
                'actfun': model.actfun,
                'model': args.model,
                'batch_size': args.batch_size,
                'alpha_primes': alpha_primes,
                'alphas': alphas,
                'num_params': util.get_model_params(model),
                'k': args.k,
                'p': args.p,
                'g': args.g,
                'perm_method': 'shuffle',
                'gen_gap': float(epoch_val_loss - epoch_train_loss),
                'epoch_train_loss': float(epoch_train_loss),
                'epoch_train_acc': float(epoch_train_acc),
                'epoch_aug_train_loss': float(epoch_aug_train_loss),
                'epoch_aug_train_acc': float(epoch_aug_train_acc),
                'epoch_val_loss': float(epoch_val_loss),
                'epoch_val_acc': float(epoch_val_acc),
                'curr_lr': lr_curr,
                'found_lr': args.lr,
                'epochs': args.epochs
            })

        epoch += 1