示例#1
0
    print(
        '| End of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid perplexity {:8.2f}'
        .format(epoch, epoch_duration, val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        # delete previous checkpoint
        if model_save_path is not None:
            os.remove(model_save_path)
        # save current state
        model_save_path = os.path.join(exp_dir, f'model_{epoch}.pt')
        torch.save(
            {
                'model': model.state_dict(),
                'optim': optimizer.state_dict(),
                'epoch': epoch,
                'step': steps
            }, model_save_path)
        epochs_wo_improvement = 0
    else:
        epochs_wo_improvement += 1

    stats = {
        'train_losses': train_losses,
        'valid_losses': val_losses,
        'mean_epoch_duration': np.mean(durations)
    }
    with open(os.path.join(exp_dir, f'stats.pkl'), 'wb') as f:
        pickle.dump(stats, f)
示例#2
0
def main():

    global args
    best_prec1, best_epoch = 0.0, 0

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.data.startswith('cifar'):
        IM_SIZE = 32
    else:
        IM_SIZE = 224

    model = getattr(models, args.arch)(args)
    n_flops, n_params = measure_model(model, IM_SIZE, IM_SIZE)    
    torch.save(n_flops, os.path.join(args.save, 'flops.pth'))
    del(model)
        
        
    model = getattr(models, args.arch)(args)

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'radam':
        from radam import RAdam
        optimizer = RAdam(model.parameters(), args.lr,
                          weight_decay=args.weight_decay)
    else:
        raise NotImplementedError("Wrong optimizer.")
    

    if args.resume:
        checkpoint = load_checkpoint(args)
        if checkpoint is not None:
            args.start_epoch = checkpoint['epoch'] + 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    train_loader, val_loader, test_loader = get_dataloaders(args)

    if args.evalmode is not None:
        state_dict = torch.load(args.evaluate_from)['state_dict']
        model.load_state_dict(state_dict)

        if args.evalmode == 'anytime':
            validate(test_loader, model, criterion)
        else:
            dynamic_evaluate(model, test_loader, val_loader, args)
        return

    scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_prec1'
              '\tval_prec1\ttrain_prec5\tval_prec5']

    for epoch in range(args.start_epoch, args.epochs):

        train_loss, train_prec1, train_prec5, lr = train(train_loader, model, criterion, optimizer, epoch)

        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)

        scores.append(('{}\t{:.3f}' + '\t{:.4f}' * 6)
                      .format(epoch, lr, train_loss, val_loss,
                              train_prec1, val_prec1, train_prec5, val_prec5))

        is_best = val_prec1 > best_prec1
        if is_best:
            best_prec1 = val_prec1
            best_epoch = epoch
            print('Best var_prec1 {}'.format(best_prec1))

        model_filename = 'checkpoint_%03d.pth.tar' % epoch
        save_checkpoint({
            'epoch': epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, args, is_best, model_filename, scores)

    print('Best val_prec1: {:.4f} at epoch {}'.format(best_prec1, best_epoch))

    ### Test the final model

    print('********** Final prediction results **********')
    validate(test_loader, model, criterion)

    return 
示例#3
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    elif opt.Prediction == 'None':
        converter = TransformerConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    # model = torch.nn.DataParallel(model).to(device)
    model = model.to(device)
    model.train()
    if opt.load_from_checkpoint:
        model.load_state_dict(torch.load(os.path.join(opt.load_from_checkpoint, 'checkpoint.pth')))
        print(f'loaded checkpoint from {opt.load_from_checkpoint}...')
    elif opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.SequenceModeling == 'Transformer':
            fe_state = OrderedDict()
            state_dict = torch.load(opt.saved_model)
            for k, v in state_dict.items():
                if k.startswith('module.FeatureExtraction'):
                    new_k = re.sub('module.FeatureExtraction.', '', k)
                    fe_state[new_k] = state_dict[k]
            model.FeatureExtraction.load_state_dict(fe_state)
        else:
            if opt.FT:
                model.load_state_dict(torch.load(opt.saved_model), strict=False)
            else:
                model.load_state_dict(torch.load(opt.saved_model))
    if opt.freeze_fe:
        model.freeze(['FeatureExtraction'])
    print("Model:")
    print(model)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    elif opt.Prediction == 'None':
        criterion = LabelSmoothingLoss(classes=converter.n_classes, padding_idx=converter.pad_idx, smoothing=0.1)
        # criterion = torch.nn.CrossEntropyLoss(ignore_index=converter.pad_idx)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        assert opt.adam in ['Adam', 'AdamW', 'RAdam'], 'adam optimizer must be in Adam, AdamW or RAdam'
        if opt.adam == 'Adam':
            optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
        elif opt.adam == "AdamW":
            optimizer = optim.AdamW(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
        else:
            optimizer = RAdam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    if opt.load_from_checkpoint and opt.load_optimizer_state:
        optimizer.load_state_dict(torch.load(os.path.join(opt.load_from_checkpoint, 'optimizer.pth')))
        print(f'loaded optimizer state from {os.path.join(opt.load_from_checkpoint, "optimizer.pth")}')

    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    if opt.load_from_checkpoint:
        with open(os.path.join(opt.load_from_checkpoint, 'iter.json'), mode='r', encoding='utf8') as f:
            start_iter = json.load(f)
            print(f'continue to train, start_iter: {start_iter}')
            f.close()

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    # i = start_iter

    bar = tqdm(range(start_iter, opt.num_iter))
    # while(True):
    for i in bar:
        bar.set_description(f'Iter {i}: train_loss = {loss_avg.val():.5f}')
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)

            # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
            torch.backends.cudnn.enabled = True

            # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
            # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
            # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0.
            # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707
            # cost = criterion(preds, text, preds_size, length)

        elif opt.Prediction == 'None':
            tgt_input = text['tgt_input']
            tgt_output = text['tgt_output']
            tgt_padding_mask = text['tgt_padding_mask']
            preds = model(image, tgt_input.transpose(0, 1), tgt_key_padding_mask=tgt_padding_mask,)
            cost = criterion(preds.view(-1, preds.shape[-1]), tgt_output.contiguous().view(-1))
        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (i + 1) % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')

                # checkpoint
                os.makedirs(f'./checkpoints/{opt.experiment_name}/', exist_ok=True)

                torch.save(model.state_dict(), f'./checkpoints/{opt.experiment_name}/checkpoint.pth')
                torch.save(optimizer.state_dict(), f'./checkpoints/{opt.experiment_name}/optimizer.pth')
                with open(f'./checkpoints/{opt.experiment_name}/iter.json', mode='w', encoding='utf8') as f:
                    json.dump(i + 1, f)
                    f.close()

                with open(f'./checkpoints/{opt.experiment_name}/checkpoint.log', mode='a', encoding='utf8') as f:
                    f.write(f'Saved checkpoint with iter={i}\n')
                    f.write(f'\tCheckpoint at: ./checkpoints/{opt.experiment_name}/checkpoint.pth')
                    f.write(f'\tOptimizer at: ./checkpoints/{opt.experiment_name}/optimizer.pth')

                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        # if i == opt.num_iter:
        #     print('end the training')
        #     sys.exit()
        # i += 1
        # if i == 1: break
    print('end training')
class Trainer(object):
    '''This class takes care of training and validation of our model'''
    def __init__(self, model):
        self.fold = args.fold
        self.total_folds = 5
        self.num_workers = 6
        self.batch_size = {
            "train": args.batch_size,
            "val": args.batch_size
        }  # 4
        self.accumulation_steps = 32 // self.batch_size['train']
        self.lr = args.learning_rate
        self.num_epochs = args.epochs
        self.best_loss = float("inf")
        self.best_dice = 0
        self.phases = ["train", "val"]
        self.device = torch.device("cuda:0")
        torch.set_default_tensor_type("torch.cuda.FloatTensor")
        self.net = model
        self.criterion = MixedLoss(10.0, 2.0)

        if args.swa is True:
            # base_opt = torch.optim.SGD(self.net.parameters(), lr=args.max_lr, momentum=args.momentum, weight_decay=args.weight_decay)
            base_opt = RAdam(self.net.parameters(), lr=self.lr)
            self.optimizer = SWA(base_opt,
                                 swa_start=38,
                                 swa_freq=1,
                                 swa_lr=args.min_lr)
            # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, scheduler_step, args.min_lr)
        else:
            if args.optimizer.lower() == 'adam':
                self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
            elif args.optimizer.lower() == 'radam':
                self.optimizer = RAdam(
                    self.net.parameters(), lr=self.lr
                )  # betas=(args.beta1, args.beta2),weight_decay=args.weight_decay
            elif args.optimizer.lower() == 'sgd':
                self.optimizer = torch.optim.SGD(
                    self.net.parameters(),
                    lr=args.max_lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)

        if args.scheduler.lower() == 'reducelronplateau':
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                               mode="min",
                                               patience=args.patience,
                                               verbose=True)
        elif args.scheduler.lower() == 'clr':
            self.scheduler = CyclicLR(self.optimizer,
                                      base_lr=self.lr,
                                      max_lr=args.max_lr)
        self.net = self.net.to(self.device)
        cudnn.benchmark = True
        self.dataloaders = {
            phase: provider(
                fold=args.fold,
                total_folds=5,
                data_folder=data_folder,
                df_path=train_rle_path,
                phase=phase,
                size=args.img_size_target,
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225),
                batch_size=self.batch_size[phase],
                num_workers=self.num_workers,
            )
            for phase in self.phases
        }
        self.losses = {phase: [] for phase in self.phases}
        self.iou_scores = {phase: [] for phase in self.phases}
        self.dice_scores = {phase: [] for phase in self.phases}
        self.kaggle_metric = {phase: [] for phase in self.phases}

    def forward(self, images, targets):
        images = images.to(self.device)
        masks = targets.to(self.device)
        outputs = self.net(images)
        loss = self.criterion(
            outputs, masks
        )  # weighted_lovasz  # lovasz_hinge(outputs, masks) # self.criterion(outputs, masks)
        return loss, outputs

    def iterate(self, epoch, phase):
        meter = Meter(phase, epoch)
        start = time.strftime("%H:%M:%S")
        print(f"Starting epoch: {epoch} | phase: {phase} | ⏰: {start}")
        batch_size = self.batch_size[phase]
        start = time.time()
        self.net.train(phase == "train")
        dataloader = self.dataloaders[phase]
        running_loss = 0.0
        total_batches = len(dataloader)
        tk0 = tqdm(dataloader, total=total_batches)
        self.optimizer.zero_grad()
        for itr, batch in enumerate(tk0):
            images, targets = batch
            loss, outputs = self.forward(images, targets)
            loss = loss / self.accumulation_steps
            if phase == "train":
                loss.backward()
                if (itr + 1) % self.accumulation_steps == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
            running_loss += loss.item()
            outputs = outputs.detach().cpu()
            meter.update(targets, outputs)
            tk0.set_postfix(loss=(running_loss / ((itr + 1))))
        if args.swa is True:
            self.optimizer.swap_swa_sgd()
        epoch_loss = (running_loss * self.accumulation_steps
                      ) / total_batches  # running_loss / total_batches
        dice, iou, scores, kaggle_metric = epoch_log(phase, epoch, epoch_loss,
                                                     meter,
                                                     start)  # kaggle_metric
        write_event(log, dice, loss=epoch_loss)
        self.losses[phase].append(epoch_loss)
        self.dice_scores[phase].append(dice)
        self.iou_scores[phase].append(iou)
        self.kaggle_metric[phase].append(kaggle_metric)
        torch.cuda.empty_cache()
        return epoch_loss, dice, iou, scores, kaggle_metric  # kaggle_metric

    def start(self):
        if os.path.exists(args.log_path + 'v' + str(args.version) + '/' +
                          str(args.fold)):
            shutil.rmtree(args.log_path + 'v' + str(args.version) + '/' +
                          str(args.fold))
        else:
            os.makedirs(args.log_path + 'v' + str(args.version) + '/' +
                        str(args.fold))
        writer = SummaryWriter(args.log_path + 'v' + str(args.version) + '/' +
                               str(args.fold))

        num_snapshot = 0
        best_acc = 0
        model_path = args.weights_path + 'v' + str(
            args.version) + '/' + save_model_name
        if os.path.exists(model_path):
            state = torch.load(model_path,
                               map_location=lambda storage, loc: storage)
            model.load_state_dict(state["state_dict"])  # ["state_dict"]
            epoch = state['epoch']
            self.best_loss = state['best_loss']
            self.best_dice = state['best_dice']
            state['state_dict'] = state['state_dict']
            state['optimizer'] = state['optimizer']
        else:
            epoch = 1
            self.best_loss = float('inf')
            self.best_dice = 0

        for epoch in range(epoch, self.num_epochs + 1):
            print('-' * 30, 'Epoch:', epoch, '-' * 30)
            train_loss, train_dice, train_iou, train_scores, train_kaggle_metric = self.iterate(
                epoch, "train")  # train_kaggle_metric
            state = {
                "epoch": epoch,
                "best_loss": self.best_loss,
                "best_dice": self.best_dice,
                "state_dict": self.net.state_dict(),
                "optimizer": self.optimizer.state_dict(),
            }
            try:
                val_loss, val_dice, val_iou, val_scores, val_kaggle_metric = self.iterate(
                    epoch, "val")  # val_kaggle_metric
                self.scheduler.step(val_loss)
                if val_loss < self.best_loss:
                    print("******** New optimal found, saving state ********")
                    state["best_loss"] = self.best_loss = val_loss
                    torch.save(state, model_path)
                    try:
                        scores = val_scores
                    except:
                        scores = 'None'
                if val_dice > self.best_dice:
                    print("******** Best Dice Score, saving state ********")
                    state["best_dice"] = self.best_dice = val_dice
                    best_dice__path = args.weights_path + 'v' + str(
                        args.version
                    ) + '/' + 'best_dice_' + basic_name + '.pth'
                    torch.save(state, best_dice__path)
                # if val_dice > best_acc:
                #     print("******** New optimal found, saving state ********")
                #     # state["best_acc"] = self.best_acc = val_dice
                #     best_acc = val_dice
                #     best_param = self.net.state_dict()

                # if (epoch + 1) % scheduler_step == 0:
                #     # torch.save(best_param, args.save_weight + args.weight_name + str(idx) + str(num_snapshot) + '.pth')
                #     save_model_name = basic_name + '.pth' # '_' +str(num_snapshot)
                #     torch.save(best_param, args.weights_path + 'v' + str(args.version) + '/' + save_model_name)
                #     # state
                #     try:
                #         scores = val_scores
                #     except:
                #         scores = 'None'
                #     optimizer = torch.optim.SGD(self.net.parameters(), lr=args.max_lr, momentum=args.momentum,
                #                                 weight_decay=args.weight_decay)
                #     self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, scheduler_step, args.min_lr)
                #     num_snapshot += 1
                #     best_acc = 0
                writer.add_scalars('loss', {
                    'train': train_loss,
                    'val': val_loss
                }, epoch)
                writer.add_scalars('dice_score', {
                    'train': train_dice,
                    'val': val_dice
                }, epoch)
                writer.add_scalars('IoU', {
                    'train': train_iou,
                    'val': val_iou
                }, epoch)
                writer.add_scalars('New_Dice', {
                    'train': train_kaggle_metric,
                    'val': val_kaggle_metric
                }, epoch)
            except KeyboardInterrupt:
                print('Ctrl+C, saving snapshot')
                torch.save(
                    state, args.weights_path + 'v' + str(args.version) + '/' +
                    save_model_name)
                print('done.')
            # writer.add_scalars('Accuracy', {'train': train_kaggle_metric, 'val': val_kaggle_metric}, epoch)

        # writer.export_scalars_to_json(args.log_path + 'v' + str(args.version) + '/' + basic_name + '.json')
        writer.close()
        return scores
