Exemplo n.º 1
0
def init_optim(
    model: type[BaseModelClass],
    optim: type[Optimizer] | Literal["SGD", "Adam", "AdamW"],
    learning_rate: float,
    weight_decay: float,
    momentum: float,
    device: type[torch.device] | Literal["cuda", "cpu"],
    milestones: Iterable = (),
    gamma: float = 0.3,
    resume: str = None,
    **kwargs,
) -> tuple[Optimizer, _LRScheduler]:
    """Initialize Optimizer and Scheduler.

    Args:
        model (type[BaseModelClass]): Model to be optimized.
        optim (type[Optimizer] | "SGD" | "Adam" | "AdamW"): Which optimizer to use
        learning_rate (float): Learning rate for optimization
        weight_decay (float): Weight decay for optimizer
        momentum (float): Momentum for optimizer
        device (type[torch.device] | "cuda" | "cpu"): Device the model will run on
        milestones (Iterable, optional): When to decay learning rate. Defaults to ().
        gamma (float, optional): Multiplier for learning rate decay. Defaults to 0.3.
        resume (str, optional): Path to model checkpoint to resume. Defaults to None.


    Returns:
        tuple[Optimizer, _LRScheduler]: Optimizer and scheduler for given model
    """
    # Select Optimiser
    if optim == "SGD":
        optimizer = SGD(
            model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            momentum=momentum,
        )
    elif optim == "Adam":
        optimizer = Adam(model.parameters(),
                         lr=learning_rate,
                         weight_decay=weight_decay)
    elif optim == "AdamW":
        optimizer = AdamW(model.parameters(),
                          lr=learning_rate,
                          weight_decay=weight_decay)
    else:
        raise NameError("Only SGD, Adam or AdamW are allowed as --optim")

    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

    if resume:
        # TODO work out how to ensure that we are using the same optimizer
        # when resuming such that the state dictionaries do not clash.
        # TODO breaking the function apart means we load the checkpoint twice.
        checkpoint = torch.load(resume, map_location=device)
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

    return optimizer, scheduler
