Exemplo n.º 1
0
def main(args):
    batch_size = args.batch_size
    num_epoch = args.epoch

    raw_text_dir = args.raw_text_dir
    raw_images_dir = args.raw_images_dir
    weight_dir = args.weight_dir

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    trainset = KITTI(raw_text_dir, raw_images_dir)
    dataloader = torch.utils.data.DataLoader(trainset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=4)
    trainer = Trainer(device,
                      args.decay,
                      batchnorm=True,
                      pretrained=False,
                      lr=args.lr,
                      momentum=args.momentum)
    writer = SummaryWriter()

    if args.resume:
        assert os.path.isfile(args.resume), "{} is not a file.".format(
            args.resume)
        state = torch.load(args.resume)
        trainer.load(state)
        it = state["iterations"]
        print("Checkpoint is loaded at {} | Iterations: {}".format(
            args.resume, it))

    else:
        it = 0

    for e in range(1, num_epoch):
        sum_loss = 0
        prog_bar = tqdm(dataloader, desc="Epoch {}".format(e))
        for i, data in enumerate(prog_bar):
            images_l = data[0].to(device)
            images_r = data[1].to(device)

            loss, ap, lr, ds = trainer(images_l, images_r)
            loss = loss.item() / batch_size
            ap = ap.item() / batch_size
            lr = lr.item() / batch_size
            ds = ds.item() / batch_size

            prog_bar.set_postfix(Loss=loss)
            writer.add_scalar('Total Loss', loss, it)
            writer.add_scalar('AP Loss', ap, it)
            writer.add_scalar('LR Loss', lr, it)
            writer.add_scalar('DS Loss', ds, it)

            it += 1

            if it % 2000 == 0:
                print('Saving checkpoint...')
                trainer.save(weight_dir, it)
                print("Checkpoint is saved at {} | Iterations: {}".format(
                    weight_dir, it))

                disp_images = trainer.get_disp_images()
                log_images(writer, images_l, disp_images[0], it)