예제 #1
0
        assert(old_opt.lr == opt.lr)
        assert(old_opt.decay == opt.decay)
        assert(old_opt.period == opt.period)
        assert(old_opt.t_mult == opt.t_mult)
        net.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        vis.load_state_dict(checkpoint['vis'])
        start_epoch = checkpoint['epoch'] + 1
    elif opt.pretrain is not None:
        checkpoint = torch.load(opt.pretrain)
        old_opt = checkpoint['opt']
        #assert(old_opt.channels == opt.channels)
        #assert(old_opt.bands == opt.bands)
        assert(old_opt.arch == opt.arch)
        assert(old_opt.blend == opt.blend)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        assert(False)

    for epoch in range(start_epoch, opt.n_epochs):
        train(opt, vis, epoch, train_loader, net, optimizer, scheduler)
        miou_val = test(opt, epoch, val_loader, net)
        miou_test = test(opt, epoch, test_loader, net)
        vis.epoch.append(epoch)
        vis.acc.append([miou_val, miou_test])
        vis.plot_acc()
        if (epoch + 1) % opt.period == 0:
            torch.save({'epoch': epoch, 'opt': opt, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(),  'scheduler': scheduler.state_dict(), 'vis': vis.state_dict()}, Path(opt.out_path) / (str(epoch) + '.pth'))
        print('Val mIoU:', miou_val, ' Test mIoU:', miou_test)
예제 #2
0
    elif opt.pretrain is not None:
        checkpoint = torch.load(opt.pretrain)
        old_opt = checkpoint['opt']
        assert (old_opt.channels == opt.channels)
        assert (old_opt.bands == opt.bands)
        assert (old_opt.arch == opt.arch)
        assert (old_opt.blend == opt.blend)
        net.load_state_dict(checkpoint['state_dict'])
    else:
        assert (False)

    for epoch in range(start_epoch, opt.n_epochs):
        train(opt, vis, epoch, train_loader, net, optimizer, scheduler)
        miou_val = test(opt, epoch, val_loader, net)
        miou_test = test(opt, epoch, test_loader, net)
        vis.epoch.append(epoch)
        vis.acc.append([miou_val, miou_test])
        vis.plot_acc()
        if (epoch + 1) % opt.period == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'opt': opt,
                    'state_dict': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'vis': vis.state_dict()
                },
                Path(opt.out_path) / (str(epoch) + '.pth'))
        print('Val mIoU:', miou_val, ' Test mIoU:', miou_test)