예제 #1
0
def train(data_size='all'):

    device = torch.device('cuda') if torch.cuda.is_available else torch.device(
        'cpu')

    # Image input
    model = SegNet(opt, 3)
    model = model.to(device)
    model.train()
    criterion = torch.nn.CrossEntropyLoss()
    criterion_d = DiscriminativeLoss()
    optimizer = SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum)

    if data_size == 'all':
        dataloader = get_dataloader(opt.paths, opt, device)
        model = batch_step(opt, optimizer, model, dataloader, criterion,
                           criterion_d, device)
        torch.save(model.state_dict(), 'model_all.pth')
    else:
        im = Image.open(opt.img_path)
        im = np.array(im, dtype=np.float32) / 255
        image = np.transpose(im, (2, 0, 1))
        data = torch.from_numpy(image).unsqueeze(0)
        data = Variable(data).to(device)

        labels = segmentation.slic(im,
                                   compactness=opt.compactness,
                                   n_segments=opt.num_superpixels)
        labels = labels.reshape(-1)
        label_nums = np.unique(labels)
        label_indices = [
            np.where(labels == label_nums[i])[0]
            for i in range(len(label_nums))
        ]

        model = one_step(opt, optimizer, model, data, label_indices, criterion,
                         criterion_d, device)
        torch.save(model.state_dict(), 'model_single.pth')
예제 #2
0
            writer.scalar_summary('train_loss', loss, batches_done)

            # Determine approximate time left for epoch
            epoch_batches_left = len(train_dataloader) - (index + 1)
            time_left = datetime.timedelta(seconds=epoch_batches_left *
                                           (time.time() - t_start) /
                                           (index + 1))
            print('epoch: {}\tbatches: {}\tloss: {:.8f}\tremaining time: {}'.
                  format(epoch, batches_done, loss, time_left))

        writer.scalar_summary('train_loss_epoch', epoch_loss.avg, epoch + 1)
        logger.info('{} epoch loss: {}'.format(epoch + 1, epoch_loss.avg))
        is_better = epoch_loss.avg < prev_loss
        if is_better:
            prev_loss = epoch_loss.avg
            torch.save(model.state_dict(),
                       f"checkpoints/best_ckpt_%d.pth" % epoch)
        else:
            torch.save(model.state_dict(),
                       f"checkpoints/segnet_ckpt_%d.pth" % epoch)

        if epoch % args.eval_interval == 0:
            val_loss, score, class_iou = validate(model=model,
                                                  val_path=val_path,
                                                  img_path=val_img_path,
                                                  mask_path=val_mask_path,
                                                  batch_size=8)
            writer.scalar_summary('val_loss', val_loss, epoch + 1)
            logger.info('epoch {} val loss: {}'.format(epoch + 1, val_loss))
            for k, v in score.items():
                print(k, v)
예제 #3
0
def train_autoencoder(epoch_plus):
    writer = SummaryWriter(log_dir='./runs_autoencoder_2')
    num_epochs = 400 - epoch_plus
    lr = 0.001
    bta1 = 0.9
    bta2 = 0.999
    weight_decay = 0.001

    # model = autoencoder(nchannels=3, width=172, height=600)
    model = SegNet(3)
    if ngpu > 1:
        model = nn.DataParallel(model)
    if use_gpu:
        model = model.to(device, non_blocking=True)
    if epoch_plus > 0:
        model.load_state_dict(
            torch.load('./autoencoder_models_2/autoencoder_{}.pth'.format(
                epoch_plus)))
    criterion = nn.MSELoss(reduction='sum')
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 betas=(bta1, bta2),
                                 weight_decay=weight_decay)

    for epoch in range(num_epochs):
        degree = randint(-180, 180)

        transforms = torchvision.transforms.Compose([
            torchvision.transforms.CenterCrop((172, 200)),
            torchvision.transforms.Resize((172, 200)),
            torchvision.transforms.RandomRotation((degree, degree)),
            torchvision.transforms.ToTensor()
        ])

        dataloader = get_dataloader(data_dir,
                                    train=True,
                                    transform=transforms,
                                    batch_size=batch_size)

        model.train()
        epoch_losses = AverageMeter()

        with tqdm(total=(1000 - 1000 % batch_size)) as _tqdm:
            _tqdm.set_description('epoch: {}/{}'.format(
                epoch + 1 + epoch_plus, num_epochs + epoch_plus))
            for data in dataloader:
                gt, text = data
                if use_gpu:
                    gt, text = gt.to(device, non_blocking=True), text.to(
                        device, non_blocking=True)

                predicted = model(text)

                # loss = criterion_bce(predicted, gt) + criterion_dice(predicted, gt)
                loss = criterion(
                    predicted, gt - text
                )  # predicts extracted text in white, all others in black
                epoch_losses.update(loss.item(), len(gt))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                _tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
                _tqdm.update(len(gt))

        save_path = './autoencoder_models_2'
        if not os.path.exists(save_path):
            os.mkdir(save_path)

        gt_text = gt - text
        predicted_mask = text + predicted

        torch.save(
            model.state_dict(),
            os.path.join(save_path,
                         'autoencoder_{}.pth'.format(epoch + 1 + epoch_plus)))
        writer.add_scalar('Loss', epoch_losses.avg, epoch + 1 + epoch_plus)
        writer.add_image('text/text_image_{}'.format(epoch + 1 + epoch_plus),
                         text[0].squeeze(), epoch + 1 + epoch_plus)
        writer.add_image('gt/gt_image_{}'.format(epoch + 1 + epoch_plus),
                         gt[0].squeeze(), epoch + 1 + epoch_plus)
        writer.add_image('gt_text/gt_image_{}'.format(epoch + 1 + epoch_plus),
                         gt_text[0].squeeze(), epoch + 1 + epoch_plus)
        writer.add_image(
            'predicted/predicted_image_{}'.format(epoch + 1 + epoch_plus),
            predicted_mask[0].squeeze(), epoch + 1 + epoch_plus)
        writer.add_image(
            'predicted_text/predicted_image_{}'.format(epoch + 1 + epoch_plus),
            predicted[0].squeeze(), epoch + 1 + epoch_plus)

    writer.close()
예제 #4
0

if __name__ == "__main__":

    model = SegNet().to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=0.01,
                                 betas=(0.9, 0.999))

    train_dataset, test_dataset = load_dataset('dataset/TrainingData')

    scores_avg = evaluate(test_dataset)
    print(scores_avg)

    for epoch in range(30):
        epoch_loss = 0
        for i, (x, y) in enumerate(train_dataset.batch(16)):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = loss_f(out, y)
            epoch_loss += loss.cpu().item()
            loss.backward()
            optimizer.step()
        print(epoch, epoch_loss / len(train_dataset))
        scores_avg = evaluate(test_dataset)
        print(epoch, scores_avg)
        print()
        model_path = f'models/epoch-{epoch}.pth'
        torch.save(model.state_dict(), model_path)