Exemplo n.º 2
0
def make_optimizer_and_schedule(args, model, checkpoint, params):
    """
    *Internal Function* (called directly from train_model)

    Creates an optimizer and a schedule for a given model, restoring from a
    checkpoint if it is non-null.

    Args:
        args (object) : an arguments object, see
            :meth:`~robustness.train.train_model` for details
        model (AttackerModel) : the model to create the optimizer for
        checkpoint (dict) : a loaded checkpoint saved by this library and loaded
            with `ch.load`
        params (list|None) : a list of parameters that should be updatable, all
            other params will not update. If ``None``, update all params 

    Returns:
        An optimizer (ch.nn.optim.Optimizer) and a scheduler
            (ch.nn.optim.lr_schedulers module).
    """
    # Make optimizer
    param_list = model.parameters() if params is None else params
    optimizer = SGD(param_list,
                    args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    # Make schedule
    schedule = None
    if args.step_lr:
        schedule = lr_scheduler.StepLR(optimizer, step_size=args.step_lr)
    elif args.custom_schedule == 'cyclic':
        eps = args.epochs
        lr_func = lambda t: np.interp([t], [0, eps * 2 // 5, eps],
                                      [0, args.lr, 0])[0]
        schedule = lr_scheduler.LambdaLR(optimizer, lr_func)
    elif args.custom_schedule:
        cs = args.custom_schedule
        periods = eval(cs) if type(cs) is str else cs

        def lr_func(ep):
            for (milestone, lr) in reversed(periods):
                if ep > milestone: return lr / args.lr
            return args.lr

        schedule = lr_scheduler.LambdaLR(optimizer, lr_func)

    # Fast-forward the optimizer and the scheduler if resuming
    if checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
        try:
            schedule.load_state_dict(checkpoint['schedule'])
        except:
            steps_to_take = checkpoint['epoch']
            print('Could not load schedule (was probably LambdaLR).'
                  f' Stepping {steps_to_take} times instead...')
            for i in range(steps_to_take):
                schedule.step()

    return optimizer, schedule
Exemplo n.º 3
0
def main():
    if not torch.cuda.is_available():
        raise Exception("need gpu to train network!")

    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    cudnn.benchmark = True
    cudnn.enabled = True

    logger = get_logger(__name__, Config.log)

    Config.gpus = torch.cuda.device_count()
    logger.info("use {} gpus".format(Config.gpus))
    config = {
        key: value
        for key, value in Config.__dict__.items() if not key.startswith("__")
    }
    logger.info(f"args: {config}")

    start_time = time.time()

    # dataset and dataloader
    logger.info("start loading data")

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    train_dataset = ImageFolder(Config.train_dataset_path, train_transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.batch_size,
        shuffle=True,
        num_workers=Config.num_workers,
        pin_memory=True,
    )
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    val_dataset = ImageFolder(Config.val_dataset_path, val_transform)
    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.batch_size,
        num_workers=Config.num_workers,
        pin_memory=True,
    )
    logger.info("finish loading data")

    # network
    net = ChannelDistillResNet1834(Config.num_classes, Config.dataset_type)
    net = nn.DataParallel(net).cuda()

    # loss and optimizer
    criterion = []
    for loss_item in Config.loss_list:
        loss_name = loss_item["loss_name"]
        loss_type = loss_item["loss_type"]
        if "kd" in loss_type:
            criterion.append(losses.__dict__[loss_name](loss_item["T"]).cuda())
        else:
            criterion.append(losses.__dict__[loss_name]().cuda())

    optimizer = SGD(net.parameters(),
                    lr=Config.lr,
                    momentum=0.9,
                    weight_decay=1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1)

    # only evaluate
    if Config.evaluate:
        # load best model
        if not os.path.isfile(Config.evaluate):
            raise Exception(
                f"{Config.evaluate} is not a file, please check it again")
        logger.info("start evaluating")
        logger.info(f"start resuming model from {Config.evaluate}")
        checkpoint = torch.load(Config.evaluate,
                                map_location=torch.device("cpu"))
        net.load_state_dict(checkpoint["model_state_dict"])
        prec1, prec5 = validate(val_loader, net)
        logger.info(
            f"epoch {checkpoint['epoch']:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )
        return

    start_epoch = 1
    # resume training
    if os.path.exists(Config.resume):
        logger.info(f"start resuming model from {Config.resume}")
        checkpoint = torch.load(Config.resume,
                                map_location=torch.device("cpu"))
        start_epoch += checkpoint["epoch"]
        net.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        logger.info(
            f"finish resuming model from {Config.resume}, epoch {checkpoint['epoch']}, "
            f"loss: {checkpoint['loss']:3f}, lr: {checkpoint['lr']:.6f}, "
            f"top1_acc: {checkpoint['acc']}%, loss {checkpoint['loss']}%")

    if not os.path.exists(Config.checkpoints):
        os.makedirs(Config.checkpoints)

    logger.info("start training")
    best_acc = 0.
    for epoch in range(start_epoch, Config.epochs + 1):
        prec1, prec5, loss = train(train_loader, net, criterion, optimizer,
                                   scheduler, epoch, logger)
        logger.info(
            f"train: epoch {epoch:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )

        prec1, prec5 = validate(val_loader, net)
        logger.info(
            f"val: epoch {epoch:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )

        # remember best prec@1 and save checkpoint
        torch.save(
            {
                "epoch": epoch,
                "acc": prec1,
                "loss": loss,
                "lr": scheduler.get_lr()[0],
                "model_state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
            }, os.path.join(Config.checkpoints, "latest.pth"))
        if prec1 > best_acc:
            shutil.copyfile(os.path.join(Config.checkpoints, "latest.pth"),
                            os.path.join(Config.checkpoints, "best.pth"))
            best_acc = prec1

    training_time = (time.time() - start_time) / 3600
    logger.info(
        f"finish training, best acc: {best_acc:.2f}%, total training time: {training_time:.2f} hours"
    )
Exemplo n.º 4
0
class Trainer(object):
    def __init__(self, args):
        super(Trainer, self).__init__()
        train_transform = transforms.Compose([
            transforms.Resize((args.scale_size, args.scale_size)),
            transforms.RandomChoice([
                transforms.RandomCrop(640),
                transforms.RandomCrop(576),
                transforms.RandomCrop(512),
                transforms.RandomCrop(384),
                transforms.RandomCrop(320)
            ]),
            transforms.Resize((args.crop_size, args.crop_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        train_dataset = MLDataset(args.train_path, args.label_path,
                                  train_transform)
        self.train_loader = DataLoader(dataset=train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       pin_memory=True,
                                       drop_last=True)
        val_transform = transforms.Compose([
            transforms.Resize((args.scale_size, args.scale_size)),
            transforms.CenterCrop(args.crop_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        val_dataset = MLDataset(args.val_path, args.label_path, val_transform)
        self.val_loader = DataLoader(dataset=val_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     num_workers=args.num_workers,
                                     pin_memory=True)

        self.model = model_factory[args.model](args, args.num_classes)
        self.model.cuda()

        trainable_parameters = filter(lambda param: param.requires_grad,
                                      self.model.parameters())
        if args.optimizer == 'Adam':
            self.optimizer = Adam(trainable_parameters, lr=args.lr)
        elif args.optimizer == 'SGD':
            self.optimizer = SGD(trainable_parameters, lr=args.lr)

        self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                           mode='max',
                                                           patience=2,
                                                           verbose=True)
        if args.loss == 'BCElogitloss':
            self.criterion = nn.BCEWithLogitsLoss()
        elif args.loss == 'tencentloss':
            self.criterion = TencentLoss(args.num_classes)
        elif args.loss == 'focalloss':
            self.criterion = FocalLoss()
        self.early_stopping = EarlyStopping(patience=5)

        self.voc12_mAP = VOC12mAP(args.num_classes)
        self.average_loss = AverageLoss()
        self.average_topk_meter = TopkAverageMeter(args.num_classes,
                                                   topk=args.topk)
        self.average_threshold_meter = ThresholdAverageMeter(
            args.num_classes, threshold=args.threshold)

        self.args = args
        self.global_step = 0
        self.writer = SummaryWriter(log_dir=args.log_dir)

    def run(self):
        s_epoch = 0
        if self.args.resume:
            checkpoint = torch.load(self.args.ckpt_latest_path)
            s_epoch = checkpoint['epoch']
            self.global_step = checkpoint['global_step']
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
            self.early_stopping.best_score = checkpoint['best_score']
            print('loading checkpoint success (epoch {})'.format(s_epoch))

        for epoch in range(s_epoch, self.args.max_epoch):
            self.train(epoch)
            save_dict = {
                'epoch': epoch + 1,
                'global_step': self.global_step,
                'model_state_dict': self.model.state_dict(),
                'optim_state_dict': self.optimizer.state_dict(),
                'best_score': self.early_stopping.best_score
            }
            torch.save(save_dict, self.args.ckpt_latest_path)

            mAP = self.validation(epoch)
            self.lr_scheduler.step(mAP)
            is_save, is_terminate = self.early_stopping(mAP)
            if is_terminate:
                break
            if is_save:
                torch.save(self.model.state_dict(), self.args.ckpt_best_path)

    def train(self, epoch):
        self.model.train()
        if self.args.model == 'ssgrl':
            self.model.resnet_101.eval()
            self.model.resnet_101.layer4.train()
        for _, batch in enumerate(self.train_loader):
            x, y = batch[0].cuda(), batch[1].cuda()
            pred_y = self.model(x)
            loss = self.criterion(pred_y, y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if self.global_step % 400 == 0:
                self.writer.add_scalar('Loss/train', loss, self.global_step)
                print('TRAIN [epoch {}] loss: {:4f}'.format(epoch, loss))

            self.global_step += 1

    def validation(self, epoch):
        self.model.eval()
        self.voc12_mAP.reset()
        self.average_loss.reset()
        self.average_topk_meter.reset()
        self.average_threshold_meter.reset()
        with torch.no_grad():
            for _, batch in enumerate(self.val_loader):
                x, y = batch[0].cuda(), batch[1].cuda()
                pred_y = self.model(x)
                loss = self.criterion(pred_y, y)

                y = y.cpu().numpy()
                pred_y = pred_y.cpu().numpy()
                loss = loss.cpu().numpy()
                self.voc12_mAP.update(pred_y, y)
                self.average_loss.update(loss, x.size(0))
                self.average_topk_meter.update(pred_y, y)
                self.average_threshold_meter.update(pred_y, y)

        _, mAP = self.voc12_mAP.compute()
        mLoss = self.average_loss.compute()
        self.average_topk_meter.compute()
        self.average_threshold_meter.compute()
        self.writer.add_scalar('Loss/val', mLoss, self.global_step)
        self.writer.add_scalar('mAP/val', mAP, self.global_step)

        print("Validation [epoch {}] mAP: {:.4f} loss: {:.4f}".format(
            epoch, mAP, mLoss))
        return mAP
Exemplo n.º 5
0
class LightHeadRCNN_Learner(Module):
    def __init__(self, training=True):
        super(LightHeadRCNN_Learner, self).__init__()
        self.conf = Config()
        self.class_2_color = get_class_colors(self.conf)   

        self.extractor = ResNet101Extractor(self.conf.pretrained_model_path).to(self.conf.device)
        self.rpn = RegionProposalNetwork().to(self.conf.device)
#         self.head = LightHeadRCNNResNet101_Head(self.conf.class_num + 1, self.conf.roi_size).to(self.conf.device)
        self.loc_normalize_mean=(0., 0., 0., 0.),
        self.loc_normalize_std=(0.1, 0.1, 0.2, 0.2)
        self.head = LightHeadRCNNResNet101_Head(self.conf.class_num + 1, 
                                                self.conf.roi_size, 
                                                roi_align = self.conf.use_roi_align).to(self.conf.device)
        self.class_2_color = get_class_colors(self.conf)
        self.detections = namedtuple('detections', ['roi_cls_locs', 'roi_scores', 'rois'])
             
        if training:
            self.train_dataset = coco_dataset(self.conf, mode = 'train')
            self.train_length = len(self.train_dataset)
            self.val_dataset =  coco_dataset(self.conf, mode = 'val')
            self.val_length = len(self.val_dataset)
            self.anchor_target_creator = AnchorTargetCreator()
            self.proposal_target_creator = ProposalTargetCreator(loc_normalize_mean = self.loc_normalize_mean, 
                                                                 loc_normalize_std = self.loc_normalize_std)
            self.step = 0
            self.optimizer = SGD([
                {'params' : get_trainables(self.extractor.parameters())},
                {'params' : self.rpn.parameters()},
                {'params' : [*self.head.parameters()][:8], 'lr' : self.conf.lr*3},
                {'params' : [*self.head.parameters()][8:]},
            ], lr = self.conf.lr, momentum=self.conf.momentum, weight_decay=self.conf.weight_decay)
            self.base_lrs = [params['lr'] for params in self.optimizer.param_groups]
            self.warm_up_duration = 5000
            self.warm_up_rate = 1 / 5
            self.train_outputs = namedtuple('train_outputs',
                                            ['loss_total', 
                                             'rpn_loc_loss', 
                                             'rpn_cls_loss', 
                                             'ohem_roi_loc_loss', 
                                             'ohem_roi_cls_loss',
                                             'total_roi_loc_loss',
                                             'total_roi_cls_loss'])                                      
            self.writer = SummaryWriter(self.conf.log_path)
            self.board_loss_every = self.train_length // self.conf.board_loss_interval
            self.evaluate_every = self.train_length // self.conf.eval_interval
            self.eva_on_coco_every = self.train_length // self.conf.eval_coco_interval
            self.board_pred_image_every = self.train_length // self.conf.board_pred_image_interval
            self.save_every = self.train_length // self.conf.save_interval
            # only for debugging
#             self.board_loss_every = 5
#             self.evaluate_every = 6
#             self.eva_on_coco_every = 7
#             self.board_pred_image_every = 8
#             self.save_every = 10
        
    def set_training(self):
        self.train()
        self.extractor.set_bn_eval()
        
    def lr_warmup(self):
        assert self.step <= self.warm_up_duration, 'stop warm up after {} steps'.format(self.warm_up_duration)
        rate = self.warm_up_rate + (1 - self.warm_up_rate) * self.step / self.warm_up_duration
        for i, params in enumerate(self.optimizer.param_groups):
            params['lr'] = self.base_lrs[i] * rate
           
    def lr_schedule(self, epoch):
        if epoch < 13:
            return
        elif epoch < 16:
            rate = 0.1
        else:
            rate = 0.01
        for i, params in enumerate(self.optimizer.param_groups):
            params['lr'] = self.base_lrs[i] * rate
        print(self.optimizer)
    
    def forward(self, img_tensor, scale, bboxes=None, labels=None, force_eval=False):
        img_tensor = img_tensor.to(self.conf.device)
        img_size = (img_tensor.shape[2], img_tensor.shape[3]) # H,W
        rpn_feature, roi_feature = self.extractor(img_tensor)
        rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn(rpn_feature, img_size, scale)
        if self.training or force_eval:
            gt_rpn_loc, gt_rpn_labels = self.anchor_target_creator(bboxes, anchor, img_size)
            gt_rpn_labels = torch.tensor(gt_rpn_labels, dtype=torch.long).to(self.conf.device)
            if len(bboxes) == 0:                
                rpn_cls_loss = F.cross_entropy(rpn_scores[0], gt_rpn_labels, ignore_index = -1)
                return self.train_outputs(rpn_cls_loss, 0, 0, 0, 0, 0, 0)
            sample_roi, gt_roi_locs, gt_roi_labels = self.proposal_target_creator(rois, bboxes, labels)
            roi_cls_locs, roi_scores = self.head(roi_feature, sample_roi)
#             roi_cls_locs, roi_scores, pool, h, rois = self.head(roi_feature, sample_roi)
            
            gt_rpn_loc = torch.tensor(gt_rpn_loc, dtype=torch.float).to(self.conf.device)
            gt_roi_locs = torch.tensor(gt_roi_locs, dtype=torch.float).to(self.conf.device)
            gt_roi_labels = torch.tensor(gt_roi_labels, dtype=torch.long).to(self.conf.device)
            
            rpn_loc_loss = fast_rcnn_loc_loss(rpn_locs[0], gt_rpn_loc, gt_rpn_labels, sigma=self.conf.rpn_sigma)
            
            rpn_cls_loss = F.cross_entropy(rpn_scores[0], gt_rpn_labels, ignore_index = -1)
            
            ohem_roi_loc_loss, \
            ohem_roi_cls_loss, \
            total_roi_loc_loss, \
            total_roi_cls_loss = OHEM_loss(roi_cls_locs, 
                                           roi_scores, 
                                           gt_roi_locs, 
                                           gt_roi_labels, 
                                           self.conf.n_ohem_sample, 
                                           self.conf.roi_sigma)
            
            loss_total = rpn_loc_loss + rpn_cls_loss + ohem_roi_loc_loss + ohem_roi_cls_loss
            
#             if loss_total.item() > 1000.:
#                 print('ohem_roi_loc_loss : {}, ohem_roi_cls_loss : {}'.format(ohem_roi_loc_loss, ohem_roi_cls_loss))
#                 torch.save(pool, 'pool_debug.pth')
#                 torch.save(h, 'h_debug.pth')
#                 np.save('rois_debug', rois)
#                 torch.save(roi_cls_locs, 'roi_cls_locs_debug.pth')
#                 torch.save(roi_scores, 'roi_scores_debug.pth')
#                 torch.save(gt_roi_locs, 'gt_roi_locs_debug.pth')
#                 torch.save(gt_roi_labels, 'gt_roi_labels_debug.pth')
#                 pdb.set_trace()
            
            return self.train_outputs(loss_total, 
                                      rpn_loc_loss.item(), 
                                      rpn_cls_loss.item(), 
                                      ohem_roi_loc_loss.item(), 
                                      ohem_roi_cls_loss.item(),
                                      total_roi_loc_loss,
                                      total_roi_cls_loss)
        
        else:
            roi_cls_locs, roi_scores = self.head(roi_feature, rois)
            return self.detections(roi_cls_locs, roi_scores, rois)
        
    def eval_predict(self, img, preset = 'evaluate', use_softnms = False):
        if type(img) == list:
            img = img[0]
        img = Image.fromarray(img.transpose(1,2,0).astype('uint8'))
        bboxes, labels, scores = self.predict_on_img(img, preset, use_softnms, original_size = True)
        bboxes = y1x1y2x2_2_x1y1x2y2(bboxes)
        return [bboxes], [labels], [scores]
        
    def predict_on_img(self, img, preset = 'evaluate', use_softnms=False, return_img = False, with_scores = False, original_size = False):
        '''
        inputs :
        imgs : PIL Image
        return : PIL Image (if return_img) or bboxes_group and labels_group
        '''
        self.eval()
        self.use_preset(preset)
        with torch.no_grad():
            orig_size = img.size # W,H
            img = np.asarray(img).transpose(2,0,1)
            img, scale = prepare_img(self.conf, img, -1)
            img = torch.tensor(img).unsqueeze(0)
            img_size = (img.shape[2], img.shape[3]) # H,W
            detections = self.forward(img, scale)
            n_sample = len(detections.roi_cls_locs)
            n_class = self.conf.class_num + 1
            roi_cls_locs = detections.roi_cls_locs.reshape((n_sample, -1, 4)).reshape([-1,4])
            roi_cls_locs = roi_cls_locs * torch.tensor(self.loc_normalize_std, device=self.conf.device) + torch.tensor(self.loc_normalize_mean, device=self.conf.device)
            rois = torch.tensor(detections.rois.repeat(n_class,0), dtype=torch.float).to(self.conf.device)
            raw_cls_bboxes = loc2bbox(rois, roi_cls_locs)
            torch.clamp(raw_cls_bboxes[:,0::2], 0, img_size[1], out = raw_cls_bboxes[:,0::2] )
            torch.clamp(raw_cls_bboxes[:,1::2], 0, img_size[0], out = raw_cls_bboxes[:,1::2] )
            raw_cls_bboxes = raw_cls_bboxes.reshape([n_sample, n_class, 4])
            raw_prob = F.softmax(detections.roi_scores, dim=1)
            bboxes, labels, scores = self._suppress(raw_cls_bboxes, raw_prob, use_softnms)
            if len(bboxes) == len(labels) == len(scores) == 0:
                if not return_img:  
                    return [], [], []
                else:
                    return to_img(self.conf, img[0])
            _, indices = scores.sort(descending=True)
            bboxes = bboxes[indices]
            labels = labels[indices]
            scores = scores[indices]
            if len(bboxes) > self.max_n_predict:
                bboxes = bboxes[:self.max_n_predict]
                labels = labels[:self.max_n_predict]
                scores = scores[:self.max_n_predict]
        # now, implement drawing
        bboxes = bboxes.cpu().numpy()
        labels = labels.cpu().numpy()
        scores = scores.cpu().numpy()
        if original_size:
            bboxes = adjust_bbox(scale, bboxes, detect=True)
        if not return_img:        
            return bboxes, labels, scores
        else:
            if with_scores:
                scores_ = scores
            else:
                scores_ = []
            predicted_img =  to_img(self.conf, img[0])
            if original_size:
                predicted_img = predicted_img.resize(orig_size)
            if len(bboxes) != 0 and len(labels) != 0:
                predicted_img = draw_bbox_class(self.conf, 
                                                predicted_img, 
                                                labels, 
                                                bboxes, 
                                                self.conf.correct_id_2_class, 
                                                self.class_2_color, 
                                                scores = scores_)
            
            return predicted_img
    
    def _suppress(self, raw_cls_bboxes, raw_prob, use_softnms):
        bbox = []
        label = []
        prob = []
        for l in range(1, self.conf.class_num + 1):
            cls_bbox_l = raw_cls_bboxes[:, l, :]
            prob_l = raw_prob[:, l]
            mask = prob_l > self.score_thresh
            if not mask.any():
                continue
            cls_bbox_l = cls_bbox_l[mask]
            prob_l = prob_l[mask]
            if use_softnms:
                keep, _  = soft_nms(torch.cat((cls_bbox_l, prob_l.unsqueeze(-1)), dim=1).cpu().numpy(),
                                    Nt = self.conf.softnms_Nt,
                                    method = self.conf.softnms_method,
                                    sigma = self.conf.softnms_sigma,
                                    min_score = self.conf.softnms_min_score)
                keep = keep.tolist()
            else:
#                 prob_l, order = torch.sort(prob_l, descending=True)
#                 cls_bbox_l = cls_bbox_l[order]
                keep = nms(torch.cat((cls_bbox_l, prob_l.unsqueeze(-1)), dim=1), self.nms_thresh).tolist()
            bbox.append(cls_bbox_l[keep])
            # The labels are in [0, 79].
            label.append((l - 1) * torch.ones((len(keep),), dtype = torch.long))
            prob.append(prob_l[keep])
        if len(bbox) == 0:
            print("looks like there is no prediction have a prob larger than thresh")
            return [], [], []
        bbox = torch.cat(bbox)
        label = torch.cat(label)
        prob = torch.cat(prob)
        return bbox, label, prob
    
    def board_scalars(self, 
                      key, 
                      loss_total, 
                      rpn_loc_loss, 
                      rpn_cls_loss, 
                      ohem_roi_loc_loss, 
                      ohem_roi_cls_loss, 
                      total_roi_loc_loss, 
                      total_roi_cls_loss):
        self.writer.add_scalar('{}_loss_total'.format(key), loss_total, self.step)
        self.writer.add_scalar('{}_rpn_loc_loss'.format(key), rpn_loc_loss, self.step)
        self.writer.add_scalar('{}_rpn_cls_loss'.format(key), rpn_cls_loss, self.step)
        self.writer.add_scalar('{}_ohem_roi_loc_loss'.format(key), ohem_roi_loc_loss, self.step)
        self.writer.add_scalar('{}_ohem_roi_cls_loss'.format(key), ohem_roi_cls_loss, self.step)
        self.writer.add_scalar('{}_total_roi_loc_loss'.format(key), total_roi_loc_loss, self.step)
        self.writer.add_scalar('{}_total_roi_cls_loss'.format(key), total_roi_cls_loss, self.step)
    
    def use_preset(self, preset):
        """Use the given preset during prediction.

        This method changes values of :obj:`self.nms_thresh` and
        :obj:`self.score_thresh`. These values are a threshold value
        used for non maximum suppression and a threshold value
        to discard low confidence proposals in :meth:`predict`,
        respectively.

        If the attributes need to be changed to something
        other than the values provided in the presets, please modify
        them by directly accessing the public attributes.

        Args:
            preset ({'visualize', 'evaluate', 'debug'): A string to determine the
                preset to use.

        """
        if preset == 'visualize':
            self.nms_thresh = 0.5
            self.score_thresh = 0.25
            self.max_n_predict = 40
        elif preset == 'evaluate':
            self.nms_thresh = 0.5
            self.score_thresh = 0.0
            self.max_n_predict = 100
#         """
#         We finally replace origi-nal 0.3 threshold with 0.5 for Non-maximum Suppression
#         (NMS). It improves 0.6 points of mmAP by improving the
#         recall rate especially for the crowd cases.
#         """
        elif preset == 'debug':
            self.nms_thresh = 0.5
            self.score_thresh = 0.0
            self.max_n_predict = 10
        else:
            raise ValueError('preset must be visualize or evaluate')
    
    def fit(self, epochs=20, resume=False, from_save_folder=False):
        if resume:
            self.resume_training_load(from_save_folder)
        self.set_training()        
        running_loss = 0.
        running_rpn_loc_loss = 0.
        running_rpn_cls_loss = 0.
        running_ohem_roi_loc_loss = 0.
        running_ohem_roi_cls_loss = 0.
        running_total_roi_loc_loss = 0.
        running_total_roi_cls_loss = 0.
        map05 = None
        val_loss = None
        
        epoch = self.step // self.train_length
        while epoch <= epochs:
            print('start the training of epoch : {}'.format(epoch))
            self.lr_schedule(epoch)
#             for index in tqdm(np.random.permutation(self.train_length), total = self.train_length):
            for index in tqdm(range(self.train_length), total = self.train_length):
                try:
                    inputs = self.train_dataset[index]
                except:
                    print('loading index {} from train dataset failed}'.format(index))
#                     print(self.train_dataset.orig_dataset._datasets[0].id_to_prop[self.train_dataset.orig_dataset._datasets[0].ids[index]])
                    continue
                self.optimizer.zero_grad()
                train_outputs = self.forward(torch.tensor(inputs.img).unsqueeze(0),
                                             inputs.scale,
                                             inputs.bboxes,
                                             inputs.labels)
                train_outputs.loss_total.backward()
                if epoch == 0:
                    if self.step <= self.warm_up_duration:
                        self.lr_warmup()
                self.optimizer.step()
                torch.cuda.empty_cache()
                
                running_loss += train_outputs.loss_total.item()
                running_rpn_loc_loss += train_outputs.rpn_loc_loss
                running_rpn_cls_loss += train_outputs.rpn_cls_loss
                running_ohem_roi_loc_loss += train_outputs.ohem_roi_loc_loss
                running_ohem_roi_cls_loss += train_outputs.ohem_roi_cls_loss
                running_total_roi_loc_loss += train_outputs.total_roi_loc_loss
                running_total_roi_cls_loss += train_outputs.total_roi_cls_loss
                
                if self.step != 0:
                    if self.step % self.board_loss_every == 0:
                        self.board_scalars('train', 
                                           running_loss / self.board_loss_every, 
                                           running_rpn_loc_loss / self.board_loss_every, 
                                           running_rpn_cls_loss / self.board_loss_every,
                                           running_ohem_roi_loc_loss / self.board_loss_every, 
                                           running_ohem_roi_cls_loss / self.board_loss_every,
                                           running_total_roi_loc_loss / self.board_loss_every, 
                                           running_total_roi_cls_loss / self.board_loss_every)
                        running_loss = 0.
                        running_rpn_loc_loss = 0.
                        running_rpn_cls_loss = 0.
                        running_ohem_roi_loc_loss = 0.
                        running_ohem_roi_cls_loss = 0.
                        running_total_roi_loc_loss = 0.
                        running_total_roi_cls_loss = 0.

                    if self.step % self.evaluate_every == 0:
                        val_loss, val_rpn_loc_loss, \
                        val_rpn_cls_loss, \
                        ohem_val_roi_loc_loss, \
                        ohem_val_roi_cls_loss, \
                        total_val_roi_loc_loss, \
                        total_val_roi_cls_loss = self.evaluate(num = self.conf.eva_num_during_training)
                        self.set_training() 
                        self.board_scalars('val', 
                                           val_loss, 
                                           val_rpn_loc_loss, 
                                           val_rpn_cls_loss, 
                                           ohem_val_roi_loc_loss,
                                           ohem_val_roi_cls_loss,
                                           total_val_roi_loc_loss,
                                           total_val_roi_cls_loss)
                    
                    if self.step % self.eva_on_coco_every == 0:
                        try:
                            cocoEval = self.eva_on_coco(limit = self.conf.coco_eva_num_during_training)
                            self.set_training() 
                            map05 = cocoEval[1]
                            mmap = cocoEval[0]
                        except:
                            print('eval on coco failed')
                            map05 = -1
                            mmap = -1
                        self.writer.add_scalar('0.5IoU MAP', map05, self.step)
                        self.writer.add_scalar('0.5::0.9 - MMAP', mmap, self.step)
                    
                    if self.step % self.board_pred_image_every == 0:
                        for i in range(20):
                            img, _, _, _ , _= self.val_dataset.orig_dataset[i]  
                            img = Image.fromarray(img.astype('uint8').transpose(1,2,0))
                            predicted_img = self.predict_on_img(img, preset='visualize', return_img=True, with_scores=True, original_size=True) 
#                             if type(predicted_img) == tuple: 
#                                 self.writer.add_image('pred_image_{}'.format(i), trans.ToTensor()(img), global_step=self.step)
#                             else: ## should be deleted after test
                            self.writer.add_image('pred_image_{}'.format(i), trans.ToTensor()(predicted_img), global_step=self.step)
                            self.set_training()
                    
                    if self.step % self.save_every == 0:
                        try:
                            self.save_state(val_loss, map05)
                        except:
                            print('save state failed')
                            self.step += 1
                            continue
                    
                self.step += 1
            epoch = self.step // self.train_length
            try:
                self.save_state(val_loss, map05, to_save_folder=True)
            except:
                print('save state failed')
    
    def eva_on_coco(self, limit = 1000, preset = 'evaluate', use_softnms = False):
        self.eval() 
        return eva_coco(self.val_dataset.orig_dataset, lambda x : self.eval_predict(x, preset, use_softnms), limit, preset)
    
    def evaluate(self, num=None):
        self.eval()        
        running_loss = 0.
        running_rpn_loc_loss = 0.
        running_rpn_cls_loss = 0.
        running_ohem_roi_loc_loss = 0.
        running_ohem_roi_cls_loss = 0.
        running_total_roi_loc_loss = 0.
        running_total_roi_cls_loss = 0.
        if num == None:
            total_num = self.val_length
        else:
            total_num = num
        with torch.no_grad():
            for index in tqdm(range(total_num)):
                inputs = self.val_dataset[index]
                if inputs.bboxes == []:
                    continue
                val_outputs = self.forward(torch.tensor(inputs.img).unsqueeze(0),
                                           inputs.scale,
                                           inputs.bboxes,
                                           inputs.labels,
                                           force_eval = True)
                running_loss += val_outputs.loss_total.item()
                running_rpn_loc_loss += val_outputs.rpn_loc_loss
                running_rpn_cls_loss += val_outputs.rpn_cls_loss
                running_ohem_roi_loc_loss += val_outputs.ohem_roi_loc_loss
                running_ohem_roi_cls_loss += val_outputs.ohem_roi_cls_loss
                running_total_roi_loc_loss += val_outputs.total_roi_loc_loss
                running_total_roi_cls_loss += val_outputs.total_roi_cls_loss
        return running_loss / total_num, \
                running_rpn_loc_loss / total_num, \
                running_rpn_cls_loss / total_num, \
                running_ohem_roi_loc_loss / total_num, \
                running_ohem_roi_cls_loss / total_num,\
                running_total_roi_loc_loss / total_num, \
                running_total_roi_cls_loss / total_num
    
    def save_state(self, val_loss, map05, to_save_folder=False, model_only=False):
        if to_save_folder:
            save_path = self.conf.work_space/'save'
        else:
            save_path = self.conf.work_space/'model'
        time = get_time()
        torch.save(
            self.state_dict(), save_path /
            ('model_{}_val_loss:{}_map05:{}_step:{}.pth'.format(time,
                                                                val_loss, 
                                                                map05, 
                                                                self.step)))
        if not model_only:
            torch.save(
                self.optimizer.state_dict(), save_path /
                ('optimizer_{}_val_loss:{}_map05:{}_step:{}.pth'.format(time,
                                                                        val_loss, 
                                                                        map05, 
                                                                        self.step)))
    
    def load_state(self, fixed_str, from_save_folder=False, model_only=False):
        if from_save_folder:
            save_path = self.conf.work_space/'save'
        else:
            save_path = self.conf.work_space/'model'          
        self.load_state_dict(torch.load(save_path/'model_{}'.format(fixed_str)))
        print('load model_{}'.format(fixed_str))
        if not model_only:
            self.optimizer.load_state_dict(torch.load(save_path/'optimizer_{}'.format(fixed_str)))
            print('load optimizer_{}'.format(fixed_str))
    
    def resume_training_load(self, from_save_folder=False):
        if from_save_folder:
            save_path = self.conf.work_space/'save'
        else:
            save_path = self.conf.work_space/'model'  
        sorted_files = sorted([*save_path.iterdir()],  key=lambda x: os.path.getmtime(x), reverse=True)
        seeking_flag = True
        index = 0
        while seeking_flag:
            if index > len(sorted_files) - 2:
                break
            file_a = sorted_files[index]
            file_b = sorted_files[index + 1]
            if file_a.name.startswith('model'):
                fix_str = file_a.name[6:]
                self.step = int(fix_str.split(':')[-1].split('.')[0]) + 1
                if file_b.name == ''.join(['optimizer', '_', fix_str]):                    
                    self.load_state(fix_str, from_save_folder)
                    return
                else:
                    index += 1
                    continue
            elif file_a.name.startswith('optimizer'):
                fix_str = file_a.name[10:]
                self.step = int(fix_str.split(':')[-1].split('.')[0]) + 1
                if file_b.name == ''.join(['model', '_', fix_str]):
                    self.load_state(fix_str, from_save_folder)
                    return
                else:
                    index += 1
                    continue
            else:
                index += 1
                continue
        print('no available files founded')
        return      
Exemplo n.º 6
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc
    args.gpu = gpu
    assert args.gpu is not None
    print("Use GPU: {} for training".format(args.gpu))

    log = open(
        os.path.join(
            args.save_path,
            'log_seed{}{}.txt'.format(args.manualSeed,
                                      '_eval' if args.evaluate else '')), 'w')
    log = (log, args.gpu)

    net = models.__dict__[args.arch](pretrained=True)
    disable_dropout(net)
    net = to_bayesian(net, args.psi_init_range)
    net.apply(unfreeze)

    print_log("Python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("PyTorch  version : {}".format(torch.__version__), log)
    print_log("CuDNN  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log(
        "Number of parameters: {}".format(
            sum([p.numel() for p in net.parameters()])), log)
    print_log(str(args), log)

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url + ":" +
                                args.dist_port,
                                world_size=args.world_size,
                                rank=args.rank)
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        net = torch.nn.parallel.DistributedDataParallel(net,
                                                        device_ids=[args.gpu])
    else:
        torch.cuda.set_device(args.gpu)
        net = net.cuda(args.gpu)

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    mus, psis = [], []
    for name, param in net.named_parameters():
        if 'psi' in name: psis.append(param)
        else: mus.append(param)
    mu_optimizer = SGD(mus,
                       args.learning_rate,
                       args.momentum,
                       weight_decay=args.decay,
                       nesterov=(args.momentum > 0.0))

    psi_optimizer = PsiSGD(psis,
                           args.learning_rate,
                           args.momentum,
                           weight_decay=args.decay,
                           nesterov=(args.momentum > 0.0))

    recorder = RecorderMeter(args.epochs)
    if args.resume:
        if args.resume == 'auto':
            args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar')
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume,
                                    map_location='cuda:{}'.format(args.gpu))
            recorder = checkpoint['recorder']
            recorder.refresh(args.epochs)
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(
                checkpoint['state_dict'] if args.distributed else {
                    k.replace('module.', ''): v
                    for k, v in checkpoint['state_dict'].items()
                })
            mu_optimizer.load_state_dict(checkpoint['mu_optimizer'])
            psi_optimizer.load_state_dict(checkpoint['psi_optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log(
                "=> loaded checkpoint '{}' accuracy={} (epoch {})".format(
                    args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log("=> do not use any checkpoint for the model", log)

    cudnn.benchmark = True

    train_loader, ood_train_loader, test_loader, adv_loader, \
        fake_loader, adv_loader2 = load_dataset_ft(args)
    psi_optimizer.num_data = len(train_loader.dataset)

    if args.evaluate:
        evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net,
                 criterion, args, log, 20, 100)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_los = -1

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
            ood_train_loader.sampler.set_epoch(epoch)
        cur_lr, cur_slr = adjust_learning_rate(mu_optimizer, psi_optimizer,
                                               epoch, args)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f} {:6.4f}]'.format(
                                    time_string(), epoch, args.epochs, need_time, cur_lr, cur_slr) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        train_acc, train_los = train(train_loader, ood_train_loader, net,
                                     criterion, mu_optimizer, psi_optimizer,
                                     epoch, args, log)
        val_acc, val_los = 0, 0
        recorder.update(epoch, train_los, train_acc, val_acc, val_los)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        if args.gpu == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'recorder': recorder,
                    'mu_optimizer': mu_optimizer.state_dict(),
                    'psi_optimizer': psi_optimizer.state_dict(),
                }, False, args.save_path, 'checkpoint.pth.tar')

        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'log.png'))

    evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net, criterion,
             args, log, 20, 100)

    log[0].close()
Exemplo n.º 7
0
def run(opt):
    if opt.log_file is not None:
        logging.basicConfig(filename=opt.log_file, level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()
    # logger.addHandler(logging.StreamHandler())
    logger = logger.info
    writer = SummaryWriter(log_dir=opt.log_dir)
    model_timer, data_timer = Timer(average=True), Timer(average=True)

    # Training variables
    logger('Loading models')
    model, parameters, mean, std = generate_model(opt)

    # Learning configurations
    if opt.optimizer == 'sgd':
        optimizer = SGD(parameters,
                        lr=opt.lr,
                        momentum=opt.momentum,
                        weight_decay=opt.weight_decay,
                        nesterov=opt.nesterov)
    elif opt.optimizer == 'adam':
        optimizer = Adam(parameters, lr=opt.lr, betas=opt.betas)
    else:
        raise Exception("Not supported")
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'max',
                                               patience=opt.lr_patience)

    # Loading checkpoint
    if opt.checkpoint:
        # load some param
        logger('loading checkpoint {}'.format(opt.checkpoint))
        checkpoint = torch.load(opt.checkpoint)

        # to use the loaded param
        opt.begin_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    logger('Loading dataset')  # =================================
    train_loader, val_loader, _ = get_data_market(opt, mean, std)

    device = 'cuda'
    trainer = create_supervised_trainer(
        model,
        optimizer,
        lambda pred, target: loss_market(pred, target, loss_fns=training_loss),
        device=device)
    evaluator = create_supervised_evaluator(
        model,
        metrics={'cosine_metric': CosineMetric(cmc_metric, testing_loss)},
        device=device)

    # Training timer handlers
    model_timer.attach(trainer,
                       start=Events.EPOCH_STARTED,
                       resume=Events.ITERATION_STARTED,
                       pause=Events.ITERATION_COMPLETED,
                       step=Events.ITERATION_COMPLETED)
    data_timer.attach(trainer,
                      start=Events.EPOCH_STARTED,
                      resume=Events.ITERATION_COMPLETED,
                      pause=Events.ITERATION_STARTED,
                      step=Events.ITERATION_STARTED)

    # Training log/plot handlers
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % opt.log_interval == 0:
            logger(
                "Epoch[{}] Iteration[{}/{}] Loss: {:.2f} Model Process: {:.3f}s/batch "
                "Data Preparation: {:.3f}s/batch".format(
                    engine.state.epoch,
                    iter, len(train_loader), engine.state.output,
                    model_timer.value(), data_timer.value()))
            writer.add_scalar("training/loss", engine.state.output,
                              engine.state.iteration)

    # Log/Plot Learning rate
    @trainer.on(Events.EPOCH_STARTED)
    def log_learning_rate(engine):
        lr = optimizer.param_groups[-1]['lr']
        logger('Epoch[{}] Starts with lr={}'.format(engine.state.epoch, lr))
        writer.add_scalar("learning_rate", lr, engine.state.epoch)

    # Checkpointing
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_checkpoint(engine):
        if engine.state.epoch % opt.save_interval == 0:
            save_file_path = os.path.join(
                opt.result_path, 'save_{}.pth'.format(engine.state.epoch))
            states = {
                'epoch': engine.state.epoch,
                'arch': opt.model,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(states, save_file_path)

    # val_evaluator event handlers
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics["cosine_metric"]
        # metric_values = [metrics[m] for m in val_metrics]
        logger("Validation Results - Epoch: {}".format(engine.state.epoch))
        for m, val in metrics.items():
            logger('{}: {:.4f}'.format(m, val))

        for m, val in metrics.items():
            if m in ['total_loss', 'cmc']:
                prefix = 'validation_summary/{}'
            else:
                prefix = 'validation/{}'
            writer.add_scalar(prefix.format(m), val, engine.state.epoch)

        # Update Learning Rate
        scheduler.step(metrics['cmc'])

    # kick everything off
    logger('Start training')
    trainer.run(train_loader, max_epochs=opt.n_epochs)

    writer.close()
Exemplo n.º 8
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_transform = T.Compose([
        T.RandomRotation(args.rotation),
        T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),
        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
        T.GaussianBlur(),
        T.ToTensor(), normalize
    ])
    val_transform = T.Compose(
        [T.Resize(args.image_size),
         T.ToTensor(), normalize])
    image_size = (args.image_size, args.image_size)
    heatmap_size = (args.heatmap_size, args.heatmap_size)
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(root=args.source_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_source_dataset = source_dataset(root=args.source_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_source_loader = DataLoader(val_source_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(root=args.target_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(root=args.target_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    print("Source train:", len(train_source_loader))
    print("Target train:", len(train_target_loader))
    print("Source test:", len(val_source_loader))
    print("Target test:", len(val_target_loader))

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    backbone = models.__dict__[args.arch](pretrained=True)
    upsampling = Upsampling(backbone.out_features)
    num_keypoints = train_source_dataset.num_keypoints
    model = RegDAPoseResNet(backbone,
                            upsampling,
                            256,
                            num_keypoints,
                            num_head_layers=args.num_head_layers,
                            finetune=True).to(device)
    # define loss function
    criterion = JointsKLLoss()
    pseudo_label_generator = PseudoLabelGenerator(num_keypoints,
                                                  args.heatmap_size,
                                                  args.heatmap_size)
    regression_disparity = RegressionDisparity(pseudo_label_generator,
                                               JointsKLLoss(epsilon=1e-7))

    # define optimizer and lr scheduler
    optimizer_f = SGD([
        {
            'params': backbone.parameters(),
            'lr': 0.1
        },
        {
            'params': upsampling.parameters(),
            'lr': 0.1
        },
    ],
                      lr=0.1,
                      momentum=args.momentum,
                      weight_decay=args.wd,
                      nesterov=True)
    optimizer_h = SGD(model.head.parameters(),
                      lr=1.,
                      momentum=args.momentum,
                      weight_decay=args.wd,
                      nesterov=True)
    optimizer_h_adv = SGD(model.head_adv.parameters(),
                          lr=1.,
                          momentum=args.momentum,
                          weight_decay=args.wd,
                          nesterov=True)
    lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x))**(
        -args.lr_decay)
    lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function)
    lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function)
    lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function)
    start_epoch = 0

    if args.resume is None:
        if args.pretrain is None:
            # first pretrain the backbone and upsampling
            print("Pretraining the model on source domain.")
            args.pretrain = logger.get_checkpoint_path('pretrain')
            pretrained_model = PoseResNet(backbone, upsampling, 256,
                                          num_keypoints, True).to(device)
            optimizer = SGD(pretrained_model.get_parameters(lr=args.lr),
                            momentum=args.momentum,
                            weight_decay=args.wd,
                            nesterov=True)
            lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)
            best_acc = 0
            for epoch in range(args.pretrain_epochs):
                lr_scheduler.step()
                print(lr_scheduler.get_lr())

                pretrain(train_source_iter, pretrained_model, criterion,
                         optimizer, epoch, args)
                source_val_acc = validate(val_source_loader, pretrained_model,
                                          criterion, None, args)

                # remember best acc and save checkpoint
                if source_val_acc['all'] > best_acc:
                    best_acc = source_val_acc['all']
                    torch.save({'model': pretrained_model.state_dict()},
                               args.pretrain)
                print("Source: {} best: {}".format(source_val_acc['all'],
                                                   best_acc))

        # load from the pretrained checkpoint
        pretrained_dict = torch.load(args.pretrain,
                                     map_location='cpu')['model']
        model_dict = model.state_dict()
        # remove keys from pretrained dict that doesn't appear in model dict
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model.load_state_dict(pretrained_dict, strict=False)
    else:
        # optionally resume from a checkpoint
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer_f.load_state_dict(checkpoint['optimizer_f'])
        optimizer_h.load_state_dict(checkpoint['optimizer_h'])
        optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv'])
        lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f'])
        lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h'])
        lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv'])
        start_epoch = checkpoint['epoch'] + 1

    # define visualization function
    tensor_to_image = Compose([
        Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ToPILImage()
    ])

    def visualize(image, keypoint2d, name, heatmaps=None):
        """
        Args:
            image (tensor): image in shape 3 x H x W
            keypoint2d (tensor): keypoints in shape K x 2
            name: name of the saving image
        """
        train_source_dataset.visualize(
            tensor_to_image(image), keypoint2d,
            logger.get_image_path("{}.jpg".format(name)))

    if args.phase == 'test':
        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize, args)
        print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'],
                                                       target_val_acc['all']))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))
        return

    # start training
    best_acc = 0
    print("Start regression domain adaptation.")
    for epoch in range(start_epoch, args.epochs):
        logger.set_epoch(epoch)
        print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(),
              lr_scheduler_h_adv.get_lr())

        # train for one epoch
        train(train_source_iter, train_target_iter, model, criterion,
              regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv,
              lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch,
              visualize if args.debug else None, args)

        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize if args.debug else None, args)

        # remember best acc and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'optimizer_f': optimizer_f.state_dict(),
                'optimizer_h': optimizer_h.state_dict(),
                'optimizer_h_adv': optimizer_h_adv.state_dict(),
                'lr_scheduler_f': lr_scheduler_f.state_dict(),
                'lr_scheduler_h': lr_scheduler_h.state_dict(),
                'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if target_val_acc['all'] > best_acc:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
            best_acc = target_val_acc['all']
        print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(
            source_val_acc['all'], target_val_acc['all'], best_acc))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))

    logger.close()
