val_dataset = tusimpleLoader('/home/tejus/Downloads/train_set/',
                             split="val",
                             augmentations=None)
valloader = torch.utils.data.DataLoader(val_dataset,
                                        batch_size=VAL_BATCH,
                                        shuffle=True,
                                        num_workers=VAL_BATCH,
                                        pin_memory=True)
# dataset = kittiLoader('/mnt/data/tejus/kitti_road/data_road/', split="train", augmentations=augmentations)
# trainloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1, pin_memory=True)

running_metrics_val = runningScore(2)

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                    net.parameters()),
                             lr=0.0001)
# optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr = 0.0001, momentum=0.9) # 0.00001
# loss_fn = nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       'min',
                                                       patience=2,
                                                       verbose=True,
                                                       min_lr=0.000001)
loss_fn = nn.CrossEntropyLoss()

import math

best_val_loss = math.inf

for EPOCHS in range(50):
checkpoint_dir = '/home/tejus/lane-seg-experiments/Segmentation/CAN_logger/frontend_only/runs/2018-10-06_14-51-26/best_val_model_tested.pkl'
run_id = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
logdir = os.path.join('runs/' , str(run_id))
writer = SummaryWriter(log_dir=logdir)
print('RUNDIR: {}'.format(logdir))
logger = get_logger(logdir)
logger.info('Let the party begin')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.manual_seed(102)

# Network definition 

net = CAN()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.0001)
# optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr = 0.0001, momentum=0.9) # 0.00001
# loss_fn = nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 2, verbose = True, min_lr = 0.000001)
loss_fn = nn.CrossEntropyLoss()
net.to(device)

if not resume_training:
    pretrained_state_dict = models.vgg16(pretrained=True).state_dict()
    pretrained_state_dict = {k:v for k, v in pretrained_state_dict.items() if 'classifier' not in k}
    #rename parameters. ordering changes because maxpool layers aren't present in network.
    pretrained_state_dict['features.23.weight'] = pretrained_state_dict.pop('features.24.weight')
    pretrained_state_dict['features.23.bias'] = pretrained_state_dict.pop('features.24.bias')
    pretrained_state_dict['features.25.weight'] = pretrained_state_dict.pop('features.26.weight')
    pretrained_state_dict['features.25.bias'] = pretrained_state_dict.pop('features.26.bias')
    pretrained_state_dict['features.27.weight'] = pretrained_state_dict.pop('features.28.weight')