Exemplo n.º 1
0
def find_bounds_clr(model, loader, optimizer, criterion, device, dtype, min_lr=8e-6, max_lr=8e-5, step_size=2000,
                    mode='triangular', save_path='.'):
    model.train()
    correct1, correct5 = 0, 0
    scheduler = CyclicLR(optimizer, base_lr=min_lr, max_lr=max_lr, step_size=step_size, mode=mode)
    epoch_count = step_size // len(loader)  # Assuming step_size is multiple of batch per epoch
    accuracy = []
    for _ in trange(epoch_count):
        for batch_idx, (data, target) in enumerate(tqdm(loader)):
            if scheduler is not None:
                scheduler.batch_step()
            data, target = data.to(device=device, dtype=dtype), target.to(device=device)

            optimizer.zero_grad()
            output = model(data)

            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            corr = correct(output, target)
            accuracy.append(corr[0] / data.shape[0])

    lrs = np.linspace(min_lr, max_lr, step_size)
    plt.plot(lrs, accuracy)
    plt.show()
    plt.savefig(os.path.join(save_path, 'find_bounds_clr.png'))
    np.save(os.path.join(save_path, 'acc.npy'), accuracy)
    return
train_time = 0
best_rank1 = -np.inf
best_epoch = 0

if evaluation:
    print("Evaluation in Progress")
    test(model, queryloader, galleryloader)
    sys.exit(0)

print("Training of model in progress")

for epoch in range(start_epoch, num_epochs):
    start_train_time = time.time()
    train(epoch, model, optim, trainloader)

    scheduler.batch_step()

    if (epoch == 0 or epoch == 100 or epoch == 180 or epoch == 250
            or epoch == 350 or epoch == 500 or epoch == 650 or epoch == 750
            or epoch == 850 or epoch == 950 or epoch == 1100 or epoch == 1200):
        print("Testing of model in progress")
        rank1 = test(model, queryloader, galleryloader)
        best = rank1 > best_rank1
        if best:
            best_rank1 = rank1
            best_epoch = epoch + 1

        state_dict = model.state_dict()
        save_checkpoint(
            {
                'state_dict': state_dict,
Exemplo n.º 3
0
class Solver(object):
    def __init__(self, config):
        self.n_classes = config['n_classes']
        self.model = UNET(1, self.n_classes)
        if self.n_classes > 1:
            self.criterion = nn.CrossEntropyLoss()
        else:
            self.criterion = nn.BCEWithLogitsLoss()

        self.lr = config['lr']
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.lr,
                                    weight_decay=1e-3,
                                    betas=(0.8, 0.9))
        self.device = config['device']
        self.num_epochs = config['num_epochs']
        if config['N_subimgs'] % config['batch_size'] != 0:
            self.train_step = config['N_subimgs'] // config['batch_size'] + 1
        else:
            self.train_step = config['N_subimgs'] / config['batch_size']

        self.model_save_dir = config['save_pth']
        self.best_loss = 10
        self.scheduler = CyclicLR(self.optimizer,
                                  step_size=2 *
                                  (config['N_subimgs'] % config['batch_size']),
                                  mode='triangular2')

        if self.device is not None:
            self.device = torch.device(self.device)
            self.model.to(self.device)

    def restore_best(self):
        model_pth = os.path.join(self.model_save_dir,
                                 'BEST_checkpoint.pth.tar')
        checkpoint = torch.load(model_pth)
        state_dict = checkpoint['model']
        best_loss = checkpoint['loss']
        epoch = checkpoint['epoch']
        return epoch + 1, best_loss

    def restore_model(self):
        model_pth = os.path.join(self.model_save_dir, 'checkpoint.pth.tar')
        checkpoint = torch.load(model_pth)
        state_dict = checkpoint['model']
        epoch = checkpoint['epoch']
        return epoch + 1

    def save_checkpoint(self, state, path):
        torch.save(state, os.path.join(path, 'BEST_checkpoint.pth.tar'))

    def update_lr(self, lr):
        for param in self.optimizer.param_groups:
            param['lr'] = lr

    def train(self, prefetcher, resume=True, best=True):
        if best and resume:
            start_epoch, best_loss = self.restore_best()
            self.best_loss = best_loss.to(self.device)
            print('Start from %d, so far the best loss is %.6f' \
                    % (start_epoch, best_loss))
        elif resume:
            start_epoch = self.restore_model()
            print('Start from %d' % (start_epoch))
        else:
            start_epoch = 0
        #not really epoch, consider using step for naming
        for i in range(start_epoch, self.num_epochs):
            epoch_loss = 0
            self.model.train()
            self.scheduler.batch_step()
            for j in range(self.train_step):
                self.optimizer.zero_grad()
                img, label = prefetcher.next()
                img = Variable(img.to(self.device, dtype=torch.float32))
                label = Variable(label.to(self.device, dtype=torch.float32))
                output = self.model(img)
                loss = self.criterion(output, label)
                epoch_loss += loss
                loss.backward()
                self.optimizer.step()

                if loss < self.best_loss:
                    state = {}
                    state['loss'] = loss
                    state['model'] = self.model.state_dict()
                    state['epoch'] = i
                    print('loss decrease, saving model...........')
                    self.save_checkpoint(state, self.model_save_dir)
                    self.best_loss = loss

            aver_loss = epoch_loss / self.train_step
            print('training %d epoch, average loss is %.6f' % (i, aver_loss))