Exemplo n.º 9
0
optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
logger.info("Model:\n%s\nOptimizer:\n%s" % (str(model), str(optimizer)))

# Optionally build beta distribution
if args.mix_up:
    beta_distribution = Beta(torch.tensor([args.alpha]), torch.tensor([args.alpha]))

# Optionally resume from a checkpoint
if args.resume is not None:
    if isfile(args.resume):
        logger.info("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_step = checkpoint['step']
        best_acc = checkpoint['best_acc']
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (step {})".format(args.resume, checkpoint['step']))
    else:
        logger.info("=> no checkpoint found at '{}'".format(args.resume))

def compute_lr(step):
    if step < args.warmup:
        lr = args.lr * step / args.warmup
    else:
        lr = args.lr
        for milestone in args.milestones:
            if step > milestone:
                lr *= args.gamma
    return lr

def main():
def main():
    '''model = models.resnet18(pretrained = False)
    model.fc = nn.Linear(512, 4000)'''
    model = ResNet18(4000)
    model.to(DEVICE)
    optimizer = SGD(model.parameters(),
                    lr=0.15,
                    momentum=0.9,
                    weight_decay=5e-5)

    if (param['resume'] == True):
        print("loading from checkpoint {}".format(param['resume_from']))
        checkPointPath = param['checkPointPath'] + '/epoch' + str(
            param['resume_from'])
        checkpoint = torch.load(checkPointPath)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(DEVICE)
        print("finish loading")

    scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
    criterion = nn.CrossEntropyLoss()
    criterion.to(DEVICE)

    batch_size = 10 if DEVICE == 'cuda' else 1
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor()])

    train_dataset = datasets.ImageFolder(
        root='../classification_data/train_data/', transform=data_transform)
    val_dataset = datasets.ImageFolder(root='../classification_data/val_data',
                                       transform=data_transform)
    verfication_dev_dataset = MyVerificationDataset(
        '../verification_pairs_val.txt', data_transform)
    verfication_test_dataset = MyVerificationDataset(
        '../verification_pairs_test.txt', data_transform)

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    dev_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    verification_dev_loader = DataLoader(verfication_dev_dataset,
                                         batch_size=batch_size,
                                         shuffle=False)
    verification_test_loader = DataLoader(verfication_test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

    start_epoch = param['resume_from'] + 1
    torch.cuda.empty_cache()
    acc = validation(model, dev_loader)
    auc = verification_dev(model, verification_dev_loader)
    print("start training")
    for epoch in range(start_epoch, start_epoch + param['nepochs']):
        train(model, train_loader, criterion, optimizer, epoch)
        path = param['checkPointPath'] + "/epoch" + str(epoch)
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, path)
        scheduler.step()
        acc = validation(model, dev_loader)
        auc = verification_dev(model, verification_dev_loader)
        print("auc is: ", auc)