示例#5
0
    # save checkpoint
    if ((epoch + 1) % check_point
            == 0) or (epoch
                      == (num_epoch - 1)) or epoch + 1 > 90 or bleu_score > 4:
        model_check_point = '%s/model_trainable_%d.pk' % (save_folder,
                                                          epoch + 1)
        optim_check_point = '%s/optim_trainable_%d.pkl' % (save_folder,
                                                           epoch + 1)
        loss_check_point = '%s/loss_trainable_%d.pkl' % (save_folder,
                                                         epoch + 1)
        epoch_check_point = '%s/epoch_trainable_%d.pkl' % (save_folder,
                                                           epoch + 1)
        bleu_check_point = '%s/bleu_trainable_%d.pkl' % (save_folder,
                                                         epoch + 1)
        torch.save(model.state_dict(), model_check_point)
        torch.save(optimizer.state_dict(), optim_check_point)
        torch.save(loss_values, loss_check_point)
        torch.save(epoch_values, epoch_check_point)
        torch.save(bleu_values, bleu_check_point)

    # save current best result
    if bleu_score > best_bleu:
        best_bleu = bleu_score
        print('current best bleu: %.4f' % best_bleu)
        model_check_point = '%s/model_best_%d.pk' % (save_folder, epoch + 1)
        optim_check_point = '%s/optim_best_%d.pkl' % (save_folder, epoch + 1)
        loss_check_point = '%s/loss_best_%d.pkl' % (save_folder, epoch + 1)
        epoch_check_point = '%s/epoch_best_%d.pkl' % (save_folder, epoch + 1)
        bleu_check_point = '%s/bleu_best_%d.pkl' % (save_folder, epoch + 1)
        torch.save(model.state_dict(), model_check_point)
        torch.save(optimizer.state_dict(), optim_check_point)
