Example #1
0
    def __init__(self, hparams, teacher_path=''):
        super().__init__()

        # addition: convert dict to namespace when necessary
        # hack:
        if isinstance(hparams, dict):
            import argparse
            args = argparse.Namespace()
            for k, v in hparams.items():
                setattr(args, k, v)
            hparams = args

        self.hparams = hparams
        self.to_heatmap = ToHeatmap(hparams.heatmap_radius)

        if teacher_path:
            # modifiction: add str
            self.teacher = MapModel.load_from_checkpoint(str(teacher_path))
            self.teacher.freeze()

        self.net = SegmentationModel(10,
                                     4,
                                     hack=hparams.hack,
                                     temperature=hparams.temperature)
        self.converter = Converter()
        self.controller = RawController(4)
Example #2
0
    def __init__(self, hparams):
        super().__init__()

        self.hparams = hparams
        self.net = SegmentationModel(4, 4)

        self.teacher = MapModel.load_from_checkpoint(pathlib.Path('/home/bradyzhou/code/carla_random/') / hparams.teacher_path)
        # self.teacher.eval()

        self.converter = Converter()
Example #3
0
def run(args):
    #model = torch.load(args.checkpoint_file) #MapModel(args)
    model = MapModel(args).to(device)
    state_dict = torch.load(args.checkpoint_file)
    model.load_state_dict(state_dict)
    model.eval()

    test_set = MapDataset(os.path.join(args.h5_path, 'test'))
    test_data_loader = DataLoader(dataset=test_set,
                                  num_workers=2,
                                  batch_size=10,
                                  shuffle=True)

    for batch, (screens, distances, objects) in enumerate(test_data_loader):
        screens, distances, objects = screens.to(device), distances.to(
            device), objects.to(device)

        pred_objects, pred_distances = model(screens)
        _, pred_objects = pred_objects.max(1)
        _, pred_distances = pred_distances.max(1)

        for i in range(len(distances)):
            draw(distances[i], objects[i], 'view-image-label.png')
            draw(pred_distances[i], pred_objects[i], 'view-image-pred.png')
            print(1)
        pass