Exemplo n.º 11
0
def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # Copies files to the outdir to store complete script with each experiment
    copy_code(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train', args.data_root, None)
    test_dataset = get_dataset(args.dataset, 'test', None, args.test_root)
    pin_memory = (args.dataset == "imagenet")
    labelled_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch,
                      num_workers=args.workers, pin_memory=pin_memory, drop_last= True)
    if args.use_unlabelled:
        pseudo_labelled_loader = DataLoader(TiTop50KDataset(), shuffle=True, batch_size=args.batch,
                                  num_workers=args.workers, pin_memory=pin_memory)
        train_loader = MultiDatasetsDataLoader([labelled_loader, pseudo_labelled_loader])
    else:
        train_loader = MultiDatasetsDataLoader([labelled_loader])
    
    test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch,
                             num_workers=args.workers, pin_memory=pin_memory)

    if args.pretrained_model == 'xception_first_time':
        model = get_architecture("xception", args.dataset)
        checkpoint = torch.load(args.load_checkpoint)
        model[1].load_state_dict(checkpoint, strict=False)
    elif args.pretrained_model == 'xception':
        checkpoint = torch.load(args.load_checkpoint,
                                    map_location=lambda storage, loc: storage)
        model = get_architecture(checkpoint["arch"], args.dataset)
        model.load_state_dict(checkpoint['state_dict'])

    elif args.pretrained_model != '':
        assert args.arch == 'cifar_resnet110', 'Unsupported architecture for pretraining'
        checkpoint = torch.load(args.pretrained_model)
        model = get_architecture(checkpoint["arch"], args.dataset)
        model.load_state_dict(checkpoint['state_dict'])
        model[1].fc = nn.Linear(64, get_num_classes('cifar10')).cuda()
    else:
        model = get_architecture(args.arch, args.dataset)

    if args.attack == 'PGD':
        print('Attacker is PGD')
        attacker = PGD_L2(steps=args.num_steps, device='cuda', max_norm=args.epsilon)
    elif args.attack == 'DDN':
        print('Attacker is DDN')
        attacker = DDN(steps=args.num_steps, device='cuda', max_norm=args.epsilon, 
                    init_norm=args.init_norm_DDN, gamma=args.gamma_DDN)
    else:
        raise Exception('Unknown attack')

    criterion = CrossEntropyLoss().cuda()
    optimizer = SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=args.gamma)

    starting_epoch = 0
    logfilename = os.path.join(args.outdir, 'log.txt')
    print(len(train_dataset.classes))

    # Load latest checkpoint if exists (to handle philly failures) 
    model_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
    if args.resume:
        if os.path.isfile(model_path):
            print("=> loading checkpoint '{}'".format(model_path))
            checkpoint = torch.load(model_path,
                                    map_location=lambda storage, loc: storage)
            starting_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                         .format(model_path, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(model_path))
            if args.adv_training:
                init_logfile(logfilename, "epoch\ttime\tlr\ttrainloss\ttestloss\ttrainacc\ttestacc\ttestaccNor")
            else:
                init_logfile(logfilename, "epoch\ttime\tlr\ttrainloss\ttestloss\ttrainacc\ttestacc")
    else:
        if args.adv_training:
            init_logfile(logfilename, "epoch\ttime\tlr\ttrainloss\ttestloss\ttrainacc\ttestacc\ttestaccNor")
        else:
            init_logfile(logfilename, "epoch\ttime\tlr\ttrainloss\ttestloss\ttrainacc\ttestacc")

    for epoch in range(starting_epoch, args.epochs):
        scheduler.step(epoch)
        attacker.max_norm = np.min([args.epsilon, (epoch + 1) * args.epsilon/args.warmup])
        attacker.init_norm = np.min([args.epsilon, (epoch + 1) * args.epsilon/args.warmup])

        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, args.noise_sd, attacker)
        test_loss, test_acc, test_acc_normal = test(test_loader, model, criterion, args.noise_sd, attacker)
        after = time.time()

        if args.adv_training:
            log(logfilename, "{}\t{:.2f}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, after - before,
                scheduler.get_lr()[0], train_loss, test_loss, train_acc, test_acc, test_acc_normal))
        else:
            log(logfilename, "{}\t{:.2f}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, after - before,
                scheduler.get_lr()[0], train_loss, test_loss, train_acc, test_acc))

        torch.save({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, model_path)
Exemplo n.º 12
0
def train_continue(directory, version, path, exp_name, name, model, lr, epochs,
                   momentum, batch_size, resize, margin, logdir):
    #definiamo la contrastive loss

    print("Continue model")
    directory = directory
    resize = resize
    device = "cuda" if torch.cuda.is_available() else "cpu"
    siamese_reload = model
    siamese_reload.to(device)
    checkpoint = torch.load(path)

    siamese_reload.load_state_dict(checkpoint['model_state_dict'])
    #optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    lossTrain = checkpoint['lossTrain']
    lossValid = checkpoint['lossValid']

    print('lossTrain', lossTrain)
    print('lossValid', lossValid)
    global_step_train = checkpoint['global_step_train']
    global_step_val = checkpoint['global_step_valid']

    accTrain = checkpoint['accTrain']
    accValid = checkpoint['accValid']
    print('accTrain', accTrain)
    print('accValid', accValid)

    print(
        "Epoca %s , lossTrain %s , lossValid ,accTarin, accValid, global_step_train %s , global_step_val %s",
        epoch, lossTrain, lossValid, accTrain, accValid, global_step_train,
        global_step_val)

    print(siamese_reload.load_state_dict(checkpoint['model_state_dict']))
    #model(torch.zeros(16,3,28,28)).shape

    #E' possibile accedere a un dizionario contenente tutti i parametri del modello utilizzando il metodo state_dict .
    state_dict = siamese_reload.state_dict()
    print(state_dict.keys())

    # Print model's state_dict
    print("Model's state_dict:")
    for param_tensor in siamese_reload.state_dict():
        print(param_tensor, "\t",
              siamese_reload.state_dict()[param_tensor].size())

    controlFileCSV()
    #controlFileCSVPair()
    dataSetPair = DataSetPairCreate(resize)
    dataSetPair.controlNormalize()

    pair_train = dataSetPair.pair_money_train
    #pair_test = dataSetPair.pair_money_test
    pair_validation = dataSetPair.pair_money_val

    pair_money_train_loader = DataLoader(pair_train,
                                         batch_size=batch_size,
                                         num_workers=0,
                                         shuffle=True)
    #pair_money_test_loader = DataLoader(pair_test, batch_size=1024, num_workers=0)
    pair_money_val_loader = DataLoader(pair_validation,
                                       batch_size=batch_size,
                                       num_workers=0)

    criterion = ContrastiveLoss(margin)
    optimizer = SGD(siamese_reload.parameters(), lr, momentum=momentum)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    #meters
    array_loss_train = []
    array_loss_valid = []
    array_sample_train = []
    array_sample_valid = []
    array_acc_valid = []
    array_acc_train = []
    prediction_train = []
    labels_train = []
    prediction_val = []
    labels_val = []

    loss_meter = AverageValueMeter()
    acc_meter = AverageValueMeter()
    #writer
    writer = SummaryWriter(join(logdir, exp_name))

    criterion.to(
        device
    )  # anche la loss va portata sul device in quanto contiene un parametro(m)
    #definiamo un dizionario contenente i loader di training e test
    loader = {'train': pair_money_train_loader, 'valid': pair_money_val_loader}
    #global_step_train = global_step_train
    #gloabal_step_val = global_step_val

    #lossTrain = lossTrain
    #lossValid = lossValid
    timer = Timer()
    global_step = global_step_val

    for e in range(epochs):
        print("Epoca ", e)
        #iteriamo tra due modalità: train e test
        for mode in ['train', 'valid']:
            """
            if mode =='train':
                loss_meter.inizializza(lossTrain, global_step_train)
                acc_meter.inizializza(accTrain, global_step_train)
                global_step=global_step_train
            else:
                loss_meter.inizializza(lossValid, global_step_val)
                acc_meter.inizializza(accValid, global_step_val)
                global_step = global_step_val
              """

            siamese_reload.train() if mode == 'train' else siamese_reload.eval(
            )
            with torch.set_grad_enabled(
                    mode == 'train'):  #abilitiamo i gradienti solo in training

                for i, batch in enumerate(loader[mode]):
                    I_i, I_j, l_ij, _, _ = [b.to(device) for b in batch]
                    #img1, img2, label12, label1, label2
                    #l'implementazione della rete siamese è banale:
                    #eseguiamo la embedding net sui due input
                    phi_i = siamese_reload(I_i)  #img 1
                    phi_j = siamese_reload(I_j)  #img2

                    #calcoliamo la loss
                    l = criterion(phi_i, phi_j, l_ij)

                    d = F.pairwise_distance(phi_i.to('cpu'), phi_j.to('cpu'))
                    labs = l_ij.to('cpu')
                    #print(len(labs))
                    tensor = torch.clamp(
                        margin - d, min=0
                    )  # sceglie il massimo  # sceglie il massimo -- se è zero allora sono dissimili
                    #print("max",type(tensor))
                    #print("size max tensor ",tensor.size())
                    #print("tentor 1", tensor)

                    for el in tensor:
                        if el <= 2:  # SIMILI
                            if mode == 'train':
                                prediction_train.append(0)
                            else:
                                prediction_val.append(0)
                        else:  # DISSIMILI
                            if mode == 'train':

                                prediction_train.append(1)
                            else:
                                prediction_val.append(1)
                    """
                    if mode=='train':
                        array_loss_train.append(l.item())
                    else:
                        array_loss_valid.append(l.item())
                    """
                    #aggiorniamo il global_step
                    #conterrà il numero di campioni visti durante il training
                    n = I_i.shape[0]  #numero di elementi nel batch
                    global_step += n

                    if mode == 'train':
                        labels_train.extend(list(labs.numpy()))
                        print("Lunghezza predette TRAIN ",
                              len(prediction_train))
                        print("Lunghezza vere TRAIN ", len(labels_train))
                        acc = accuracy_score(np.array(labels_train),
                                             np.array(prediction_train))
                        acc_meter.add(acc, n)

                    else:
                        labels_val.extend(list(labs.numpy()))
                        print("Lunghezza predette VALID ", len(prediction_val))
                        print("Lunghezza vere VALID ", len(labels_val))
                        acc = accuracy_score(np.array(labels_val),
                                             np.array(prediction_val))
                        acc_meter.add(acc, n)

                    if mode == 'train':
                        l.backward()
                        optimizer.step()
                        optimizer.zero_grad()

                    n = batch[0].shape[0]  #numero di elementi nel batch
                    loss_meter.add(l.item(), n)

                    if mode == 'train':
                        writer.add_scalar('loss/train',
                                          loss_meter.value(),
                                          global_step=global_step)
                        writer.add_scalar('accuracy/train',
                                          acc_meter.value(),
                                          global_step=global_step)

                    if mode == 'train':
                        lossTrain = loss_meter.value()
                        global_step_train = global_step
                        array_loss_train.append(lossTrain)
                        array_acc_train.append(acc_meter.value())
                        array_sample_train.append(global_step_train)
                        print("TRAIN- Epoca", e)
                        print("GLOBAL STEP TRAIN", global_step_train)
                        print("LOSS TRAIN", lossTrain)
                        print("ACC TRAIN", acc_meter.value())

                    else:
                        lossValid = loss_meter.value()
                        global_step_val = global_step
                        array_loss_valid.append(lossValid)
                        array_acc_valid.append(acc_meter.value())
                        array_sample_valid.append(global_step_val)
                        print("VALID- Epoca", e)
                        print("GLOBAL STEP VALID", global_step_val)
                        print("LOSS VALID", lossValid)
                        print("ACC VALID", acc_meter.value())

            writer.add_scalar('loss/' + mode,
                              loss_meter.value(),
                              global_step=global_step)
            writer.add_scalar('accuracy/' + mode,
                              acc_meter.value(),
                              global_step=global_step)

        #aggiungiamo un embedding. Tensorboard farà il resto
        #Per monitorare lo stato di training della rete in termini qualitativi, alla fine di ogni epoca stamperemo l'embedding dell'ultimo batch di test.
        writer.add_embedding(phi_i,
                             batch[3],
                             I_i,
                             global_step=global_step,
                             tag=exp_name + '_embedding')
        #conserviamo solo l'ultimo modello sovrascrivendo i vecchi

        #torch.save(siamese_reload.state_dict(),'%s.pth'%exp_name) # salvare i parametri del modello

        net_save(epochs, siamese_reload, optimizer, lossTrain, lossValid,
                 array_acc_train[-1], array_acc_valid[-1], global_step_train,
                 global_step_val, '%s.pth' % exp_name)
    f = '{:.7f}'.format(timer.stop())

    return siamese_reload, f, array_loss_train, array_loss_valid, array_sample_train, array_sample_valid, array_acc_train, array_acc_valid, labels_train, prediction_train, labels_val, prediction_val
Exemplo n.º 13
0
def train(args):
    #print(t.get_num_threads())
    t.set_num_threads(8)
    print("Number of threads: {}".format(t.get_num_threads()))
    idx2word = pickle.load(
        open(os.path.join(args.data_dir, 'idx2word.dat'), 'rb'))
    word2idx = pickle.load(
        open(os.path.join(args.data_dir, 'word2idx.dat'), 'rb'))
    wc = pickle.load(open(os.path.join(args.data_dir, 'wc.dat'), 'rb'))
    wf = np.array([wc[word] for word in idx2word])
    wf = wf / wf.sum()
    # frequency subsampling
    ws = 1 - np.sqrt(args.ss_t / wf)
    ws = np.clip(ws, 0, 1)
    vocab_size = len(idx2word)
    weights = wf if args.weights else None
    if not os.path.isdir(args.save_dir):
        os.mkdir(args.save_dir)
    model = SkipGramNeg(vocab_size=vocab_size, emb_dim=args.e_dim)
    modelpath = os.path.join(args.save_dir, '{}.pt'.format(args.name))
    #sgns = SGNS(embedding=model, vocab=word2idx.keys(), vocab_size=vocab_size, n_negs=args.n_negs, weights=weights)
    if os.path.isfile(modelpath) and args.conti:
        model.load_state_dict(t.load(modelpath))
    if args.cuda:
        model = model.cuda()
    optim = SGD(model.parameters(), lr=0.01)
    optimpath = os.path.join(args.save_dir, '{}.optim.pt'.format(args.name))
    if os.path.isfile(optimpath) and args.conti:
        optim.load_state_dict(t.load(optimpath))
    data_utils = DataUtils(wc, args.n_negs)
    data_utils.initTableNegatives()
    dataset = PermutedSubsampledCorpus(
        os.path.join(args.data_dir, 'train.dat'), data_utils)
    #pipeline = DataPipeline(dataset, range(len(idx2word)), vocab_size, data_offest=0, use_noise_neg=False)
    #vali_examples = random.sample(word2idx.keys(), vali_size)
    for epoch in range(1, args.epoch + 1):
        batch_size = len(dataset) // args.mb
        #print(batch_size)
        #print(len(dataset))
        dataloader = DataLoader(dataset, batch_size=args.mb, shuffle=True)

        #batch_inputs, batch_labels = pipeline.generate_batch(batch_size, num_skips, skip_window)
        #batch_neg = pipeline.get_neg_data(batch_size, num_neg, batch_inputs)
        total_batches = int(np.ceil(len(dataset) / args.mb))
        pbar = tqdm(dataloader)
        pbar.set_description("[Epoch {}]".format(epoch))
        #print(len(list(map(lambda item: item[1].tolist(), dataloader))[0]))
        #batch_neg = get_neg_data(batch_size, args.n_negs, list(map(lambda item: item[1].tolist(), dataloader)), range(len(idx2word)))
        #batch_neg = t.tensor(batch_neg, dtype=t.long)
        for iword, owords, batch_neg in pbar:
            iword = iword.to(device)
            owords = owords.to(device)
            batch_neg = batch_neg.to(device)

            loss = model(iword, owords, batch_neg)
            optim.zero_grad()
            loss.backward()
            optim.step()
            pbar.set_postfix(loss=loss.item())
            print("Loss: {}".format(loss))
    idx2vec = model.input_emb.weight.data.cpu().numpy()
    pickle.dump(idx2vec, open(os.path.join(args.data_dir, 'idx2vec.dat'),
                              'wb'))
    t.save(model.state_dict(),
           os.path.join(args.save_dir, '{}.pt'.format(args.name)))
    t.save(optim.state_dict(),
           os.path.join(args.save_dir, '{}.optim.pt'.format(args.name)))
Exemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser(description=__doc__)
    arg = parser.add_argument

    arg('--model', default='fasterrcnn_resnet50_fpn',
        help='model')  # resnet50/152?
    arg('--device', default='cuda', help='device')  # cuda for gpu
    arg('--batch-size', default=16, type=int)  # batchsize
    arg('--workers',
        default=4,
        type=int,
        help='number of data loading workers')  # workers
    arg('--lr', default=0.01, type=float,
        help='initial learning rate')  # learing rate
    arg('--momentum', default=0.9, type=float,
        help='momentum')  # optimizer momentum
    arg('--wd',
        '--weight-decay',
        default=1e-4,
        type=float,
        help='weight decay (default: 1e-4)',
        dest='weight_decay')  # optimizer weight decay
    arg('--epochs', default=45, type=int,
        help='number of total epochs to run')  # epochs
    arg('--lr-steps',
        default=[35],
        nargs='+',
        type=int,
        help='decrease lr every step-size epochs')  # learning rate scheduler
    arg('--lr-gamma',
        default=0.1,
        type=float,
        help='decrease lr by a factor of lr-gamma')  # lr scheduler rate
    arg('--cosine',
        type=int,
        default=0,
        help='cosine lr schedule (disabled step lr schedule)'
        )  # cosine lr scheduler
    arg('--print-freq', default=100, type=int,
        help='print frequency')  # print freq
    arg('--output-dir',
        help='path where to save')  # output directory after training
    arg('--resume',
        help='resume from checkpoint')  # resume training from checkpoint
    arg('--test-only', help='Only test the model',
        action='store_true')  # testing only without submission
    arg('--submission', help='Create test predictions',
        action='store_true')  # submission
    arg('--pretrained',
        type=int,
        default=0,
        help='Use pre-trained models from the modelzoo'
        )  # pretrained models from modelzoo
    arg('--score-threshold', type=float,
        default=0.5)  # score threshold for detection
    arg('--nms-threshold', type=float,
        default=0.25)  # non max suppresion threshold for detection
    arg('--repeat-train-step', type=int, default=2)  # repeat train

    # fold parameters
    arg('--fold', type=int, default=0)  # how many folds
    arg('--n-folds', type=int, default=5)  # number of folds

    args = parser.parse_args()
    if args.test_only and args.submission:
        parser.error('Please select either test or submission')

    output_dir = Path(args.output_dir) if args.output_dir else None
    if output_dir:
        output_dir.mkdir(parents=True, exist_ok=True)

    # utils.init_distributed_mode(args)
    print(args)
    device = torch.device(args.device)

    # Loading dataset
    print('...Loading Data Now...')

    df_train, df_valid = load_train_valid_df(args.fold,
                                             args.n_folds)  # from data_utils
    root = TRAIN_ROOT  # from data_utils
    if args.submission:
        df_valid = pd.read_csv(DATA_ROOT /
                               'sample_submission.csv')  # from data_utils
        df_valid['labels'] = ''
        root = TEST_ROOT
    dataset_train = Dataset(df_train,
                            augmentation(train=True),
                            root,
                            skip_empty=False)
    dataset_test = Dataset(df_valid,
                           augmentation(train=False),
                           root,
                           skip_empty=False)

    # Pytorch data loaders
    print('...Creating The Data Loaders Now...')
    train_sampler = torch.utils.data.RandomSampler(dataset_train)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    train_batch = torch.utils.data.BatchSampler(train_sampler,
                                                args.batch_size,
                                                drop_last=True)
    data_loader_train = torch.utils.data.DataLoader(dataset_train,
                                                    batch_sampler=train_batch,
                                                    num_workers=args.workers,
                                                    collate_fn=collate_fn)
    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=1,
                                                   sampler=test_sampler,
                                                   num_workers=args.workers,
                                                   collate_fn=collate_fn)

    # Create The Model
    print('...Creating Model Now...')
    model = build_model(args.model, args.pretrained, args.nms_threshold)
    model.to(device)

    params = [para for para in model.parameters()
              if para.requires_grad]  # requires grad?
    optimizer = SGD(params,
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    lr_scheduler = None
    if args.cosine:
        lr_scheduler = CosineAnnealingLR(optimizer, args.epochs)
    elif args.lr_steps:
        lr_scheduler = MultiStepLR(optimizer,
                                   milestones=args.lr_steps,
                                   gamma=args.lr_gamma)
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        if 'model' in checkpoint:
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if lr_scheduler and 'lr_scheduler' in checkpoint:
                lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        else:
            model.load_state_dict(checkpoint)
        print('Loaded from checkpoint {}'.format(args.resume))

    def save_eval_results(results):
        scores, clf_gt = results
        if output_dir:
            pd.DataFrame(scores).to_csv(output_dir / 'eval.csv', index=None)
            pd.DataFrame(clf_gt).to_csv(output_dir / 'clf_gt.csv', index=None)

    if args.test_only or args.submission:
        _, eval_results = evaluate(model,
                                   data_loader_test,
                                   device=device,
                                   output_dir=output_dir,
                                   threshold=args.score_threshold)
        if args.test_only:
            save_eval_results(eval_results)
        elif output_dir:
            pd.DataFrame(eval_results[1]).to_csv(output_dir /
                                                 'test_predictions.csv',
                                                 index=None)
        return

    # Start Training
    print('...Training Session Begin...')
    best_f1 = 0
    start = time.time()
    for epoch in range(args.epochs):
        #         train_sampler.set_epoch(epoch)
        for _ in range(args.repeat_train_step):
            train_metrics = train_one_epoch(model, optimizer,
                                            data_loader_train, device, epoch,
                                            args.print_freq)
        if lr_scheduler:
            lr_scheduler.step()
        if output_dir:
            # json_log_plots.write_event(output_dir, step=epoch, **train_metrics)
            save_on_master(
                {
                    'model':
                    model.state_dict(),
                    'optimizer':
                    optimizer.state_dict(),
                    'lr_scheduler':
                    (lr_scheduler.state_dict if lr_scheduler else None),
                    'args':
                    args
                }, output_dir / 'checkpoint.p')
        # evaluation for every epoch
        eval_metrics, eval_results = evaluate(model,
                                              data_loader_test,
                                              device=device,
                                              output_dir=None,
                                              threshold=args.score_threshold)
        save_eval_results(eval_results)
        if output_dir:
            # json_log_plots.write_event(output_dir, step=epoch, **eval_metrics)
            if eval_metrics['f1'] > best_f1:
                best_f1 = eval_metrics['f1']
                print('Updated best model with f1 of {}'.format(best_f1))
                save_on_master(model.state_dict(), output_dir / 'model_best.p')

    total_time = time.time() - start
    final = str(datetime.timedelta(seconds=int(total_time)))
    print('Trained for {} seconds'.format(final))
Exemplo n.º 15
0
class CoNLLNERTrainer(BaseTrainer):
    def __init__(self):
        super().__init__()

        self.model = None

        self.word_alphabet = None
        self.char_alphabet = None
        self.ner_alphabet = None

        self.config_model = None
        self.config_data = None
        self.normalize_func = None

        self.device = None
        self.optim, self.trained_epochs = None, None

        self.resource: Optional[Resources] = None

        self.train_instances_cache = []

        # Just for recording
        self.max_char_length = 0

        self.__past_dev_result = None

    def initialize(self, resource: Resources, configs: HParams):

        self.resource = resource

        self.word_alphabet = resource.get("word_alphabet")
        self.char_alphabet = resource.get("char_alphabet")
        self.ner_alphabet = resource.get("ner_alphabet")

        word_embedding_table = resource.get('word_embedding_table')

        self.config_model = configs.config_model
        self.config_data = configs.config_data

        self.normalize_func = utils.normalize_digit_word

        self.device = torch.device("cuda") if torch.cuda.is_available() \
            else torch.device("cpu")

        utils.set_random_seed(self.config_model.random_seed)

        self.model = BiRecurrentConvCRF(
            word_embedding_table, self.char_alphabet.size(),
            self.ner_alphabet.size(), self.config_model).to(device=self.device)

        self.optim = SGD(self.model.parameters(),
                         lr=self.config_model.learning_rate,
                         momentum=self.config_model.momentum,
                         nesterov=True)

        self.trained_epochs = 0

        self.resource.update(model=self.model)

    def data_request(self):
        request_string = {
            "context_type": Sentence,
            "request": {
                Token: ["ner"],
                Sentence: [],  # span by default
            }
        }
        return request_string

    def consume(self, instance):
        tokens = instance["Token"]
        word_ids = []
        char_id_seqs = []
        ner_tags, ner_ids = tokens["ner"], []

        for word in tokens["text"]:
            char_ids = []
            for char in word:
                char_ids.append(self.char_alphabet.get_index(char))
            if len(char_ids) > self.config_data.max_char_length:
                char_ids = char_ids[:self.config_data.max_char_length]
            char_id_seqs.append(char_ids)

            word = self.normalize_func(word)
            word_ids.append(self.word_alphabet.get_index(word))

        for ner in ner_tags:
            ner_ids.append(self.ner_alphabet.get_index(ner))

        max_len = max([len(char_seq) for char_seq in char_id_seqs])
        self.max_char_length = max(self.max_char_length, max_len)

        self.train_instances_cache.append((word_ids, char_id_seqs, ner_ids))

    def pack_finish_action(self, pack_count):
        pass

    def epoch_finish_action(self, epoch):
        """
        at the end of each dataset_iteration, we perform the training,
        and set validation flags
        :return:
        """
        counter = len(self.train_instances_cache)
        logger.info(f"Total number of ner_data: {counter}")

        lengths = \
            sum([len(instance[0]) for instance in self.train_instances_cache])

        logger.info(f"Average sentence length: {(lengths / counter):0.3f}")

        train_err = 0.0
        train_total = 0.0

        start_time = time.time()
        self.model.train()

        # Each time we will clear and reload the train_instances_cache
        instances = self.train_instances_cache
        random.shuffle(self.train_instances_cache)
        data_iterator = torchtext.data.iterator.pool(
            instances,
            self.config_data.batch_size_tokens,
            key=lambda x: x.length(),  # length of word_ids
            batch_size_fn=batch_size_fn,
            random_shuffler=torchtext.data.iterator.RandomShuffler())

        step = 0

        for batch in data_iterator:
            step += 1
            batch_data = self.get_batch_tensor(batch, device=self.device)
            word, char, labels, masks, lengths = batch_data

            self.optim.zero_grad()
            loss = self.model(word, char, labels, mask=masks)
            loss.backward()
            self.optim.step()

            num_inst = word.size(0)
            train_err += loss.item() * num_inst
            train_total += num_inst

            # update log
            if step % 200 == 0:
                logger.info(f"Train: {step}, "
                            f"loss: {(train_err / train_total):0.3f}")

        logger.info(f"Epoch: {epoch}, steps: {step}, "
                    f"loss: {(train_err / train_total):0.3f}, "
                    f"time: {(time.time() - start_time):0.3f}s")

        self.trained_epochs = epoch

        if epoch % self.config_model.decay_interval == 0:
            lr = self.config_model.learning_rate / \
                 (1.0 + self.trained_epochs * self.config_model.decay_rate)
            for param_group in self.optim.param_groups:
                param_group["lr"] = lr
            logger.info(f"Update learning rate to {lr:0.3f}")

        self.request_eval()
        self.train_instances_cache.clear()

        if epoch >= self.config_data.num_epochs:
            self.request_stop_train()

    @torch.no_grad()
    def get_loss(self, instances: Iterator) -> float:
        losses = 0
        val_data = list(instances)
        for i in tqdm(range(0, len(val_data),
                            self.config_data.test_batch_size)):
            b_data = val_data[i:i + self.config_data.test_batch_size]
            batch = self.get_batch_tensor(b_data, device=self.device)

            word, char, labels, masks, unused_lengths = batch
            loss = self.model(word, char, labels, mask=masks)
            losses += loss.item()

        mean_loss = losses / len(val_data)
        return mean_loss

    def post_validation_action(self, eval_result):
        if self.__past_dev_result is None or \
                (eval_result["eval"]["f1"] >
                 self.__past_dev_result["eval"]["f1"]):
            self.__past_dev_result = eval_result
            logger.info("Validation f1 increased, saving model")
            self.save_model_checkpoint()

        best_epoch = self.__past_dev_result["epoch"]
        acc, prec, rec, f1 = (self.__past_dev_result["eval"]["accuracy"],
                              self.__past_dev_result["eval"]["precision"],
                              self.__past_dev_result["eval"]["recall"],
                              self.__past_dev_result["eval"]["f1"])
        logger.info(f"Best val acc: {acc: 0.3f}, precision: {prec:0.3f}, "
                    f"recall: {rec:0.3f}, F1: {f1:0.3f}, epoch={best_epoch}")

        if "test" in self.__past_dev_result:
            acc, prec, rec, f1 = (self.__past_dev_result["test"]["accuracy"],
                                  self.__past_dev_result["test"]["precision"],
                                  self.__past_dev_result["test"]["recall"],
                                  self.__past_dev_result["test"]["f1"])
            logger.info(
                f"Best test acc: {acc: 0.3f}, precision: {prec: 0.3f}, "
                f"recall: {rec: 0.3f}, F1: {f1: 0.3f}, "
                f"epoch={best_epoch}")

    def finish(self, resources: Resources):  # pylint: disable=unused-argument
        if self.resource:
            keys_to_serializers = {}
            for key in resources.keys():
                if key == "model":
                    keys_to_serializers[key] = \
                        lambda x, y: pickle.dump(x.state_dict(), open(y, "wb"))
                else:
                    keys_to_serializers[key] = \
                        lambda x, y: pickle.dump(x, open(y, "wb"))

            self.resource.save(keys_to_serializers)

        self.save_model_checkpoint()

    def save_model_checkpoint(self):
        states = {
            "model": self.model.state_dict(),
            "optimizer": self.optim.state_dict(),
        }
        torch.save(states, self.config_model.model_path)

    def load_model_checkpoint(self):
        ckpt = torch.load(self.config_model.model_path)
        logger.info("restoring model from %s", self.config_model.model_path)
        self.model.load_state_dict(ckpt["model"])
        self.optim.load_state_dict(ckpt["optimizer"])

    def get_batch_tensor(
            self, data: List[Tuple[List[int], List[List[int]], List[int]]],
            device: Optional[torch.device] = None) -> \
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
                  torch.Tensor]:
        """Get the tensors to be fed into the model.

        Args:
            data: A list of tuple (word_ids, char_id_sequences, ner_ids)
            device: The device for the tensors.

        Returns:
            A tuple where

            - ``words``: A tensor of shape `[batch_size, batch_length]`
              representing the word ids in the batch
            - ``chars``: A tensor of shape
              `[batch_size, batch_length, char_length]` representing the char
              ids for each word in the batch
            - ``ners``: A tensor of shape `[batch_size, batch_length]`
              representing the ner ids for each word in the batch
            - ``masks``: A tensor of shape `[batch_size, batch_length]`
              representing the indices to be masked in the batch. 1 indicates
              no masking.
            - ``lengths``: A tensor of shape `[batch_size]` representing the
              length of each sentences in the batch
        """
        batch_size = len(data)
        batch_length = max([len(d[0]) for d in data])
        char_length = max(
            [max([len(charseq) for charseq in d[1]]) for d in data])

        char_length = min(
            self.config_data.max_char_length,
            char_length + self.config_data.num_char_pad,
        )

        wid_inputs = np.empty([batch_size, batch_length], dtype=np.int64)
        cid_inputs = np.empty([batch_size, batch_length, char_length],
                              dtype=np.int64)
        nid_inputs = np.empty([batch_size, batch_length], dtype=np.int64)

        masks = np.zeros([batch_size, batch_length], dtype=np.float32)

        lengths = np.empty(batch_size, dtype=np.int64)

        for i, inst in enumerate(data):
            wids, cid_seqs, nids = inst

            inst_size = len(wids)
            lengths[i] = inst_size
            # word ids
            wid_inputs[i, :inst_size] = wids
            wid_inputs[i, inst_size:] = self.word_alphabet.pad_id
            for c, cids in enumerate(cid_seqs):
                cid_inputs[i, c, :len(cids)] = cids
                cid_inputs[i, c, len(cids):] = self.char_alphabet.pad_id
            cid_inputs[i, inst_size:, :] = self.char_alphabet.pad_id
            # ner ids
            nid_inputs[i, :inst_size] = nids
            nid_inputs[i, inst_size:] = self.ner_alphabet.pad_id
            # masks
            masks[i, :inst_size] = 1.0

        words = torch.from_numpy(wid_inputs).to(device)
        chars = torch.from_numpy(cid_inputs).to(device)
        ners = torch.from_numpy(nid_inputs).to(device)
        masks = torch.from_numpy(masks).to(device)
        lengths = torch.from_numpy(lengths).to(device)

        return words, chars, ners, masks, lengths
