def run(args): #model = torch.load(args.checkpoint_file) #MapModel(args) model = MapModel(args).to(device) state_dict = torch.load(args.checkpoint_file) model.load_state_dict(state_dict) model.eval() test_set = MapDataset(os.path.join(args.h5_path, 'test')) test_data_loader = DataLoader(dataset=test_set, num_workers=2, batch_size=10, shuffle=True) for batch, (screens, distances, objects) in enumerate(test_data_loader): screens, distances, objects = screens.to(device), distances.to( device), objects.to(device) pred_objects, pred_distances = model(screens) _, pred_objects = pred_objects.max(1) _, pred_distances = pred_distances.max(1) for i in range(len(distances)): draw(distances[i], objects[i], 'view-image-label.png') draw(pred_distances[i], pred_objects[i], 'view-image-pred.png') print(1) pass
def train(args): train_set = MapDataset(os.path.join(args.h5_path, 'train')) train_data_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=args.batch_size, shuffle=True) test_set = MapDataset(os.path.join(args.h5_path, 'test')) test_data_loader = DataLoader(dataset=test_set, num_workers=4, batch_size=10, shuffle=False) validation_set = MapDataset(os.path.join(args.h5_path, 'val')) validation_data_loader = DataLoader(dataset=validation_set, num_workers=4, batch_size=10, shuffle=False) model = MapModel(args).to(device) model.train() optimizer = optim.AdamW(model.parameters(), lr=5e-4) if args.load is not None and os.path.isfile(args.load): print("loading model parameters {}".format(args.load)) state_dict = torch.load(args.load) model.load_state_dict(state_dict) optimizer_dict = torch.load(args.load + '_optimizer.pth') optimizer.load_state_dict(optimizer_dict) for epoch in range(args.epoch_num): epoch_loss_obj = 0 epoch_loss_dist = 0 epoch_accuracy_obj = 0 epoch_accuracy_dist = 0 running_loss_obj = 0 running_loss_dist = 0 running_accuracy_obj = 0 running_accuracy_dist = 0 batch_time = time.time() batch = 0 for batch, (screens, distances, objects) in enumerate(train_data_loader): screens, distances, objects = screens.to(device), distances.to( device), objects.to(device) #for i in range(len(distances)): # draw(distances[i], objects[i], 'view-image.png') optimizer.zero_grad() pred_objects, pred_distances = model(screens) loss_obj = objects_criterion(pred_objects, objects) loss_dist = distances_criterion(pred_distances, distances) loss = loss_obj + loss_dist loss.backward() optimizer.step() running_loss_obj += loss_obj.item() running_loss_dist += loss_dist.item() epoch_loss_obj += loss_obj.item() epoch_loss_dist += loss_dist.item() _, pred_objects = pred_objects.max(1) accuracy = (pred_objects == objects).float().mean() running_accuracy_obj += accuracy epoch_accuracy_obj += accuracy _, pred_distances = pred_distances.max(1) accuracy = (pred_distances == distances).float().mean() running_accuracy_dist += accuracy epoch_accuracy_dist += accuracy if batch % 1000 == 999: torch.save(model.state_dict(), args.checkpoint_file) torch.save(optimizer.state_dict(), args.checkpoint_file + '_optimizer.pth') batches_per_print = 10 if batch % batches_per_print == batches_per_print - 1: # print every batches_per_print mini-batches running_loss_obj /= batches_per_print running_loss_dist /= batches_per_print running_accuracy_obj /= batches_per_print running_accuracy_dist /= batches_per_print print( '[{:d}, {:5d}] loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}, time: {:.6f}' .format(epoch + 1, batch + 1, running_loss_obj, running_loss_dist, running_accuracy_obj, running_accuracy_dist, (time.time() - batch_time) / batches_per_print)) running_loss_obj, running_loss_dist = 0, 0 running_accuracy_obj, running_accuracy_dist = 0, 0 batch_time = time.time() batch_num = batch + 1 epoch_loss_obj /= batch_num epoch_loss_dist /= batch_num epoch_accuracy_obj /= batch_num epoch_accuracy_dist /= batch_num if epoch % args.checkpoint_rate == args.checkpoint_rate - 1: torch.save(model.state_dict(), args.checkpoint_file) torch.save(optimizer.state_dict(), args.checkpoint_file + '_optimizer.pth') val_loss, val_accuracy = test(model, validation_data_loader) print( '[{:d}] TRAIN loss: {:.3f}, {:.3f} accuracy: {:.3f}, {:.3f}, VAL loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}' .format(epoch + 1, epoch_loss_obj, epoch_loss_dist, epoch_accuracy_obj, epoch_accuracy_dist, *val_loss, *val_accuracy)) train_writer.add_scalar('map/loss_obj', epoch_loss_obj, epoch) train_writer.add_scalar('map/loss_dist', epoch_loss_dist, epoch) train_writer.add_scalar('map/accuracy_obj', epoch_accuracy_obj, epoch) train_writer.add_scalar('map/accuracy_dist', epoch_accuracy_dist, epoch) val_writer.add_scalar('map/loss_obj', val_loss[0], epoch) val_writer.add_scalar('map/loss_dist', val_loss[1], epoch) val_writer.add_scalar('map/accuracy_obj', val_accuracy[0], epoch) val_writer.add_scalar('map/accuracy_dist', val_accuracy[1], epoch) test_loss, test_accuracy = test(model, test_data_loader) print('[TEST] loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}'.format( *test_loss, *test_accuracy))