def __init__(self, hparams, teacher_path=''): super().__init__() # addition: convert dict to namespace when necessary # hack: if isinstance(hparams, dict): import argparse args = argparse.Namespace() for k, v in hparams.items(): setattr(args, k, v) hparams = args self.hparams = hparams self.to_heatmap = ToHeatmap(hparams.heatmap_radius) if teacher_path: # modifiction: add str self.teacher = MapModel.load_from_checkpoint(str(teacher_path)) self.teacher.freeze() self.net = SegmentationModel(10, 4, hack=hparams.hack, temperature=hparams.temperature) self.converter = Converter() self.controller = RawController(4)
def __init__(self, hparams): super().__init__() self.hparams = hparams self.net = SegmentationModel(4, 4) self.teacher = MapModel.load_from_checkpoint(pathlib.Path('/home/bradyzhou/code/carla_random/') / hparams.teacher_path) # self.teacher.eval() self.converter = Converter()
def run(args): #model = torch.load(args.checkpoint_file) #MapModel(args) model = MapModel(args).to(device) state_dict = torch.load(args.checkpoint_file) model.load_state_dict(state_dict) model.eval() test_set = MapDataset(os.path.join(args.h5_path, 'test')) test_data_loader = DataLoader(dataset=test_set, num_workers=2, batch_size=10, shuffle=True) for batch, (screens, distances, objects) in enumerate(test_data_loader): screens, distances, objects = screens.to(device), distances.to( device), objects.to(device) pred_objects, pred_distances = model(screens) _, pred_objects = pred_objects.max(1) _, pred_distances = pred_distances.max(1) for i in range(len(distances)): draw(distances[i], objects[i], 'view-image-label.png') draw(pred_distances[i], pred_objects[i], 'view-image-pred.png') print(1) pass
def train(args): train_set = MapDataset(os.path.join(args.h5_path, 'train')) train_data_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=args.batch_size, shuffle=True) test_set = MapDataset(os.path.join(args.h5_path, 'test')) test_data_loader = DataLoader(dataset=test_set, num_workers=4, batch_size=10, shuffle=False) validation_set = MapDataset(os.path.join(args.h5_path, 'val')) validation_data_loader = DataLoader(dataset=validation_set, num_workers=4, batch_size=10, shuffle=False) model = MapModel(args).to(device) model.train() optimizer = optim.AdamW(model.parameters(), lr=5e-4) if args.load is not None and os.path.isfile(args.load): print("loading model parameters {}".format(args.load)) state_dict = torch.load(args.load) model.load_state_dict(state_dict) optimizer_dict = torch.load(args.load + '_optimizer.pth') optimizer.load_state_dict(optimizer_dict) for epoch in range(args.epoch_num): epoch_loss_obj = 0 epoch_loss_dist = 0 epoch_accuracy_obj = 0 epoch_accuracy_dist = 0 running_loss_obj = 0 running_loss_dist = 0 running_accuracy_obj = 0 running_accuracy_dist = 0 batch_time = time.time() batch = 0 for batch, (screens, distances, objects) in enumerate(train_data_loader): screens, distances, objects = screens.to(device), distances.to( device), objects.to(device) #for i in range(len(distances)): # draw(distances[i], objects[i], 'view-image.png') optimizer.zero_grad() pred_objects, pred_distances = model(screens) loss_obj = objects_criterion(pred_objects, objects) loss_dist = distances_criterion(pred_distances, distances) loss = loss_obj + loss_dist loss.backward() optimizer.step() running_loss_obj += loss_obj.item() running_loss_dist += loss_dist.item() epoch_loss_obj += loss_obj.item() epoch_loss_dist += loss_dist.item() _, pred_objects = pred_objects.max(1) accuracy = (pred_objects == objects).float().mean() running_accuracy_obj += accuracy epoch_accuracy_obj += accuracy _, pred_distances = pred_distances.max(1) accuracy = (pred_distances == distances).float().mean() running_accuracy_dist += accuracy epoch_accuracy_dist += accuracy if batch % 1000 == 999: torch.save(model.state_dict(), args.checkpoint_file) torch.save(optimizer.state_dict(), args.checkpoint_file + '_optimizer.pth') batches_per_print = 10 if batch % batches_per_print == batches_per_print - 1: # print every batches_per_print mini-batches running_loss_obj /= batches_per_print running_loss_dist /= batches_per_print running_accuracy_obj /= batches_per_print running_accuracy_dist /= batches_per_print print( '[{:d}, {:5d}] loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}, time: {:.6f}' .format(epoch + 1, batch + 1, running_loss_obj, running_loss_dist, running_accuracy_obj, running_accuracy_dist, (time.time() - batch_time) / batches_per_print)) running_loss_obj, running_loss_dist = 0, 0 running_accuracy_obj, running_accuracy_dist = 0, 0 batch_time = time.time() batch_num = batch + 1 epoch_loss_obj /= batch_num epoch_loss_dist /= batch_num epoch_accuracy_obj /= batch_num epoch_accuracy_dist /= batch_num if epoch % args.checkpoint_rate == args.checkpoint_rate - 1: torch.save(model.state_dict(), args.checkpoint_file) torch.save(optimizer.state_dict(), args.checkpoint_file + '_optimizer.pth') val_loss, val_accuracy = test(model, validation_data_loader) print( '[{:d}] TRAIN loss: {:.3f}, {:.3f} accuracy: {:.3f}, {:.3f}, VAL loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}' .format(epoch + 1, epoch_loss_obj, epoch_loss_dist, epoch_accuracy_obj, epoch_accuracy_dist, *val_loss, *val_accuracy)) train_writer.add_scalar('map/loss_obj', epoch_loss_obj, epoch) train_writer.add_scalar('map/loss_dist', epoch_loss_dist, epoch) train_writer.add_scalar('map/accuracy_obj', epoch_accuracy_obj, epoch) train_writer.add_scalar('map/accuracy_dist', epoch_accuracy_dist, epoch) val_writer.add_scalar('map/loss_obj', val_loss[0], epoch) val_writer.add_scalar('map/loss_dist', val_loss[1], epoch) val_writer.add_scalar('map/accuracy_obj', val_accuracy[0], epoch) val_writer.add_scalar('map/accuracy_dist', val_accuracy[1], epoch) test_loss, test_accuracy = test(model, test_data_loader) print('[TEST] loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}'.format( *test_loss, *test_accuracy))
import sys import cv2 import torch import numpy as np from PIL import Image, ImageDraw from dataset import CarlaDataset from converter import Converter from map_model import MapModel import common net = MapModel.load_from_checkpoint(sys.argv[1]) net.cuda() net.eval() data = CarlaDataset(sys.argv[2]) converter = Converter() for i in range(len(data)): rgb, topdown, points, heatmap, heatmap_img, meta = data[i] points_unnormalized = (points + 1) / 2 * 256 points_cam = converter(points_unnormalized) heatmap_flipped = torch.FloatTensor(heatmap.numpy()[:, :, ::-1].copy()) with torch.no_grad(): points_pred = net(torch.cat((topdown, heatmap), 0).cuda()[None]).cpu().squeeze() points_pred_flipped = net( torch.cat((topdown, heatmap_flipped),