Example #4
0
def train(args):

    train_set = MapDataset(os.path.join(args.h5_path, 'train'))
    train_data_loader = DataLoader(dataset=train_set,
                                   num_workers=4,
                                   batch_size=args.batch_size,
                                   shuffle=True)

    test_set = MapDataset(os.path.join(args.h5_path, 'test'))
    test_data_loader = DataLoader(dataset=test_set,
                                  num_workers=4,
                                  batch_size=10,
                                  shuffle=False)

    validation_set = MapDataset(os.path.join(args.h5_path, 'val'))
    validation_data_loader = DataLoader(dataset=validation_set,
                                        num_workers=4,
                                        batch_size=10,
                                        shuffle=False)

    model = MapModel(args).to(device)
    model.train()

    optimizer = optim.AdamW(model.parameters(), lr=5e-4)

    if args.load is not None and os.path.isfile(args.load):
        print("loading model parameters {}".format(args.load))
        state_dict = torch.load(args.load)
        model.load_state_dict(state_dict)
        optimizer_dict = torch.load(args.load + '_optimizer.pth')
        optimizer.load_state_dict(optimizer_dict)

    for epoch in range(args.epoch_num):
        epoch_loss_obj = 0
        epoch_loss_dist = 0
        epoch_accuracy_obj = 0
        epoch_accuracy_dist = 0
        running_loss_obj = 0
        running_loss_dist = 0
        running_accuracy_obj = 0
        running_accuracy_dist = 0
        batch_time = time.time()
        batch = 0
        for batch, (screens, distances,
                    objects) in enumerate(train_data_loader):
            screens, distances, objects = screens.to(device), distances.to(
                device), objects.to(device)

            #for i in range(len(distances)):
            #    draw(distances[i], objects[i], 'view-image.png')

            optimizer.zero_grad()

            pred_objects, pred_distances = model(screens)

            loss_obj = objects_criterion(pred_objects, objects)
            loss_dist = distances_criterion(pred_distances, distances)
            loss = loss_obj + loss_dist
            loss.backward()
            optimizer.step()

            running_loss_obj += loss_obj.item()
            running_loss_dist += loss_dist.item()
            epoch_loss_obj += loss_obj.item()
            epoch_loss_dist += loss_dist.item()

            _, pred_objects = pred_objects.max(1)
            accuracy = (pred_objects == objects).float().mean()
            running_accuracy_obj += accuracy
            epoch_accuracy_obj += accuracy

            _, pred_distances = pred_distances.max(1)
            accuracy = (pred_distances == distances).float().mean()
            running_accuracy_dist += accuracy
            epoch_accuracy_dist += accuracy

            if batch % 1000 == 999:
                torch.save(model.state_dict(), args.checkpoint_file)
                torch.save(optimizer.state_dict(),
                           args.checkpoint_file + '_optimizer.pth')

            batches_per_print = 10
            if batch % batches_per_print == batches_per_print - 1:  # print every batches_per_print mini-batches
                running_loss_obj /= batches_per_print
                running_loss_dist /= batches_per_print
                running_accuracy_obj /= batches_per_print
                running_accuracy_dist /= batches_per_print
                print(
                    '[{:d}, {:5d}] loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}, time: {:.6f}'
                    .format(epoch + 1, batch + 1, running_loss_obj,
                            running_loss_dist, running_accuracy_obj,
                            running_accuracy_dist,
                            (time.time() - batch_time) / batches_per_print))
                running_loss_obj, running_loss_dist = 0, 0
                running_accuracy_obj, running_accuracy_dist = 0, 0
                batch_time = time.time()

        batch_num = batch + 1
        epoch_loss_obj /= batch_num
        epoch_loss_dist /= batch_num
        epoch_accuracy_obj /= batch_num
        epoch_accuracy_dist /= batch_num

        if epoch % args.checkpoint_rate == args.checkpoint_rate - 1:
            torch.save(model.state_dict(), args.checkpoint_file)
            torch.save(optimizer.state_dict(),
                       args.checkpoint_file + '_optimizer.pth')

        val_loss, val_accuracy = test(model, validation_data_loader)

        print(
            '[{:d}] TRAIN loss: {:.3f}, {:.3f} accuracy: {:.3f}, {:.3f}, VAL loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}'
            .format(epoch + 1, epoch_loss_obj, epoch_loss_dist,
                    epoch_accuracy_obj, epoch_accuracy_dist, *val_loss,
                    *val_accuracy))

        train_writer.add_scalar('map/loss_obj', epoch_loss_obj, epoch)
        train_writer.add_scalar('map/loss_dist', epoch_loss_dist, epoch)
        train_writer.add_scalar('map/accuracy_obj', epoch_accuracy_obj, epoch)
        train_writer.add_scalar('map/accuracy_dist', epoch_accuracy_dist,
                                epoch)
        val_writer.add_scalar('map/loss_obj', val_loss[0], epoch)
        val_writer.add_scalar('map/loss_dist', val_loss[1], epoch)
        val_writer.add_scalar('map/accuracy_obj', val_accuracy[0], epoch)
        val_writer.add_scalar('map/accuracy_dist', val_accuracy[1], epoch)

    test_loss, test_accuracy = test(model, test_data_loader)
    print('[TEST] loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}'.format(
        *test_loss, *test_accuracy))
Example #5
0
import sys

import cv2
import torch
import numpy as np

from PIL import Image, ImageDraw

from dataset import CarlaDataset
from converter import Converter
from map_model import MapModel
import common

net = MapModel.load_from_checkpoint(sys.argv[1])
net.cuda()
net.eval()

data = CarlaDataset(sys.argv[2])
converter = Converter()

for i in range(len(data)):
    rgb, topdown, points, heatmap, heatmap_img, meta = data[i]
    points_unnormalized = (points + 1) / 2 * 256
    points_cam = converter(points_unnormalized)
    heatmap_flipped = torch.FloatTensor(heatmap.numpy()[:, :, ::-1].copy())

    with torch.no_grad():
        points_pred = net(torch.cat((topdown, heatmap),
                                    0).cuda()[None]).cpu().squeeze()
        points_pred_flipped = net(
            torch.cat((topdown, heatmap_flipped),