Exemplo n.º 16
0
    def fit(self, dataset, mode='fit', **kwargs):
        from sklearn.metrics import accuracy_score

        assert self.model is not None

        params = self.model.parameters()
        val_loader = None
        if 'refit' in mode:
            train_loader = DataLoader(dataset=dataset.train_dataset,
                                      batch_size=self.batch_size,
                                      shuffle=True,
                                      num_workers=NUM_WORKERS)
            if mode == 'refit_test':
                val_loader = DataLoader(dataset=dataset.test_dataset,
                                        batch_size=self.batch_size,
                                        shuffle=False,
                                        num_workers=NUM_WORKERS)
        else:
            if not dataset.subset_sampler_used:
                train_loader = DataLoader(dataset=dataset.train_dataset,
                                          batch_size=self.batch_size,
                                          shuffle=True,
                                          num_workers=NUM_WORKERS)
                val_loader = DataLoader(dataset=dataset.val_dataset,
                                        batch_size=self.batch_size,
                                        shuffle=False,
                                        num_workers=NUM_WORKERS)
            else:
                train_loader = DataLoader(dataset=dataset.train_dataset,
                                          batch_size=self.batch_size,
                                          sampler=dataset.train_sampler,
                                          num_workers=NUM_WORKERS)
                val_loader = DataLoader(dataset=dataset.train_for_val_dataset,
                                        batch_size=self.batch_size,
                                        sampler=dataset.val_sampler,
                                        num_workers=NUM_WORKERS)

        if self.optimizer == 'SGD':
            optimizer = SGD(params=params,
                            lr=self.sgd_learning_rate,
                            momentum=self.sgd_momentum)
        elif self.optimizer == 'Adam':
            optimizer = Adam(params=params,
                             lr=self.adam_learning_rate,
                             betas=(self.beta1, 0.999))
        else:
            return ValueError("Optimizer %s not supported!" % self.optimizer)

        scheduler = MultiStepLR(
            optimizer,
            milestones=[int(self.max_epoch * 0.5),
                        int(self.max_epoch * 0.75)],
            gamma=self.lr_decay)
        loss_func = nn.CrossEntropyLoss()
        early_stop = EarlyStop(patience=5, mode='min')

        if self.load_path:
            checkpoint = torch.load(self.load_path)
            self.model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            self.cur_epoch_num = checkpoint['epoch_num']
            early_stop = checkpoint['early_stop']
            if early_stop.if_early_stop:
                print("Early stop!")
                self.optimizer_ = optimizer
                self.epoch_num = int(self.epoch_num) + int(self.cur_epoch_num)
                self.scheduler = scheduler
                self.early_stop = early_stop
                return self

        profile_iter = kwargs.get('profile_iter', None)
        profile_epoch = kwargs.get('profile_epoch', None)
        assert not (profile_iter and profile_epoch)

        if profile_epoch or profile_iter:  # Profile mode
            self.model.train()
            if profile_epoch:
                for epoch in range(int(profile_epoch)):
                    for i, data in enumerate(train_loader):
                        batch_x, batch_y = data[0], data[1]
                        masks = torch.Tensor(
                            np.array([[float(i != 0) for i in sample]
                                      for sample in batch_x]))
                        logits = self.model(batch_x.long().to(self.device),
                                            masks.to(self.device))
                        optimizer.zero_grad()
                        loss = loss_func(logits, batch_y.to(self.device))
                        loss.backward()
                        optimizer.step()
            else:
                num_iter = 0
                stop_flag = False
                for epoch in range(int(self.epoch_num)):
                    if stop_flag:
                        break
                    for i, data in enumerate(train_loader):
                        batch_x, batch_y = data[0], data[1]
                        masks = torch.Tensor(
                            np.array([[float(i != 0) for i in sample]
                                      for sample in batch_x]))
                        logits = self.model(batch_x.long().to(self.device),
                                            masks.to(self.device))
                        optimizer.zero_grad()
                        loss = loss_func(logits, batch_y.to(self.device))
                        loss.backward()
                        optimizer.step()
                        num_iter += 1
                        if num_iter > profile_iter:
                            stop_flag = True
                            break
            return self

        for epoch in range(int(self.cur_epoch_num),
                           int(self.cur_epoch_num) + int(self.epoch_num)):
            self.model.train()
            # print('Current learning rate: %.5f' % optimizer.state_dict()['param_groups'][0]['lr'])
            epoch_avg_loss = 0
            epoch_avg_acc = 0
            val_avg_loss = 0
            val_avg_acc = 0
            num_train_samples = 0
            num_val_samples = 0
            for i, data in enumerate(train_loader):
                batch_x, batch_y = data[0], data[1]
                num_train_samples += len(batch_x)
                masks = torch.Tensor(
                    np.array([[float(i != 0) for i in sample]
                              for sample in batch_x]))
                logits = self.model(batch_x.long().to(self.device),
                                    masks.to(self.device))
                loss = loss_func(logits, batch_y.to(self.device))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                epoch_avg_loss += loss.to('cpu').detach() * len(batch_x)
                prediction = np.argmax(logits.to('cpu').detach().numpy(),
                                       axis=-1)
                epoch_avg_acc += accuracy_score(
                    prediction,
                    batch_y.to('cpu').detach().numpy()) * len(batch_x)

            epoch_avg_loss /= num_train_samples
            epoch_avg_acc /= num_train_samples
            # TODO: logger
            print('Epoch %d: Train loss %.4f, train acc %.4f' %
                  (epoch, epoch_avg_loss, epoch_avg_acc))

            if val_loader is not None:
                self.model.eval()
                with torch.no_grad():
                    for i, data in enumerate(val_loader):
                        batch_x, batch_y = data[0], data[1]
                        masks = torch.Tensor(
                            np.array([[float(i != 0) for i in sample]
                                      for sample in batch_x]))
                        logits = self.model(batch_x.long().to(self.device),
                                            masks.to(self.device))
                        val_loss = loss_func(logits, batch_y.to(self.device))
                        num_val_samples += len(batch_x)
                        val_avg_loss += val_loss.to('cpu').detach() * len(
                            batch_x)

                        prediction = np.argmax(
                            logits.to('cpu').detach().numpy(), axis=-1)
                        val_avg_acc += accuracy_score(
                            prediction,
                            batch_y.to('cpu').detach().numpy()) * len(batch_x)

                    val_avg_loss /= num_val_samples
                    val_avg_acc /= num_val_samples
                    print('Epoch %d: Val loss %.4f, val acc %.4f' %
                          (epoch, val_avg_loss, val_avg_acc))

                    # Early stop
                    if 'refit' not in mode:
                        early_stop.update(val_avg_loss)
                        if early_stop.if_early_stop:
                            self.early_stop_flag = True
                            print("Early stop!")
                            break

        scheduler.step()

        self.optimizer_ = optimizer
        self.epoch_num = int(self.epoch_num) + int(self.cur_epoch_num)
        self.scheduler = scheduler

        return self
Exemplo n.º 17
0
    def fit(self, dataset: DLDataset, mode='fit', **kwargs):
        assert self.model is not None

        if self.load_path:
            self.model.load_state_dict(torch.load(self.load_path))

        params = self.model.parameters()

        val_loader = None
        if 'refit' in mode:
            train_loader = DataLoader(
                dataset=dataset.train_dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=NUM_WORKERS,
                collate_fn=dataset.train_dataset.collate_fn)
            if mode == 'refit_test':
                val_loader = DataLoader(
                    dataset=dataset.test_dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=NUM_WORKERS,
                    collate_fn=dataset.test_dataset.collate_fn)
        else:
            train_loader = DataLoader(
                dataset=dataset.train_dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=NUM_WORKERS,
                collate_fn=dataset.train_dataset.collate_fn)
            val_loader = DataLoader(dataset=dataset.val_dataset,
                                    batch_size=self.batch_size,
                                    shuffle=False,
                                    num_workers=NUM_WORKERS,
                                    collate_fn=dataset.val_dataset.collate_fn)
            # else:
            #     train_loader = DataLoader(dataset=dataset.train_dataset, batch_size=self.batch_size,
            #                               sampler=dataset.train_sampler, num_workers=4,
            #                               collate_fn=dataset.train_dataset.collate_fn)
            #     val_loader = DataLoader(dataset=dataset.train_dataset, batch_size=self.batch_size,
            #                             sampler=dataset.val_sampler, num_workers=4,
            #                             collate_fn=dataset.train_dataset.collate_fn)

        if self.optimizer == 'SGD':
            optimizer = SGD(params=params,
                            lr=self.sgd_learning_rate,
                            momentum=self.sgd_momentum)
        elif self.optimizer == 'Adam':
            optimizer = Adam(params=params,
                             lr=self.adam_learning_rate,
                             betas=(self.beta1, 0.999))
        else:
            return ValueError("Optimizer %s not supported!" % self.optimizer)

        scheduler = MultiStepLR(
            optimizer,
            milestones=[int(self.max_epoch * 0.5),
                        int(self.max_epoch * 0.75)],
            gamma=self.lr_decay)
        early_stop = EarlyStop(patience=5, mode='min')

        if self.load_path:
            checkpoint = torch.load(self.load_path)
            self.model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            self.cur_epoch_num = checkpoint['epoch_num']
            early_stop = checkpoint['early_stop']
            if early_stop.if_early_stop:
                print("Early stop!")
                self.optimizer_ = optimizer
                self.epoch_num = int(self.epoch_num) + int(self.cur_epoch_num)
                self.scheduler = scheduler
                self.early_stop = early_stop
                return self

        profile_iter = kwargs.get('profile_iter', None)
        profile_epoch = kwargs.get('profile_epoch', None)
        assert not (profile_iter and profile_epoch)

        if profile_epoch or profile_iter:  # Profile mode
            self.model.train()
            if profile_epoch:
                for epoch in range(int(profile_epoch)):
                    for i, (_, batch_x, batch_y) in enumerate(train_loader):
                        loss, outputs = self.model(
                            batch_x.float().to(self.device),
                            batch_y.float().to(self.device))
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
            else:
                num_iter = 0
                stop_flag = False
                for epoch in range(int(self.epoch_num)):
                    if stop_flag:
                        break
                    for i, (_, batch_x, batch_y) in enumerate(train_loader):
                        loss, outputs = self.model(
                            batch_x.float().to(self.device),
                            batch_y.float().to(self.device))
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        num_iter += 1
                        if num_iter > profile_iter:
                            stop_flag = True
                            break
            return self

        for epoch in range(int(self.cur_epoch_num),
                           int(self.cur_epoch_num) + int(self.epoch_num)):
            self.model.train()
            # print('Current learning rate: %.5f' % optimizer.state_dict()['param_groups'][0]['lr'])
            epoch_avg_loss = 0
            val_avg_loss = 0
            num_train_samples = 0
            num_val_samples = 0
            for i, (_, batch_x, batch_y) in enumerate(train_loader):
                loss, outputs = self.model(batch_x.float().to(self.device),
                                           batch_y.float().to(self.device))
                optimizer.zero_grad()
                epoch_avg_loss += loss.to('cpu').detach() * len(batch_x)
                num_train_samples += len(batch_x)
                loss.backward()
                optimizer.step()
            epoch_avg_loss /= num_train_samples
            print('Epoch %d: Train loss %.4f' % (epoch, epoch_avg_loss))
            scheduler.step()

            if val_loader is not None:
                self.model.eval()
                with torch.no_grad():
                    for i, (_, batch_x, batch_y) in enumerate(val_loader):
                        loss, outputs = self.model(
                            batch_x.float().to(self.device),
                            batch_y.float().to(self.device))
                        val_avg_loss += loss.to('cpu').detach() * len(batch_x)
                        num_val_samples += len(batch_x)

                    val_avg_loss /= num_val_samples
                    print('Epoch %d: Val loss %.4f' % (epoch, val_avg_loss))

                    # Early stop
                    if 'refit' not in mode:
                        early_stop.update(val_avg_loss)
                        if early_stop.if_early_stop:
                            self.early_stop_flag = True
                            print("Early stop!")
                            break

        self.optimizer_ = optimizer
        self.epoch_num = int(self.epoch_num) + int(self.cur_epoch_num)
        self.scheduler = scheduler

        return self
Exemplo n.º 18
0
                                         int(0.4 * end_epoch),
                                         int(0.7 * end_epoch),
                                         int(0.8 * end_epoch),
                                         int(0.9 * end_epoch)
                                     ],
                                     gamma=0.1)
# scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',patience=10, verbose=True)
loss_class = nn.CrossEntropyLoss(reduction='elementwise_mean').to(device)
loss_reg = nn.SmoothL1Loss(reduction='sum').to(
    device)  # or MSELoss or L1Loss or SmoothL1Loss

# resume
if (os.path.isfile(resume_path) and resume_flag):
    checkpoint = torch.load(resume_path)
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    # scheduler.load_state_dict(checkpoint["scheduler_state"])
    # start_epoch = checkpoint["epoch"]
    print(
        "=====>",
        "Loaded checkpoint '{}' (iter {})".format(resume_path,
                                                  checkpoint["epoch"]))
else:
    print("=====>", "No checkpoint found at '{}'".format(resume_path))


# Training
def train(epoch, display=True):
    print('\nEpoch: %d' % epoch)
    model.train()
    train_class_loss = 0