示例#6
0
class Trainer:
    def __init__(self,
                 model,
                 train_loader,
                 test_loader,
                 epochs=200,
                 batch_size=60,
                 run_id=0,
                 logs_dir='logs',
                 device='cpu',
                 saturation_device=None,
                 optimizer='None',
                 plot=True,
                 compute_top_k=False,
                 data_prallel=False,
                 conv_method='channelwise',
                 thresh=.99,
                 half_precision=False,
                 downsampling=None):
        self.saturation_device = device if saturation_device is None else saturation_device
        self.device = device
        self.model = model
        self.epochs = epochs
        self.plot = plot
        self.compute_top_k = compute_top_k

        if 'cuda' in device:
            cudnn.benchmark = True

        self.train_loader = train_loader
        self.test_loader = test_loader

        self.criterion = nn.CrossEntropyLoss()
        print('Checking for optimizer for {}'.format(optimizer))
        #optimizer = str(optimizer)
        if optimizer == "adam":
            print('Using adam')
            self.optimizer = optim.Adam(model.parameters())
        elif optimizer == "adam_lr":
            print("Using adam with higher learning rate")
            self.optimizer = optim.Adam(model.parameters(), lr=0.01)
        elif optimizer == 'adam_lr2':
            print('Using adam with to large learning rate')
            self.optimizer = optim.Adam(model.parameters(), lr=0.0001)
        elif optimizer == "SGD":
            print('Using SGD')
            self.optimizer = optim.SGD(model.parameters(),
                                       momentum=0.9,
                                       weight_decay=5e-4)
        elif optimizer == "LRS":
            print('Using LRS')
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=0.1,
                                       momentum=0.9,
                                       weight_decay=5e-4)
            self.lr_scheduler = optim.lr_scheduler.StepLR(
                self.optimizer, self.epochs // 3)
        elif optimizer == "radam":
            print('Using radam')
            self.optimizer = RAdam(model.parameters())
        else:
            raise ValueError('Unknown optimizer {}'.format(optimizer))
        self.opt_name = optimizer
        save_dir = os.path.join(logs_dir, model.name, train_loader.name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        self.savepath = os.path.join(
            save_dir,
            f'{model.name}_bs{batch_size}_e{epochs}_dspl{downsampling}_t{int(thresh*1000)}_id{run_id}.csv'
        )
        self.experiment_done = False
        if os.path.exists(self.savepath):
            trained_epochs = len(pd.read_csv(self.savepath, sep=';'))

            if trained_epochs >= epochs:
                self.experiment_done = True
                print(
                    f'Experiment Logs for the exact same experiment with identical run_id was detected, training will be skipped, consider using another run_id'
                )
        if os.path.exists((self.savepath.replace('.csv', '.pt'))):
            self.model.load_state_dict(
                torch.load(self.savepath.replace('.csv',
                                                 '.pt'))['model_state_dict'])
            if data_prallel:
                self.model = nn.DataParallel(self.model)
            self.model = self.model.to(self.device)
            if half_precision:
                self.model = self.model.half()
            self.optimizer.load_state_dict(
                torch.load(self.savepath.replace('.csv', '.pt'))['optimizer'])
            self.start_epoch = torch.load(self.savepath.replace(
                '.csv', '.pt'))['epoch'] + 1
            initial_epoch = self._infer_initial_epoch(self.savepath)
            print('Resuming existing run, starting at epoch', self.start_epoch,
                  'from', self.savepath.replace('.csv', '.pt'))
        else:
            if half_precision:
                self.model = self.model.half()
            self.start_epoch = 0
            initial_epoch = 0
            self.parallel = data_prallel
            if data_prallel:
                self.model = nn.DataParallel(self.model)
            self.model = self.model.to(self.device)
        writer = CSVandPlottingWriter(self.savepath.replace('.csv', ''),
                                      fontsize=16,
                                      primary_metric='test_accuracy')
        writer2 = NPYWriter(self.savepath.replace('.csv', ''))
        self.pooling_strat = conv_method
        print('Settomg Satiraton recording threshold to', thresh)
        self.half = half_precision

        self.stats = CheckLayerSat(self.savepath.replace('.csv', ''), [writer],
                                   model,
                                   ignore_layer_names='convolution',
                                   stats=['lsat', 'idim'],
                                   sat_threshold=.99,
                                   verbose=False,
                                   conv_method=conv_method,
                                   log_interval=1,
                                   device=self.saturation_device,
                                   reset_covariance=True,
                                   max_samples=None,
                                   initial_epoch=initial_epoch,
                                   interpolation_strategy='nearest'
                                   if downsampling is not None else None,
                                   interpolation_downsampling=4)

    def _infer_initial_epoch(self, savepath):
        if not os.path.exists(savepath):
            return 0
        else:
            df = pd.read_csv(savepath, sep=';', index_col=0)
            print(len(df) + 1)
            return len(df)

    def train(self):
        if self.experiment_done:
            return
        for epoch in range(self.start_epoch, self.epochs):
            #self.test(epoch=epoch)

            print('Start training epoch', epoch)
            print(
                "{} Epoch {}, training loss: {}, training accuracy: {}".format(
                    now(), epoch, *self.train_epoch()))
            self.test(epoch=epoch)
            if self.opt_name == "LRS":
                print('LRS step')
                self.lr_scheduler.step()
            self.stats.add_saturations()
            #self.stats.save()
            #if self.plot:
            #    plot_saturation_level_from_results(self.savepath, epoch)
        self.stats.close()
        return self.savepath + '.csv'

    def train_epoch(self):
        self.model.train()
        correct = 0
        total = 0
        running_loss = 0
        old_time = time()
        top5_accumulator = 0
        for batch, data in enumerate(self.train_loader):
            if batch % 10 == 0 and batch != 0:
                print(
                    batch, 'of', len(self.train_loader), 'processing time',
                    time() - old_time,
                    "top5_acc:" if self.compute_top_k else 'acc:',
                    round(top5_accumulator /
                          (batch), 3) if self.compute_top_k else correct /
                    total)
                old_time = time()
            inputs, labels = data
            if self.half:
                inputs, labels = inputs.to(self.device).half(), labels.to(
                    self.device)
            else:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            if self.compute_top_k:
                top5_accumulator += accuracy(outputs, labels, (5, ))[0]
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)

            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            correct += (predicted == labels.long()).sum().item()

            running_loss += loss.item()
        self.stats.add_scalar('training_loss', running_loss / total)
        if self.compute_top_k:
            self.stats.add_scalar('training_accuracy',
                                  (top5_accumulator / (batch + 1)))
        else:
            self.stats.add_scalar('training_accuracy', correct / total)
        return running_loss / total, correct / total

    def test(self, epoch, save=True):
        self.model.eval()
        correct = 0
        total = 0
        test_loss = 0
        top5_accumulator = 0
        with torch.no_grad():
            for batch, data in enumerate(self.test_loader):
                if batch % 10 == 0:
                    print('Processing eval batch', batch, 'of',
                          len(self.test_loader))
                inputs, labels = data
                if self.half:
                    inputs, labels = inputs.to(self.device).half(), labels.to(
                        self.device)
                else:
                    inputs, labels = inputs.to(self.device), labels.to(
                        self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels.long()).sum().item()
                if self.compute_top_k:
                    top5_accumulator += accuracy(outputs, labels, (5, ))[0]
                test_loss += loss.item()

        self.stats.add_scalar('test_loss', test_loss / total)
        if self.compute_top_k:
            self.stats.add_scalar('test_accuracy',
                                  top5_accumulator / (batch + 1))
            print('{} Test Top5-Accuracy on {} images: {:.4f}'.format(
                now(), total, top5_accumulator / (batch + 1)))

        else:
            self.stats.add_scalar('test_accuracy', correct / total)
            print('{} Test Accuracy on {} images: {:.4f}'.format(
                now(), total, correct / total))
        if save:
            torch.save(
                {
                    'model_state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'epoch': epoch,
                    'test_loss': test_loss / total
                }, self.savepath.replace('.csv', '.pt'))
        return correct / total, test_loss / total