Esempio n. 1
0
def run(args):
    #model = torch.load(args.checkpoint_file) #MapModel(args)
    model = MapModel(args).to(device)
    state_dict = torch.load(args.checkpoint_file)
    model.load_state_dict(state_dict)
    model.eval()

    test_set = MapDataset(os.path.join(args.h5_path, 'test'))
    test_data_loader = DataLoader(dataset=test_set,
                                  num_workers=2,
                                  batch_size=10,
                                  shuffle=True)

    for batch, (screens, distances, objects) in enumerate(test_data_loader):
        screens, distances, objects = screens.to(device), distances.to(
            device), objects.to(device)

        pred_objects, pred_distances = model(screens)
        _, pred_objects = pred_objects.max(1)
        _, pred_distances = pred_distances.max(1)

        for i in range(len(distances)):
            draw(distances[i], objects[i], 'view-image-label.png')
            draw(pred_distances[i], pred_objects[i], 'view-image-pred.png')
            print(1)
        pass
Esempio n. 2
0
def train(args):

    train_set = MapDataset(os.path.join(args.h5_path, 'train'))
    train_data_loader = DataLoader(dataset=train_set,
                                   num_workers=4,
                                   batch_size=args.batch_size,
                                   shuffle=True)

    test_set = MapDataset(os.path.join(args.h5_path, 'test'))
    test_data_loader = DataLoader(dataset=test_set,
                                  num_workers=4,
                                  batch_size=10,
                                  shuffle=False)

    validation_set = MapDataset(os.path.join(args.h5_path, 'val'))
    validation_data_loader = DataLoader(dataset=validation_set,
                                        num_workers=4,
                                        batch_size=10,
                                        shuffle=False)

    model = MapModel(args).to(device)
    model.train()

    optimizer = optim.AdamW(model.parameters(), lr=5e-4)

    if args.load is not None and os.path.isfile(args.load):
        print("loading model parameters {}".format(args.load))
        state_dict = torch.load(args.load)
        model.load_state_dict(state_dict)
        optimizer_dict = torch.load(args.load + '_optimizer.pth')
        optimizer.load_state_dict(optimizer_dict)

    for epoch in range(args.epoch_num):
        epoch_loss_obj = 0
        epoch_loss_dist = 0
        epoch_accuracy_obj = 0
        epoch_accuracy_dist = 0
        running_loss_obj = 0
        running_loss_dist = 0
        running_accuracy_obj = 0
        running_accuracy_dist = 0
        batch_time = time.time()
        batch = 0
        for batch, (screens, distances,
                    objects) in enumerate(train_data_loader):
            screens, distances, objects = screens.to(device), distances.to(
                device), objects.to(device)

            #for i in range(len(distances)):
            #    draw(distances[i], objects[i], 'view-image.png')

            optimizer.zero_grad()

            pred_objects, pred_distances = model(screens)

            loss_obj = objects_criterion(pred_objects, objects)
            loss_dist = distances_criterion(pred_distances, distances)
            loss = loss_obj + loss_dist
            loss.backward()
            optimizer.step()

            running_loss_obj += loss_obj.item()
            running_loss_dist += loss_dist.item()
            epoch_loss_obj += loss_obj.item()
            epoch_loss_dist += loss_dist.item()

            _, pred_objects = pred_objects.max(1)
            accuracy = (pred_objects == objects).float().mean()
            running_accuracy_obj += accuracy
            epoch_accuracy_obj += accuracy

            _, pred_distances = pred_distances.max(1)
            accuracy = (pred_distances == distances).float().mean()
            running_accuracy_dist += accuracy
            epoch_accuracy_dist += accuracy

            if batch % 1000 == 999:
                torch.save(model.state_dict(), args.checkpoint_file)
                torch.save(optimizer.state_dict(),
                           args.checkpoint_file + '_optimizer.pth')

            batches_per_print = 10
            if batch % batches_per_print == batches_per_print - 1:  # print every batches_per_print mini-batches
                running_loss_obj /= batches_per_print
                running_loss_dist /= batches_per_print
                running_accuracy_obj /= batches_per_print
                running_accuracy_dist /= batches_per_print
                print(
                    '[{:d}, {:5d}] loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}, time: {:.6f}'
                    .format(epoch + 1, batch + 1, running_loss_obj,
                            running_loss_dist, running_accuracy_obj,
                            running_accuracy_dist,
                            (time.time() - batch_time) / batches_per_print))
                running_loss_obj, running_loss_dist = 0, 0
                running_accuracy_obj, running_accuracy_dist = 0, 0
                batch_time = time.time()

        batch_num = batch + 1
        epoch_loss_obj /= batch_num
        epoch_loss_dist /= batch_num
        epoch_accuracy_obj /= batch_num
        epoch_accuracy_dist /= batch_num

        if epoch % args.checkpoint_rate == args.checkpoint_rate - 1:
            torch.save(model.state_dict(), args.checkpoint_file)
            torch.save(optimizer.state_dict(),
                       args.checkpoint_file + '_optimizer.pth')

        val_loss, val_accuracy = test(model, validation_data_loader)

        print(
            '[{:d}] TRAIN loss: {:.3f}, {:.3f} accuracy: {:.3f}, {:.3f}, VAL loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}'
            .format(epoch + 1, epoch_loss_obj, epoch_loss_dist,
                    epoch_accuracy_obj, epoch_accuracy_dist, *val_loss,
                    *val_accuracy))

        train_writer.add_scalar('map/loss_obj', epoch_loss_obj, epoch)
        train_writer.add_scalar('map/loss_dist', epoch_loss_dist, epoch)
        train_writer.add_scalar('map/accuracy_obj', epoch_accuracy_obj, epoch)
        train_writer.add_scalar('map/accuracy_dist', epoch_accuracy_dist,
                                epoch)
        val_writer.add_scalar('map/loss_obj', val_loss[0], epoch)
        val_writer.add_scalar('map/loss_dist', val_loss[1], epoch)
        val_writer.add_scalar('map/accuracy_obj', val_accuracy[0], epoch)
        val_writer.add_scalar('map/accuracy_dist', val_accuracy[1], epoch)

    test_loss, test_accuracy = test(model, test_data_loader)
    print('[TEST] loss: {:.3f}, {:.3f}, accuracy: {:.3f}, {:.3f}'.format(
        *test_loss, *test_accuracy))