Exemplo n.º 19
0
def train(model,
          state,
          path,
          annotations,
          val_path,
          val_annotations,
          resize,
          max_size,
          jitter,
          batch_size,
          iterations,
          val_iterations,
          mixed_precision,
          lr,
          warmup,
          milestones,
          gamma,
          is_master=True,
          world=1,
          use_dali=True,
          verbose=True,
          metrics_url=None,
          logdir=None):
    'Train the model on the given dataset'

    # Prepare model
    nn_model = model
    stride = model.stride

    model = convert_fixedbn_model(model)
    if torch.cuda.is_available():
        model = model.cuda()

    # Setup optimizer and schedule
    optimizer = SGD(model.parameters(),
                    lr=lr,
                    weight_decay=0.0001,
                    momentum=0.9)

    model, optimizer = amp.initialize(
        model,
        optimizer,
        opt_level='O2' if mixed_precision else 'O0',
        keep_batchnorm_fp32=True,
        loss_scale=128.0,
        verbosity=is_master)

    if world > 1:
        model = DistributedDataParallel(model)
    model.train()

    if 'optimizer' in state:
        optimizer.load_state_dict(state['optimizer'])

    def schedule(train_iter):
        if warmup and train_iter <= warmup:
            return 0.9 * train_iter / warmup + 0.1
        return gamma**len([m for m in milestones if m <= train_iter])

    scheduler = LambdaLR(optimizer.optimizer if mixed_precision else optimizer,
                         schedule)

    # Prepare dataset
    if verbose: print('Preparing dataset...')
    data_iterator = (DaliDataIterator if use_dali else DataIterator)(
        path,
        jitter,
        max_size,
        batch_size,
        stride,
        world,
        annotations,
        training=True)
    if verbose: print(data_iterator)

    if verbose:
        print('    device: {} {}'.format(
            world, 'cpu' if not torch.cuda.is_available() else
            'gpu' if world == 1 else 'gpus'))
        print('    batch: {}, precision: {}'.format(
            batch_size, 'mixed' if mixed_precision else 'full'))
        print('Training model for {} iterations...'.format(iterations))

    # Create TensorBoard writer
    if logdir is not None:
        from tensorboardX import SummaryWriter
        if is_master and verbose:
            print('Writing TensorBoard logs to: {}'.format(logdir))
        writer = SummaryWriter(log_dir=logdir)

    profiler = Profiler(['train', 'fw', 'bw'])
    iteration = state.get('iteration', 0)
    while iteration < iterations:
        cls_losses, box_losses = [], []
        for i, (data, target) in enumerate(data_iterator):
            scheduler.step(iteration)

            # Forward pass
            profiler.start('fw')

            optimizer.zero_grad()
            cls_loss, box_loss = model([data, target])
            del data
            profiler.stop('fw')

            # Backward pass
            profiler.start('bw')
            with amp.scale_loss(cls_loss + box_loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()

            # Reduce all losses
            cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean(
            ).clone()
            if world > 1:
                torch.distributed.all_reduce(cls_loss)
                torch.distributed.all_reduce(box_loss)
                cls_loss /= world
                box_loss /= world
            if is_master:
                cls_losses.append(cls_loss)
                box_losses.append(box_loss)

            if is_master and not isfinite(cls_loss + box_loss):
                raise RuntimeError('Loss is diverging!\n{}'.format(
                    'Try lowering the learning rate.'))

            del cls_loss, box_loss
            profiler.stop('bw')

            iteration += 1
            profiler.bump('train')
            if is_master and (profiler.totals['train'] > 60
                              or iteration == iterations):
                focal_loss = torch.stack(list(cls_losses)).mean().item()
                box_loss = torch.stack(list(box_losses)).mean().item()
                learning_rate = optimizer.param_groups[0]['lr']
                if verbose:
                    msg = '[{:{len}}/{}]'.format(iteration,
                                                 iterations,
                                                 len=len(str(iterations)))
                    msg += ' focal loss: {:.3f}'.format(focal_loss)
                    msg += ', box loss: {:.3f}'.format(box_loss)
                    msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'],
                                                       batch_size)
                    msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(
                        profiler.means['fw'], profiler.means['bw'])
                    msg += ', {:.1f} im/s'.format(batch_size /
                                                  profiler.means['train'])
                    msg += ', lr: {:.2g}'.format(learning_rate)
                    print(msg, flush=True)

                if logdir is not None:
                    writer.add_scalar('focal_loss', focal_loss, iteration)
                    writer.add_scalar('box_loss', box_loss, iteration)
                    writer.add_scalar('learning_rate', learning_rate,
                                      iteration)
                    del box_loss, focal_loss

                if metrics_url:
                    post_metrics(
                        metrics_url, {
                            'focal loss': mean(cls_losses),
                            'box loss': mean(box_losses),
                            'im_s': batch_size / profiler.means['train'],
                            'lr': learning_rate
                        })

                # Save model weights
                state.update({
                    'iteration': iteration,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                })
                with ignore_sigint():
                    nn_model.save(state)

                profiler.reset()
                del cls_losses[:], box_losses[:]

            if val_annotations and (iteration == iterations
                                    or iteration % val_iterations == 0):
                infer(model,
                      val_path,
                      None,
                      resize,
                      max_size,
                      batch_size,
                      annotations=val_annotations,
                      mixed_precision=mixed_precision,
                      is_master=is_master,
                      world=world,
                      use_dali=use_dali,
                      verbose=False)
                model.train()

            if iteration == iterations:
                break

    if logdir is not None:
        writer.close()
Exemplo n.º 20
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(
        root=args.source_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=args.resize_ratio,
                                scale=(0.5, 1.)),
            T.ColorJitter(brightness=0.3, contrast=0.3),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(
        root=args.target_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=(2., 2.),
                                scale=(0.5, 1.)),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(
        root=args.target_root,
        split='val',
        transforms=T.Compose([
            T.Resize(image_size=args.test_input_size,
                     label_size=args.test_output_size),
            T.NormalizeAndTranspose(),
        ]),
    )
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=1,
                                   shuffle=False,
                                   pin_memory=True)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    num_classes = train_source_dataset.num_classes
    model = models.__dict__[args.arch](num_classes=num_classes).to(device)
    discriminator = Discriminator(num_classes=num_classes).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(model.get_parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    optimizer_d = Adam(discriminator.parameters(),
                       lr=args.lr_d,
                       betas=(0.9, 0.99))
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))
    lr_scheduler_d = LambdaLR(
        optimizer_d, lambda x:
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))

    # optionally resume from a checkpoint
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        discriminator.load_state_dict(checkpoint['discriminator'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        optimizer_d.load_state_dict(checkpoint['optimizer_d'])
        lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d'])
        args.start_epoch = checkpoint['epoch'] + 1

    # define loss function (criterion)
    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=args.ignore_label).to(device)
    dann = DomainAdversarialEntropyLoss(discriminator)
    interp_train = nn.Upsample(size=args.train_size[::-1],
                               mode='bilinear',
                               align_corners=True)
    interp_val = nn.Upsample(size=args.test_output_size[::-1],
                             mode='bilinear',
                             align_corners=True)

    # define visualization function
    decode = train_source_dataset.decode_target

    def visualize(image, pred, label, prefix):
        """
        Args:
            image (tensor): 3 x H x W
            pred (tensor): C x H x W
            label (tensor): H x W
            prefix: prefix of the saving image
        """
        image = image.detach().cpu().numpy()
        pred = pred.detach().max(dim=0)[1].cpu().numpy()
        label = label.cpu().numpy()
        for tensor, name in [
            (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))),
             "image"), (decode(label), "label"), (decode(pred), "pred")
        ]:
            tensor.save(logger.get_image_path("{}_{}.png".format(prefix,
                                                                 name)))

    if args.phase == 'test':
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           visualize, args)
        print(confmat)
        return

    # start training
    best_iou = 0.
    for epoch in range(args.start_epoch, args.epochs):
        logger.set_epoch(epoch)
        print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr())
        # train for one epoch
        train(train_source_iter, train_target_iter, model, interp_train,
              criterion, dann, optimizer, lr_scheduler, optimizer_d,
              lr_scheduler_d, epoch, visualize if args.debug else None, args)

        # evaluate on validation set
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           None, args)
        print(confmat.format(train_source_dataset.classes))
        acc_global, acc, iu = confmat.compute()

        # calculate the mean iou over partial classes
        indexes = [
            train_source_dataset.classes.index(name)
            for name in train_source_dataset.evaluate_classes
        ]
        iu = iu[indexes]
        mean_iou = iu.mean()

        # remember best acc@1 and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'discriminator': discriminator.state_dict(),
                'optimizer': optimizer.state_dict(),
                'optimizer_d': optimizer_d.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'lr_scheduler_d': lr_scheduler_d.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if mean_iou > best_iou:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
        best_iou = max(best_iou, mean_iou)
        print("Target: {} Best: {}".format(mean_iou, best_iou))

    logger.close()
Exemplo n.º 21
0
class Trainer(object):
    """
    Trainer encapsulates all the logic necessary for
    training the Recurrent Attention Model.

    All hyperparameters are provided by the user in the
    config file.
    """
    def __init__(self, config, data_loader):
        """
        Construct a new Trainer instance.

        Args
        ----
        - config: object containing command line arguments.
        - data_loader: data iterator
        """
        self.config = config

        # glimpse network params
        self.patch_size = config.patch_size
        self.glimpse_scale = config.glimpse_scale
        self.num_patches = config.num_patches
        self.loc_hidden = config.loc_hidden
        self.glimpse_hidden = config.glimpse_hidden

        # core network params
        self.num_glimpses = config.num_glimpses
        self.hidden_size = config.hidden_size

        # reinforce params
        self.std = config.std
        self.M = config.M

        # data params
        if config.is_train:
            self.train_loader = data_loader[0]
            self.valid_loader = data_loader[1]
            self.num_train = len(self.train_loader.sampler.indices)
            self.num_valid = len(self.valid_loader.sampler.indices)
        else:
            self.test_loader = data_loader
            self.num_test = len(self.test_loader.dataset)
        self.num_classes = 10
        self.num_channels = 1

        # training params
        self.epochs = config.epochs
        self.start_epoch = 0
        self.momentum = config.momentum
        self.lr = config.init_lr

        # misc params
        self.use_gpu = config.use_gpu
        self.best = config.best
        self.ckpt_dir = config.ckpt_dir
        self.logs_dir = config.logs_dir
        self.best_valid_acc = 0.
        self.counter = 0
        self.patience = config.patience
        self.use_tensorboard = config.use_tensorboard
        self.resume = config.resume
        self.print_freq = config.print_freq
        self.plot_freq = config.plot_freq
        self.model_name = 'ram_{}_{}x{}_{}'.format(config.num_glimpses,
                                                   config.patch_size,
                                                   config.patch_size,
                                                   config.glimpse_scale)

        self.plot_dir = './plots/' + self.model_name + '/'
        if not os.path.exists(self.plot_dir):
            os.makedirs(self.plot_dir)

        # configure tensorboard logging
        if self.use_tensorboard:
            tensorboard_dir = self.logs_dir + self.model_name
            print('[*] Saving tensorboard logs to {}'.format(tensorboard_dir))
            if not os.path.exists(tensorboard_dir):
                os.makedirs(tensorboard_dir)
            configure(tensorboard_dir)

        # build RAM model
        self.model = RecurrentAttention(
            self.patch_size,
            self.num_patches,
            self.glimpse_scale,
            self.num_channels,
            self.loc_hidden,
            self.glimpse_hidden,
            self.std,
            self.hidden_size,
            self.num_classes,
        )
        if self.use_gpu:
            self.model.cuda()

        print('[*] Number of model parameters: {:,}'.format(
            sum([p.data.nelement() for p in self.model.parameters()])))

        # initialize optimizer and scheduler
        self.optimizer = SGD(
            self.model.parameters(),
            lr=self.lr,
            momentum=self.momentum,
        )
        self.scheduler = ReduceLROnPlateau(self.optimizer,
                                           'min',
                                           patience=self.patience)

    def reset(self):
        """
        Initialize the hidden state of the core network
        and the location vector.

        This is called once every time a new minibatch
        `x` is introduced.
        """
        dtype = torch.cuda.FloatTensor if self.use_gpu else torch.FloatTensor

        h_t = torch.zeros(self.batch_size, self.hidden_size)
        h_t = Variable(h_t).type(dtype)

        l_t = torch.Tensor(self.batch_size, 2).uniform_(-1, 1)
        l_t = Variable(l_t).type(dtype)

        return h_t, l_t

    def train(self):
        """
        Train the model on the training set.

        A checkpoint of the model is saved after each epoch
        and if the validation accuracy is improved upon,
        a separate ckpt is created for use on the test set.
        """
        # load the most recent checkpoint
        if self.resume:
            self.load_checkpoint(best=False)

        print("\n[*] Train on {} samples, validate on {} samples".format(
            self.num_train, self.num_valid))

        for epoch in range(self.start_epoch, self.epochs):

            print('\nEpoch: {}/{} - LR: {:.6f}'.format(epoch + 1, self.epochs,
                                                       self.lr))

            # train for 1 epoch
            train_loss, train_acc = self.train_one_epoch(epoch)

            # evaluate on validation set
            valid_loss, valid_acc = self.validate(epoch)

            # reduce lr if validation loss plateaus
            self.scheduler.step(valid_loss)

            is_best = valid_acc > self.best_valid_acc
            msg1 = "train loss: {:.3f} - train acc: {:.3f} "
            msg2 = "- val loss: {:.3f} - val acc: {:.3f}"
            if is_best:
                msg2 += " [*]"
            msg = msg1 + msg2
            print(msg.format(train_loss, train_acc, valid_loss, valid_acc))

            # check for improvement
            if not is_best:
                self.counter += 1
            if self.counter > self.patience:
                print("[!] No improvement in a while, stopping training.")
                return
            self.best_valid_acc = max(valid_acc, self.best_valid_acc)
            self.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model_state': self.model.state_dict(),
                    'optim_state': self.optimizer.state_dict(),
                    'best_valid_acc': self.best_valid_acc,
                }, is_best)

    def train_one_epoch(self, epoch):
        """
        Train the model for 1 epoch of the training set.

        An epoch corresponds to one full pass through the entire
        training set in successive mini-batches.

        This is used by train() and should not be called manually.
        """
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                if self.use_gpu:
                    x, y = x.cuda(), y.cuda()
                x, y = Variable(x), Variable(y)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t, l_t, b_t, log_probas, p = self.model(x,
                                                          l_t,
                                                          h_t,
                                                          last=True)
                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                loss_action = F.nll_loss(log_probas, y)
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.mean(-log_pi * adjusted_reward)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                losses.update(loss.data[0], x.size()[0])
                accs.update(acc.data[0], x.size()[0])

                # compute gradients and update SGD
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description(
                    ("{:.1f}s - loss: {:.3f} - acc: {:.3f}".format(
                        (toc - tic), loss.data[0], acc.data[0])))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    if self.use_gpu:
                        imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                        locs = [l.cpu().data.numpy() for l in locs]
                    else:
                        imgs = [g.data.numpy().squeeze() for g in imgs]
                        locs = [l.data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(self.plot_dir + "g_{}.p".format(epoch + 1), "wb"))
                    pickle.dump(
                        locs,
                        open(self.plot_dir + "l_{}.p".format(epoch + 1), "wb"))

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value('train_loss', losses.avg, iteration)
                    log_value('train_acc', accs.avg, iteration)

            return losses.avg, accs.avg

    def validate(self, epoch):
        """
        Evaluate the model on the validation set.
        """
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(1, self.num_glimpses)

            # compute losses for differentiable modules
            loss_action = F.nll_loss(log_probas, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.mean(-log_pi * adjusted_reward)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.data[0], x.size()[0])
            accs.update(acc.data[0], x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value('valid_loss', losses.avg, iteration)
                log_value('valid_acc', accs.avg, iteration)

        return losses.avg, accs.avg

    def test(self):
        """
        Test the model on the held-out test data.
        This function should only be called at the very
        end once the model has finished training.
        """
        correct = 0

        # load the best checkpoint
        self.load_checkpoint(best=self.best)

        for i, (x, y) in enumerate(self.test_loader):
            if self.use_gpu:
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x, volatile=True), Variable(y)

            # duplicate 10 times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)

            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            pred = log_probas.data.max(1, keepdim=True)[1]
            correct += pred.eq(y.data.view_as(pred)).cpu().sum()

        perc = (100. * correct) / (self.num_test)
        print('[*] Test Acc: {}/{} ({:.2f}%)'.format(correct, self.num_test,
                                                     perc))

    def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated
        on the test data.

        If this model has reached the best validation accuracy thus
        far, a seperate file with the suffix `best` is created.
        """
        # print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.model_name + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, os.path.join(self.ckpt_dir, filename))

    def load_checkpoint(self, best=False):
        """
        Load the best copy of a model. This is useful for 2 cases:

        - Resuming training with the most recent model checkpoint.
        - Loading the best validation model to evaluate on the test data.

        Params
        ------
        - best: if set to True, loads the best model. Use this if you want
          to evaluate your model on the test data. Else, set to False in
          which case the most recent version of the checkpoint is used.
        """
        print("[*] Loading model from {}".format(self.ckpt_dir))

        filename = self.model_name + '_ckpt.pth.tar'
        if best:
            filename = self.model_name + '_model_best.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        ckpt = torch.load(ckpt_path)

        # load variables from checkpoint
        self.start_epoch = ckpt['epoch']
        self.best_valid_acc = ckpt['best_valid_acc']
        self.model.load_state_dict(ckpt['model_state'])
        self.optimizer.load_state_dict(ckpt['optim_state'])

        if best:
            print("[*] Loaded {} checkpoint @ epoch {} "
                  "with best valid acc of {:.3f}".format(
                      filename, ckpt['epoch'] + 1, ckpt['best_valid_acc']))
        else:
            print("[*] Loaded {} checkpoint @ epoch {}".format(
                filename, ckpt['epoch'] + 1))
Exemplo n.º 22
0
def train(train_dir, model_dir, config_path, checkpoint_path,
          n_steps, save_every, test_every, decay_every,
          n_speakers, n_utterances, seg_len):
    """Train a d-vector network."""

    # setup
    total_steps = 0

    # load data
    dataset = SEDataset(train_dir, n_utterances, seg_len)
    train_set, valid_set = random_split(dataset, [len(dataset)-2*n_speakers,
                                                  2*n_speakers])
    train_loader = DataLoader(train_set, batch_size=n_speakers,
                              shuffle=True, num_workers=4,
                              collate_fn=pad_batch, drop_last=True)
    valid_loader = DataLoader(valid_set, batch_size=n_speakers,
                              shuffle=True, num_workers=4,
                              collate_fn=pad_batch, drop_last=True)
    train_iter = iter(train_loader)

    assert len(train_set) >= n_speakers
    assert len(valid_set) >= n_speakers
    print(f"Training starts with {len(train_set)} speakers. "
          f"(and {len(valid_set)} speakers for validation)")

    # build network and training tools
    dvector = DVector().load_config_file(config_path)
    criterion = GE2ELoss()
    optimizer = SGD(list(dvector.parameters()) +
                    list(criterion.parameters()), lr=0.01)
    scheduler = StepLR(optimizer, step_size=decay_every, gamma=0.5)

    # load checkpoint
    if checkpoint_path is not None:
        ckpt = torch.load(checkpoint_path)
        total_steps = ckpt["total_steps"]
        dvector.load_state_dict(ckpt["state_dict"])
        criterion.load_state_dict(ckpt["criterion"])
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["scheduler"])

    # prepare for training
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dvector = dvector.to(device)
    criterion = criterion.to(device)
    writer = SummaryWriter(model_dir)
    pbar = tqdm.trange(n_steps)

    # start training
    for step in pbar:

        total_steps += 1

        try:
            batch = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            batch = next(train_iter)

        embd = dvector(batch.to(device)).view(n_speakers, n_utterances, -1)

        loss = criterion(embd)

        optimizer.zero_grad()
        loss.backward()

        grad_norm = torch.nn.utils.clip_grad_norm_(
            list(dvector.parameters()) + list(criterion.parameters()), max_norm=3)
        dvector.embedding.weight.grad.data *= 0.5
        criterion.w.grad.data *= 0.01
        criterion.b.grad.data *= 0.01

        optimizer.step()
        scheduler.step()

        pbar.set_description(f"global = {total_steps}, loss = {loss:.4f}")
        writer.add_scalar("Training loss", loss, total_steps)
        writer.add_scalar("Gradient norm", grad_norm, total_steps)

        if (step + 1) % test_every == 0:
            batch = next(iter(valid_loader))
            embd = dvector(batch.to(device)).view(n_speakers, n_utterances, -1)
            loss = criterion(embd)
            writer.add_scalar("validation loss", loss, total_steps)

        if (step + 1) % save_every == 0:
            ckpt_path = os.path.join(model_dir, f"ckpt-{total_steps}.tar")
            ckpt_dict = {
                "total_steps": total_steps,
                "state_dict": dvector.state_dict(),
                "criterion": criterion.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            }
            torch.save(ckpt_dict, ckpt_path)

    print("Training completed.")
Exemplo n.º 23
0
def train(model,
          state,
          path,
          annotations,
          val_path,
          val_annotations,
          resize,
          max_size,
          jitter,
          batch_size,
          iterations,
          val_iterations,
          mixed_precision,
          lr,
          warmup,
          milestones,
          gamma,
          rank=0,
          world=1,
          no_apex=False,
          use_dali=True,
          verbose=True,
          metrics_url=None,
          logdir=None,
          rotate_augment=False,
          augment_brightness=0.0,
          augment_contrast=0.0,
          augment_hue=0.0,
          augment_saturation=0.0,
          regularization_l2=0.0001,
          rotated_bbox=False,
          absolute_angle=False):
    'Train the model on the given dataset'

    # Prepare model
    nn_model = model
    stride = model.stride

    model = convert_fixedbn_model(model)
    if torch.cuda.is_available():
        model = model.to(memory_format=torch.channels_last).cuda()

    # Setup optimizer and schedule
    optimizer = SGD(model.parameters(),
                    lr=lr,
                    weight_decay=regularization_l2,
                    momentum=0.9)

    is_master = rank == 0
    if not no_apex:
        loss_scale = "dynamic" if use_dali else "128.0"
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level='O2' if mixed_precision else 'O0',
            keep_batchnorm_fp32=True,
            loss_scale=loss_scale,
            verbosity=is_master)

    if world > 1:
        model = DDP(model, device_ids=[rank]) if no_apex else ADDP(model)
    model.train()

    if 'optimizer' in state:
        optimizer.load_state_dict(state['optimizer'])

    def schedule(train_iter):
        if warmup and train_iter <= warmup:
            return 0.9 * train_iter / warmup + 0.1
        return gamma**len([m for m in milestones if m <= train_iter])

    scheduler = LambdaLR(optimizer, schedule)
    if 'scheduler' in state:
        scheduler.load_state_dict(state['scheduler'])

    # Prepare dataset
    if verbose: print('Preparing dataset...')
    if rotated_bbox:
        if use_dali:
            raise NotImplementedError(
                "This repo does not currently support DALI for rotated bbox detections."
            )
        data_iterator = RotatedDataIterator(
            path,
            jitter,
            max_size,
            batch_size,
            stride,
            world,
            annotations,
            training=True,
            rotate_augment=rotate_augment,
            augment_brightness=augment_brightness,
            augment_contrast=augment_contrast,
            augment_hue=augment_hue,
            augment_saturation=augment_saturation,
            absolute_angle=absolute_angle)
    else:
        data_iterator = (DaliDataIterator if use_dali else DataIterator)(
            path,
            jitter,
            max_size,
            batch_size,
            stride,
            world,
            annotations,
            training=True,
            rotate_augment=rotate_augment,
            augment_brightness=augment_brightness,
            augment_contrast=augment_contrast,
            augment_hue=augment_hue,
            augment_saturation=augment_saturation)
    if verbose: print(data_iterator)

    if verbose:
        print('    device: {} {}'.format(
            world, 'cpu' if not torch.cuda.is_available() else
            'GPU' if world == 1 else 'GPUs'))
        print('     batch: {}, precision: {}'.format(
            batch_size, 'mixed' if mixed_precision else 'full'))
        print(' BBOX type:', 'rotated' if rotated_bbox else 'axis aligned')
        print('Training model for {} iterations...'.format(iterations))

    # Create TensorBoard writer
    if is_master and logdir is not None:
        from torch.utils.tensorboard import SummaryWriter
        if verbose:
            print('Writing TensorBoard logs to: {}'.format(logdir))
        writer = SummaryWriter(log_dir=logdir)

    scaler = GradScaler()
    profiler = Profiler(['train', 'fw', 'bw'])
    iteration = state.get('iteration', 0)
    while iteration < iterations:
        cls_losses, box_losses = [], []
        for i, (data, target) in enumerate(data_iterator):
            if iteration >= iterations:
                break

            # Forward pass
            profiler.start('fw')

            optimizer.zero_grad()
            if not no_apex:
                cls_loss, box_loss = model([
                    data.contiguous(memory_format=torch.channels_last), target
                ])
            else:
                with autocast():
                    cls_loss, box_loss = model([
                        data.contiguous(memory_format=torch.channels_last),
                        target
                    ])
            del data
            profiler.stop('fw')

            # Backward pass
            profiler.start('bw')
            if not no_apex:
                with amp.scale_loss(cls_loss + box_loss,
                                    optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()
            else:
                scaler.scale(cls_loss + box_loss).backward()
                scaler.step(optimizer)
                scaler.update()

            scheduler.step()

            # Reduce all losses
            cls_loss, box_loss = cls_loss.mean().clone(), box_loss.mean(
            ).clone()
            if world > 1:
                torch.distributed.all_reduce(cls_loss)
                torch.distributed.all_reduce(box_loss)
                cls_loss /= world
                box_loss /= world
            if is_master:
                cls_losses.append(cls_loss)
                box_losses.append(box_loss)

            if is_master and not isfinite(cls_loss + box_loss):
                raise RuntimeError('Loss is diverging!\n{}'.format(
                    'Try lowering the learning rate.'))

            del cls_loss, box_loss
            profiler.stop('bw')

            iteration += 1
            profiler.bump('train')
            if is_master and (profiler.totals['train'] > 60
                              or iteration == iterations):
                focal_loss = torch.stack(list(cls_losses)).mean().item()
                box_loss = torch.stack(list(box_losses)).mean().item()
                learning_rate = optimizer.param_groups[0]['lr']
                if verbose:
                    msg = '[{:{len}}/{}]'.format(iteration,
                                                 iterations,
                                                 len=len(str(iterations)))
                    msg += ' focal loss: {:.3f}'.format(focal_loss)
                    msg += ', box loss: {:.3f}'.format(box_loss)
                    msg += ', {:.3f}s/{}-batch'.format(profiler.means['train'],
                                                       batch_size)
                    msg += ' (fw: {:.3f}s, bw: {:.3f}s)'.format(
                        profiler.means['fw'], profiler.means['bw'])
                    msg += ', {:.1f} im/s'.format(batch_size /
                                                  profiler.means['train'])
                    msg += ', lr: {:.2g}'.format(learning_rate)
                    print(msg, flush=True)

                if is_master and logdir is not None:
                    writer.add_scalar('focal_loss', focal_loss, iteration)
                    writer.add_scalar('box_loss', box_loss, iteration)
                    writer.add_scalar('learning_rate', learning_rate,
                                      iteration)
                    del box_loss, focal_loss

                if metrics_url:
                    post_metrics(
                        metrics_url, {
                            'focal loss': mean(cls_losses),
                            'box loss': mean(box_losses),
                            'im_s': batch_size / profiler.means['train'],
                            'lr': learning_rate
                        })

                # Save model weights
                state.update({
                    'iteration': iteration,
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                })
                with ignore_sigint():
                    nn_model.save(state)

                profiler.reset()
                del cls_losses[:], box_losses[:]

            if val_annotations and (iteration == iterations
                                    or iteration % val_iterations == 0):
                stats = infer(model,
                              val_path,
                              None,
                              resize,
                              max_size,
                              batch_size,
                              annotations=val_annotations,
                              mixed_precision=mixed_precision,
                              is_master=is_master,
                              world=world,
                              use_dali=use_dali,
                              no_apex=no_apex,
                              is_validation=True,
                              verbose=False,
                              rotated_bbox=rotated_bbox)
                model.train()
                if is_master and logdir is not None and stats is not None:
                    writer.add_scalar('Validation_Precision/mAP', stats[0],
                                      iteration)
                    writer.add_scalar('Validation_Precision/[email protected]',
                                      stats[1], iteration)
                    writer.add_scalar('Validation_Precision/[email protected]',
                                      stats[2], iteration)
                    writer.add_scalar('Validation_Precision/mAP (small)',
                                      stats[3], iteration)
                    writer.add_scalar('Validation_Precision/mAP (medium)',
                                      stats[4], iteration)
                    writer.add_scalar('Validation_Precision/mAP (large)',
                                      stats[5], iteration)
                    writer.add_scalar('Validation_Recall/mAR (max 1 Dets)',
                                      stats[6], iteration)
                    writer.add_scalar('Validation_Recall/mAR (max 10 Dets)',
                                      stats[7], iteration)
                    writer.add_scalar('Validation_Recall/mAR (max 100 Dets)',
                                      stats[8], iteration)
                    writer.add_scalar('Validation_Recall/mAR (small)',
                                      stats[9], iteration)
                    writer.add_scalar('Validation_Recall/mAR (medium)',
                                      stats[10], iteration)
                    writer.add_scalar('Validation_Recall/mAR (large)',
                                      stats[11], iteration)

            if (iteration == iterations
                    and not rotated_bbox) or (iteration > iterations
                                              and rotated_bbox):
                break

    if is_master and logdir is not None:
        writer.close()
Exemplo n.º 24
0
class Trainer:
    def __init__(self,
                 model: nn.Module,
                 dataset_root: str,
                 summary_writer: SummaryWriter,
                 device: Device,
                 batch_size: int = 128,
                 cc_loss: bool = False):
        # load train/test splits of SALICON dataset
        train_dataset = Salicon(dataset_root + "train.pkl")
        test_dataset = Salicon(dataset_root + "val.pkl")

        self.train_loader = DataLoader(
            train_dataset,
            shuffle=True,
            batch_size=batch_size,
            pin_memory=True,
            num_workers=1,
        )
        self.val_loader = DataLoader(
            test_dataset,
            shuffle=False,
            batch_size=batch_size,
            num_workers=1,
            pin_memory=True,
        )
        self.model = model.to(device)
        self.device = device
        if cc_loss:
            self.criterion = CCLoss
        else:
            self.criterion = nn.MSELoss()
        self.optimizer = SGD(self.model.parameters(),
                             lr=0.03,
                             momentum=0.9,
                             weight_decay=0.0005,
                             nesterov=True)
        self.summary_writer = summary_writer
        self.step = 0

    def train(self,
              epochs: int,
              val_frequency: int,
              log_frequency: int = 5,
              start_epoch: int = 0):
        lrs = np.linspace(0.03, 0.0001, epochs)
        for epoch in range(start_epoch, epochs):
            self.model.train()
            for batch, gts in self.train_loader:
                # LR decay
                # need to update learning rate between 0.03 and 0.0001 (according to paper)
                optimstate = self.optimizer.state_dict()
                self.optimizer = SGD(self.model.parameters(),
                                     lr=lrs[epoch],
                                     momentum=0.9,
                                     weight_decay=0.0005,
                                     nesterov=True)
                self.optimizer.load_state_dict(optimstate)

                self.optimizer.zero_grad()
                # load batch to device
                batch = batch.to(self.device)
                gts = gts.to(self.device)

                # train step
                step_start_time = time.time()
                output = self.model.forward(batch)
                loss = self.criterion(output, gts)
                loss.backward()
                self.optimizer.step()

                # log step
                if ((self.step + 1) % log_frequency) == 0:
                    step_time = time.time() - step_start_time
                    self.log_metrics(epoch, loss, step_time)
                    self.print_metrics(epoch, loss, step_time)

                # count steps
                self.step += 1

            # log epoch
            self.summary_writer.add_scalar("epoch", epoch, self.step)

            # validate
            if ((epoch + 1) % val_frequency) == 0:
                self.validate()
                self.model.train()
            if (epoch + 1) % 10 == 0:
                save(self.model, "checkp_model.pkl")

    def print_metrics(self, epoch, loss, step_time):
        epoch_step = self.step % len(self.train_loader)
        print(f"epoch: [{epoch}], "
              f"step: [{epoch_step}/{len(self.train_loader)}], "
              f"batch loss: {loss:.5f}, "
              f"step time: {step_time:.5f}")

    def log_metrics(self, epoch, loss, step_time):
        self.summary_writer.add_scalar("epoch", epoch, self.step)

        self.summary_writer.add_scalars("loss", {"train": float(loss.item())},
                                        self.step)
        self.summary_writer.add_scalar("time/data", step_time, self.step)

    def validate(self):
        results = {"preds": [], "gts": []}
        total_loss = 0
        self.model.eval()

        # No need to track gradients for validation, we're not optimizing.
        with no_grad():
            for batch, gts in self.val_loader:
                batch = batch.to(self.device)
                gts = gts.to(self.device)
                output = self.model(batch)
                loss = self.criterion(output, gts)
                total_loss += loss.item()
                preds = output.cpu().numpy()
                results["preds"].extend(list(preds))
                results["gts"].extend(list(gts.cpu().numpy()))

        average_loss = total_loss / len(self.val_loader)

        self.summary_writer.add_scalars("loss", {"test": average_loss},
                                        self.step)
        print(f"validation loss: {average_loss:.5f}")
Exemplo n.º 25
0
def prologue(args):
    if not hasattr(args, 'id') or args.id is None:
        args.id = np.random.randint(10000)
    args.outdir = args.outdir + f"/{args.arch}/{args.id}/"
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # Copies files to the outdir to store complete script with each experiment
    copy_code(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=args.batch,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=args.batch,
                             num_workers=args.workers,
                             pin_memory=pin_memory)

    if args.pretrained_model != '':
        assert args.arch == 'cifar_resnet110', 'Unsupported architecture for pretraining'
        checkpoint = torch.load(args.pretrained_model)
        model = get_architecture(checkpoint["arch"], args.dataset)
        model.load_state_dict(checkpoint['state_dict'])
        model[1].fc = nn.Linear(64, get_num_classes('cifar10')).to(device)
    else:
        model = get_architecture(args.arch, args.dataset)

    logfilename = os.path.join(args.outdir, 'log.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")
    writer = SummaryWriter(args.outdir)

    criterion = CrossEntropyLoss().to(device)
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)
    starting_epoch = 0

    # Load latest checkpoint if exists (to handle philly failures)
    model_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
    if args.resume:
        if os.path.isfile(model_path):
            print("=> loading checkpoint '{}'".format(model_path))
            checkpoint = torch.load(model_path,
                                    map_location=lambda storage, loc: storage)
            starting_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                model_path, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(model_path))

    return train_loader, test_loader, criterion, model, optimizer, scheduler, \
           starting_epoch, logfilename, model_path, device, writer
Exemplo n.º 26
0
    ])

    # Construct validation dataset and loader
    val_dataset = RecognitionDataset(eval_root,
                                     eval_indices,
                                     eval_files,
                                     RecognitionDataset.VAL,
                                     transform=eval_transform)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    # Start training
    start_time = time.time()
    if os.path.exists(save):
        saved_checkpoint = torch.load(save)
        start_epoch = saved_checkpoint['last_epoch'] + 1
        sgd.load_state_dict(saved_checkpoint['optimizer'])
        crit.load_state_dict(saved_checkpoint['criterion'])
    else:
        start_epoch = 1

    for epoch in range(start_epoch, 31):
        vgg16d, sgd, crit = train(vgg16d, sgd, crit, epoch, train_loader,
                                  val_loader, save, best)
    end_time = time.time()
    out(f'VGG base recognition training elapsed for {end_time - start_time} seconds'
        )

    # Construct test dataset and loader
    test_dataset = RecognitionDataset(eval_root,
                                      eval_indices,
                                      eval_files,
Exemplo n.º 27
0
def make_optimizer_and_schedule(args, model, checkpoint, params):
    """
    *Internal Function* (called directly from train_model)

    Creates an optimizer and a schedule for a given model, restoring from a
    checkpoint if it is non-null.

    Args:
        args (object) : an arguments object, see
            :meth:`~robustness.train.train_model` for details
        model (AttackerModel) : the model to create the optimizer for
        checkpoint (dict) : a loaded checkpoint saved by this library and loaded
            with `ch.load`
        params (list|None) : a list of parameters that should be updatable, all
            other params will not update. If ``None``, update all params

    Returns:
        An optimizer (ch.nn.optim.Optimizer) and a scheduler
            (ch.nn.optim.lr_schedulers module).
    """
    # Make optimizer
    param_list = model.parameters() if params is None else params
    optimizer = SGD(param_list, args.lr, momentum=args.momentum,
                    weight_decay=args.weight_decay)

    if args.mixed_precision:
        model.to('cuda')
        model, optimizer = amp.initialize(model, optimizer, 'O1')

    # Make schedule
    schedule = None
    if args.custom_lr_multiplier == 'cyclic':
        eps = args.epochs
        lr_func = lambda t: np.interp([t], [0, eps*4//15, eps], [0, 1, 0])[0]
        schedule = lr_scheduler.LambdaLR(optimizer, lr_func)
    elif args.custom_lr_multiplier:
        cs = args.custom_lr_multiplier
        periods = eval(cs) if type(cs) is str else cs
        if args.lr_interpolation == 'linear':
            lr_func = lambda t: np.interp([t], *zip(*periods))[0]
        else:
            def lr_func(ep):
                for (milestone, lr) in reversed(periods):
                    if ep >= milestone: return lr
                return 1.0
        schedule = lr_scheduler.LambdaLR(optimizer, lr_func)
    elif args.step_lr:
        schedule = lr_scheduler.StepLR(optimizer, step_size=args.step_lr, gamma=args.step_lr_gamma)

    # Fast-forward the optimizer and the scheduler if resuming
    if checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
        try:
            schedule.load_state_dict(checkpoint['schedule'])
        except:
            steps_to_take = checkpoint['epoch']
            print('Could not load schedule (was probably LambdaLR).'
                  f' Stepping {steps_to_take} times instead...')
            for i in range(steps_to_take):
                schedule.step()
        
        if 'amp' in checkpoint and checkpoint['amp'] not in [None, 'N/A']:
            amp.load_state_dict(checkpoint['amp'])

        # TODO: see if there's a smarter way to do this
        # TODO: see what's up with loading fp32 weights and then MP training
        if args.mixed_precision:
            model.load_state_dict(checkpoint['model'])

    return optimizer, schedule
Exemplo n.º 28
0
class MetaFrameWork(object):
    def __init__(self, name='normal_all', train_num=1, source='GSIM',
                 target='C', network=Net, resume=True, dataset=DGMetaDataSets,
                 inner_lr=1e-3, outer_lr=5e-3, train_size=8, test_size=16, no_source_test=True, bn='torch'):
        super(MetaFrameWork, self).__init__()
        self.no_source_test = no_source_test
        self.train_num = train_num
        self.exp_name = name
        self.resume = resume

        self.inner_update_lr = inner_lr
        self.outer_update_lr = outer_lr
        self.network = network
        self.dataset = dataset
        self.train_size = train_size
        self.test_size = test_size
        self.source = source
        self.target = target
        self.bn = bn

        self.epoch = 1
        self.best_target_acc = 0
        self.best_target_acc_source = 0
        self.best_target_epoch = 1

        self.best_source_acc = 0
        self.best_source_acc_target = 0
        self.best_source_epoch = 0

        self.total_epoch = 120
        self.save_interval = 1
        self.save_path = Path(self.exp_name)
        self.init()

    def init(self):
        kwargs = {'bn': self.bn, 'output_stride': 8}
        self.backbone = nn.DataParallel(self.network(**kwargs)).cuda()
        kwargs.update({'pretrained': False})
        self.updated_net = nn.DataParallel(self.network(**kwargs)).cuda()
        self.ce = nn.CrossEntropyLoss(ignore_index=-1)
        self.nim = NaturalImageMeasure(nclass=19)

        batch_size = self.train_size
        workers = len(self.source) * 4

        dataloader = functools.partial(DataLoader, num_workers=workers, pin_memory=True, batch_size=batch_size, shuffle=True)
        self.train_loader = dataloader(self.dataset(mode='train', domains=self.source, force_cache=True))

        dataloader = functools.partial(DataLoader, num_workers=workers, pin_memory=True, batch_size=self.test_size, shuffle=False)
        self.source_val_loader = dataloader(self.dataset(mode='val', domains=self.source, force_cache=True))

        target_dataset, folder = get_dataset(self.target)
        self.target_loader = dataloader(target_dataset(root=ROOT + folder, mode='val'))
        self.target_test_loader = dataloader(target_dataset(root=ROOT + 'cityscapes', mode='test'))

        self.opt_old = SGD(self.backbone.parameters(), lr=self.outer_update_lr, momentum=0.9, weight_decay=5e-4)
        self.scheduler_old = PolyLR(self.opt_old, self.total_epoch, len(self.train_loader), 0, True, power=0.9)

        self.logger = get_logger('train', self.exp_name)
        self.log('exp_name : {}, train_num = {}, source domains = {}, target_domain = {}, lr : inner = {}, outer = {},'
                 'dataset : {}, net : {}, bn : {}\n'.
                 format(self.exp_name, self.train_num, self.source, self.target, self.inner_update_lr, self.outer_update_lr, self.dataset,
                        self.network, self.bn))
        self.log(self.exp_name + '\n')
        self.train_timer, self.test_timer = Timer(), Timer()

    def train(self, epoch, it, inputs):
        # imgs : batch x domains x C x H x W
        # targets : batch x domains x 1 x H x W
        imgs, targets = inputs
        B, D, C, H, W = imgs.size()
        meta_train_imgs = imgs.view(-1, C, H, W)
        meta_train_targets = targets.view(-1, 1, H, W)

        tr_logits = self.backbone(meta_train_imgs)[0]
        tr_logits = make_same_size(tr_logits, meta_train_targets)
        ds_loss = self.ce(tr_logits, meta_train_targets[:, 0])
        with torch.no_grad():
            self.nim(tr_logits, meta_train_targets)

        self.opt_old.zero_grad()
        ds_loss.backward()
        self.opt_old.step()
        self.scheduler_old.step(epoch, it)
        losses = {
            'dg': 0,
            'ds': ds_loss.item()
        }
        acc = {
            'iou': self.nim.get_res()[0]
        }
        return losses, acc, self.scheduler_old.get_lr(epoch, it)[0]

    def meta_train(self, epoch, it, inputs):
        # imgs : batch x domains x C x H x W
        # targets : batch x domains x 1 x H x W

        imgs, targets = inputs
        B, D, C, H, W = imgs.size()
        split_idx = np.random.permutation(D)
        i = np.random.randint(1, D)
        train_idx = split_idx[:i]
        test_idx = split_idx[i:]
        # train_idx = split_idx[:D // 2]
        # test_idx = split_idx[D // 2:]

        # self.print(split_idx, B, D, C, H, W)'
        meta_train_imgs = imgs[:, train_idx].reshape(-1, C, H, W)
        meta_train_targets = targets[:, train_idx].reshape(-1, 1, H, W)
        meta_test_imgs = imgs[:, test_idx].reshape(-1, C, H, W)
        meta_test_targets = targets[:, test_idx].reshape(-1, 1, H, W)

        # Meta-Train
        tr_logits = self.backbone(meta_train_imgs)[0]
        tr_logits = make_same_size(tr_logits, meta_train_targets)
        ds_loss = self.ce(tr_logits, meta_train_targets[:, 0])

        # Update new network
        self.opt_old.zero_grad()
        ds_loss.backward(retain_graph=True)
        updated_net = get_updated_network(self.backbone, self.updated_net, self.inner_update_lr).train().cuda()

        # Meta-Test
        te_logits = updated_net(meta_test_imgs)[0]
        # te_logits = test_res[0]
        te_logits = make_same_size(te_logits, meta_test_targets)
        dg_loss = self.ce(te_logits, meta_test_targets[:, 0])
        with torch.no_grad():
            self.nim(te_logits, meta_test_targets)

        # Update old network
        dg_loss.backward()
        self.opt_old.step()
        self.scheduler_old.step(epoch, it)
        losses = {
            'dg': dg_loss.item(),
            'ds': ds_loss.item()
        }
        acc = {
            'iou': self.nim.get_res()[0],
        }
        return losses, acc, self.scheduler_old.get_lr(epoch, it)[0]

    def do_train(self):
        if self.resume:
            self.load()

        self.writer = SummaryWriter(str(self.save_path / 'tensorboard'), filename_suffix=time.strftime('_%Y-%m-%d_%H-%M-%S'))
        self.log('Start epoch : {}\n'.format(self.epoch))

        for epoch in range(self.epoch, self.total_epoch + 1):
            loss_meters, acc_meters = MeterDicts(), MeterDicts(averaged=['iou'])
            self.nim.clear_cache()
            self.backbone.train()
            self.epoch = epoch
            with self.train_timer:
                for it, (paths, imgs, target) in enumerate(self.train_loader):
                    meta = (it + 1) % self.train_num == 0
                    if meta:
                        losses, acc, lr = self.meta_train(epoch - 1, it, to_cuda([imgs, target]))
                    else:
                        losses, acc, lr = self.train(epoch - 1, it, to_cuda([imgs, target]))

                    loss_meters.update_meters(losses, skips=['dg'] if not meta else [])
                    acc_meters.update_meters(acc)

                    self.print(self.get_string(epoch, it, loss_meters, acc_meters, lr, meta), end='')
                    self.tfb_log(epoch, it, loss_meters, acc_meters)
            self.print(self.train_timer.get_formatted_duration())
            self.log(self.get_string(epoch, it, loss_meters, acc_meters, lr, meta) + '\n')

            self.save('ckpt')
            if epoch % self.save_interval == 0:
                with self.test_timer:
                    city_acc = self.val(self.target_loader)
                    self.save_best(city_acc, epoch)

            total_duration = self.train_timer.duration + self.test_timer.duration
            self.print('Time Left : ' + self.train_timer.get_formatted_duration(total_duration * (self.total_epoch - epoch)) + '\n')

        self.log('Best city acc : \n  city : {}, origin : {}, epoch : {}\n'.format(
            self.best_target_acc, self.best_target_acc_source, self.best_target_epoch))
        self.log('Best origin acc : \n  city : {}, origin : {}, epoch : {}\n'.format(
            self.best_source_acc_target, self.best_source_acc, self.best_source_epoch))

    def save_best(self, city_acc, epoch):
        self.writer.add_scalar('acc/citys', city_acc, epoch)
        if not self.no_source_test:
            origin_acc = self.val(self.source_val_loader)
            self.writer.add_scalar('acc/origin', origin_acc, epoch)
        else:
            origin_acc = 0

        self.writer.flush()
        if city_acc > self.best_target_acc:
            self.best_target_acc = city_acc
            self.best_target_acc_source = origin_acc
            self.best_target_epoch = epoch
            self.save('best_city')

        if origin_acc > self.best_source_acc:
            self.best_source_acc = origin_acc
            self.best_source_acc_target = city_acc
            self.best_source_epoch = epoch
            self.save('best_origin')

    def val(self, dataset):
        self.backbone.eval()
        with torch.no_grad():
            self.nim.clear_cache()
            self.nim.set_max_len(len(dataset))
            for p, img, target in dataset:
                img, target = to_cuda(get_img_target(img, target))
                logits = self.backbone(img)[0]
                self.nim(logits, target)
        self.log('\nNormal validation : {}\n'.format(self.nim.get_acc()))
        if hasattr(dataset.dataset, 'format_class_iou'):
            self.log(dataset.dataset.format_class_iou(self.nim.get_class_acc()[0]) + '\n')
        return self.nim.get_acc()[0]

    def target_specific_val(self, loader):
        self.nim.clear_cache()
        self.nim.set_max_len(len(loader))
        # eval for dropout
        self.backbone.module.remove_dropout()
        self.backbone.module.not_track()
        for idx, (p, img, target) in enumerate(loader):
            if len(img.size()) == 5:
                B, D, C, H, W = img.size()
            else:
                B, C, H, W = img.size()
                D = 1
            img, target = to_cuda([img.reshape(B, D, C, H, W), target.reshape(B, D, 1, H, W)])
            for d in range(img.size(1)):
                img_d, target_d, = img[:, d], target[:, d]
                self.backbone.train()
                with torch.no_grad():
                    new_logits = self.backbone(img_d)[0]
                    self.nim(new_logits, target_d)

        self.backbone.module.recover_dropout()
        self.log('\nTarget specific validation : {}\n'.format(self.nim.get_acc()))
        if hasattr(loader.dataset, 'format_class_iou'):
            self.log(loader.dataset.format_class_iou(self.nim.get_class_acc()[0]) + '\n')
        return self.nim.get_acc()[0]

    def predict_target(self, load_path='best_city', color=False, train=False, output_path='predictions'):
        self.load(load_path)
        import skimage.io as skio
        dataset = self.target_test_loader

        output_path = Path(self.save_path / output_path)
        output_path.mkdir(exist_ok=True)

        if train:
            self.backbone.module.remove_dropout()
            self.backbone.train()
        else:
            self.backbone.eval()

        with torch.no_grad():
            self.nim.clear_cache()
            self.nim.set_max_len(len(dataset))
            for names, img, target in tqdm(dataset):
                img = to_cuda(img)
                logits = self.backbone(img)[0]
                logits = F.interpolate(logits, img.size()[2:], mode='bilinear', align_corners=True)
                preds = get_prediction(logits).cpu().numpy()
                if color:
                    trainId_preds = preds
                else:
                    trainId_preds = dataset.dataset.predict(preds)

                for pred, name in zip(trainId_preds, names):
                    file_name = name.split('/')[-1]
                    if color:
                        pred = class_map_2_color_map(pred).transpose(1, 2, 0).astype(np.uint8)
                    skio.imsave(str(output_path / file_name), pred)

    def get_string(self, epoch, it, loss_meters, acc_meters, lr, meta):
        string = '\repoch {:4}, iter : {:4}, '.format(epoch, it)
        for k, v in loss_meters.items():
            string += k + ' : {:.4f}, '.format(v.avg)
        for k, v in acc_meters.items():
            string += k + ' : {:.4f}, '.format(v.avg)

        string += 'lr : {:.6f}, meta : {}'.format(lr, meta)
        return string

    def log(self, strs):
        self.logger.info(strs)

    def print(self, strs, **kwargs):
        print(strs, **kwargs)

    def tfb_log(self, epoch, it, losses, acc):
        iteration = epoch * len(self.train_loader) + it
        for k, v in losses.items():
            self.writer.add_scalar('loss/' + k, v.val, iteration)
        for k, v in acc.items():
            self.writer.add_scalar('acc/' + k, v.val, iteration)

    def save(self, name='ckpt'):
        info = [self.best_source_acc, self.best_source_acc_target, self.best_source_epoch,
                self.best_target_acc, self.best_target_acc_source, self.best_target_epoch]
        dicts = {
            'backbone': self.backbone.state_dict(),
            'opt': self.opt_old.state_dict(),
            'epoch': self.epoch + 1,
            'best': self.best_target_acc,
            'info': info
        }
        self.print('Saving epoch : {}'.format(self.epoch))
        torch.save(dicts, self.save_path / '{}.pth'.format(name))

    def load(self, path=None, strict=False):
        if path is None:
            path = self.save_path / 'ckpt.pth'
        else:
            if 'pth' in path:
                path = path
            else:
                path = self.save_path / '{}.pth'.format(path)

        try:
            dicts = torch.load(path, map_location='cpu')
            msg = self.backbone.load_state_dict(dicts['backbone'], strict=strict)
            self.print(msg)
            if 'opt' in dicts:
                self.opt_old.load_state_dict(dicts['opt'])
            if 'epoch' in dicts:
                self.epoch = dicts['epoch']
            else:
                self.epoch = 1
            if 'best' in dicts:
                self.best_target_acc = dicts['best']
            if 'info' in dicts:
                self.best_source_acc, self.best_source_acc_target, self.best_source_epoch, \
                self.best_target_acc, self.best_target_acc_source, self.best_target_epoch = dicts['info']
            self.log('Loaded from {}, next epoch : {}, best_target : {}, best_epoch : {}\n'
                     .format(str(path), self.epoch, self.best_target_acc, self.best_target_epoch))
            return True
        except Exception as e:
            self.print(e)
            self.log('No ckpt found in {}\n'.format(str(path)))
            self.epoch = 1
            return False
Exemplo n.º 29
0
    def _train(save_iter=None, save_epoch=None, sd=None):
        w_norms = []
        grad_norms = []
        data = []
        chkpt = []

        manual_seed(12)
        arch = [
            nn.Conv2d(3, 10, 3),
            nn.ReLU(),
            nn.Conv2d(10, 10, 3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(10, 5),
            nn.ReLU(),
            nn.Linear(5, 2),
        ]
        if with_dropout:
            arch.insert(2, nn.Dropout2d())
            arch.insert(-2, nn.Dropout())

        model = nn.Sequential(*arch).to(device)
        opt = SGD(model.parameters(), lr=0.001)

        def proc_fn(e, b):
            from ignite.engine.deterministic import _get_rng_states, _repr_rng_state

            s = _repr_rng_state(_get_rng_states())
            model.train()
            opt.zero_grad()
            y = model(b.to(device))
            y.sum().backward()
            opt.step()
            if debug:
                print(trainer.state.iteration, trainer.state.epoch,
                      "proc_fn - b.shape", b.shape,
                      torch.norm(y).item(), s)

        trainer = DeterministicEngine(proc_fn)

        if save_iter is not None:
            ev = Events.ITERATION_COMPLETED(once=save_iter)
        elif save_epoch is not None:
            ev = Events.EPOCH_COMPLETED(once=save_epoch)
            save_iter = save_epoch * (data_size // batch_size)

        @trainer.on(ev)
        def save_chkpt(_):
            if debug:
                print(trainer.state.iteration, "save_chkpt")
            fp = dirname / "test.pt"
            from ignite.engine.deterministic import _repr_rng_state

            tsd = trainer.state_dict()
            if debug:
                print("->", _repr_rng_state(tsd["rng_states"]))
            torch.save([model.state_dict(), opt.state_dict(), tsd], fp)
            chkpt.append(fp)

        def log_event_filter(_, event):
            if (event // save_iter == 1) and 1 <= (event % save_iter) <= 5:
                return True
            return False

        @trainer.on(Events.ITERATION_COMPLETED(event_filter=log_event_filter))
        def write_data_grads_weights(e):
            x = e.state.batch
            i = e.state.iteration
            data.append([i, x.mean().item(), x.std().item()])

            total = [0.0, 0.0]
            out1 = []
            out2 = []
            for p in model.parameters():
                n1 = torch.norm(p).item()
                n2 = torch.norm(p.grad).item()
                out1.append(n1)
                out2.append(n2)
                total[0] += n1
                total[1] += n2
            w_norms.append([i, total[0]] + out1)
            grad_norms.append([i, total[1]] + out2)

        if sd is not None:
            sd = torch.load(sd)
            model.load_state_dict(sd[0])
            opt.load_state_dict(sd[1])
            from ignite.engine.deterministic import _repr_rng_state

            if debug:
                print("-->", _repr_rng_state(sd[2]["rng_states"]))
            trainer.load_state_dict(sd[2])

        manual_seed(32)
        trainer.run(random_train_data_loader(size=data_size), max_epochs=5)
        return {
            "sd": chkpt,
            "data": data,
            "grads": grad_norms,
            "weights": w_norms
        }
Exemplo n.º 30
0
def train_then_test(**kwargs):
    """
    a DAN test and train
    """

    # Inputs ##################################################################
    #  Files
    directory = kwargs["directory"]
    exp = kwargs["exp"]
    direxp = directory + "/" + exp
    #  Dimensions
    x_dim = kwargs.get("x_dim", 40)  # state dim
    h_dim = kwargs.get("h_dim", 20 * x_dim)  # mem dim
    batch_sizes = kwargs.get("batch_size", {"train": 1024, "test": 1})
    burn = kwargs.get("burn", 10**3)  # skip any transiant regime
    #  Controls
    modes = kwargs.get("modes", ["train", "test"])
    load_weights = kwargs.get("load_weights", False)
    append_outputs = kwargs.get("append_outputs", {
        "train": False,
        "test": False
    })
    load_state = kwargs.get("load_state", {"train": [False], "test": [False]})
    seeds = kwargs.get("seeds", {"train": [0], "test": [1]})
    kwargs["ktest"]["batch_size"] = kwargs["batch_sizes"]["test"]
    kwargs["ktrain"]["batch_size"] = kwargs["batch_sizes"]["train"]

    # Outputs #################################################################
    # the outputs of test and train
    ###########################################################################

    # DAN initialization
    net = DAN(**kwargs["kDAN"])
    if load_weights:
        net.load_state_dict(torch.load(direxp + "_state_dict.pt"))

    # Optimizer initialization
    optidict = kwargs.get("optidict", {"optimizer": "Adam", "lr": 10**-4})
    if optidict["optimizer"] == "SGD":
        optimizer = SGD(net.parameters(),
                        lr=optidict["lr"],
                        momentum=optidict["momentum"],
                        nesterov=optidict["nesterov"])
    elif optidict["optimizer"] == "Adam":
        optimizer = Adam(net.parameters(), lr=optidict["lr"])

    i = {"train": 0, "test": 0}
    for mode in modes:
        # set seeds
        seed = seeds[mode][i[mode]]
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)
        rnd.seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        if append_outputs[mode]:
            # load previous outputs dict
            for key, val in kwargs["k" + mode]["outputs"].items():
                val = np.load(direxp + "_" + mode + "_" + key + ".npy")
        else:
            append_outputs[mode] = True

        if load_state[mode][i[mode]]:
            # load mem and truth
            h =\
                torch.load(direxp+"_hidden_state_"+mode+".pt")
            x =\
                torch.load(direxp+"_x_truth_"+mode+".pt")
        else:
            # clear mem and truth
            h =\
                torch.zeros(batch_sizes[mode], h_dim)
            x =\
                M(3*torch.ones(batch_sizes[mode], x_dim) +
                  torch.randn(batch_sizes[mode], x_dim), burn)

        if mode == "train":
            print("Launch " + mode)
            if optidict.get("load", False) and (i[mode] != 0):
                optimizer.load_state_dict(
                    torch.load(direxp + "_opt_state_dict.pt"))
            train(net=net,
                  hidden_state=h,
                  x_truth=x,
                  optimizer=optimizer,
                  **kwargs["k" + mode])
        elif mode == "test":
            print("Launch " + mode)
            test(net=net, hidden_state=h, x_truth=x, **kwargs["k" + mode])
        i[mode] += 1