Beispiel #1
0
 def set_loader_tgt(self):
     batch_size = self.batch_size
     shuffle = self.shuffle
     num_workers = self.num_workers
     if self.crop_size is not None:
         collate_fn = lambda batch: augment_collate(batch, crop=self.crop_size,
                 halfcrop=self.half_crop, flip=True)
     else:
         collate_fn=torch.utils.data.dataloader.default_collate
     self.loader_tgt = torch.utils.data.DataLoader(self.target, 
             batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
             collate_fn=collate_fn, pin_memory=True)
     self.iters_tgt = iter(self.loader_tgt)
Beispiel #2
0
def main(output, dataset, target_name, datadir, batch_size, lr, iterations,
         momentum, snapshot, downscale, augmentation, fyu, crop_size, weights,
         model, gpu, num_cls, nthreads, model_weights, data_flag,
         serial_batches, resize_to, start_step, preprocessing, small,
         rundir_flag):
    if weights is not None:
        raise RuntimeError("weights don't work because eric is bad at coding")
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    config_logging()
    logdir_flag = data_flag
    if rundir_flag != "":
        logdir_flag += "_{}".format(rundir_flag)

    logdir = 'runs/{:s}/{:s}/{:s}'.format(model, '-'.join(dataset),
                                          logdir_flag)
    writer = SummaryWriter(log_dir=logdir)
    if model == 'fcn8s':
        net = get_model(model,
                        num_cls=num_cls,
                        weights_init=model_weights,
                        output_last_ft=True)
    else:
        net = get_model(model,
                        num_cls=num_cls,
                        finetune=True,
                        weights_init=model_weights)
    net.cuda()

    str_ids = gpu.split(',')
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >= 0:
            gpu_ids.append(id)

    # set gpu ids
    if len(gpu_ids) > 0:
        torch.cuda.set_device(gpu_ids[0])
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)

    transform = []
    target_transform = []

    if preprocessing:
        transform.extend([
            torchvision.transforms.Resize(
                [int(resize_to), int(int(resize_to) * 1.8)],
                interpolation=Image.BICUBIC)
        ])
        target_transform.extend([
            torchvision.transforms.Resize(
                [int(resize_to), int(int(resize_to) * 1.8)],
                interpolation=Image.NEAREST)
        ])

    transform.extend([net.module.transform])
    target_transform.extend([to_tensor_raw])
    transform = torchvision.transforms.Compose(transform)
    target_transform = torchvision.transforms.Compose(target_transform)

    datasets = [
        get_dataset(name,
                    os.path.join(datadir, name),
                    num_cls=num_cls,
                    transform=transform,
                    target_transform=target_transform,
                    data_flag=data_flag,
                    small=small) for name in dataset
    ]

    target_dataset = get_dataset(target_name,
                                 os.path.join(datadir, target_name),
                                 num_cls=num_cls,
                                 transform=transform,
                                 target_transform=target_transform,
                                 data_flag=data_flag,
                                 small=small)

    if weights is not None:
        weights = np.loadtxt(weights)

    if augmentation:
        collate_fn = lambda batch: augment_collate(
            batch, crop=crop_size, flip=True)
    else:
        collate_fn = torch.utils.data.dataloader.default_collate

    loaders = [
        torch.utils.data.DataLoader(dataset,
                                    batch_size=batch_size,
                                    shuffle=not serial_batches,
                                    num_workers=nthreads,
                                    collate_fn=collate_fn,
                                    pin_memory=True,
                                    drop_last=True) for dataset in datasets
    ]

    target_loader = torch.utils.data.DataLoader(target_dataset,
                                                batch_size=batch_size,
                                                shuffle=not serial_batches,
                                                num_workers=nthreads,
                                                collate_fn=collate_fn,
                                                pin_memory=True,
                                                drop_last=True)
    iteration = start_step
    losses = deque(maxlen=10)
    losses_domain_syn = deque(maxlen=10)
    losses_domain_gta = deque(maxlen=10)
    losses_task = deque(maxlen=10)

    for loader in loaders:
        loader.dataset.__getitem__(0, debug=True)

    input_dim = 2048
    configs = {
        "input_dim": input_dim,
        "hidden_layers": [1000, 500, 100],
        "num_classes": 2,
        'num_domains': 2,
        'mode': 'dynamic',
        'mu': 1e-2,
        'gamma': 10.0
    }

    mdan = MDANet(configs).to(gpu_ids[0])
    mdan = torch.nn.DataParallel(mdan, gpu_ids)
    mdan.train()

    opt = torch.optim.Adam(itertools.chain(mdan.module.parameters(),
                                           net.module.parameters()),
                           lr=1e-4)

    # cnt = 0
    for (im_syn, label_syn), (im_gta,
                              label_gta), (im_cs,
                                           label_cs) in multi_source_infinite(
                                               loaders, target_loader):
        # cnt += 1
        # print(cnt)
        # Clear out gradients
        opt.zero_grad()

        # load data/label
        im_syn = make_variable(im_syn, requires_grad=False)
        label_syn = make_variable(label_syn, requires_grad=False)

        im_gta = make_variable(im_gta, requires_grad=False)
        label_gta = make_variable(label_gta, requires_grad=False)

        im_cs = make_variable(im_cs, requires_grad=False)
        label_cs = make_variable(label_cs, requires_grad=False)

        if iteration == 0:
            print("im_syn size: {}".format(im_syn.size()))
            print("label_syn size: {}".format(label_syn.size()))

            print("im_gta size: {}".format(im_gta.size()))
            print("label_gta size: {}".format(label_gta.size()))

            print("im_cs size: {}".format(im_cs.size()))
            print("label_cs size: {}".format(label_cs.size()))

        if not (im_syn.size() == im_gta.size() == im_cs.size()):
            print(im_syn.size())
            print(im_gta.size())
            print(im_cs.size())

        # forward pass and compute loss
        preds_syn, ft_syn = net(im_syn)
        # pooled_ft_syn = avg_pool(ft_syn)

        preds_gta, ft_gta = net(im_gta)
        # pooled_ft_gta = avg_pool(ft_gta)

        preds_cs, ft_cs = net(im_cs)
        # pooled_ft_cs = avg_pool(ft_cs)

        loss_synthia = supervised_loss(preds_syn, label_syn)
        loss_gta = supervised_loss(preds_gta, label_gta)

        loss = loss_synthia + loss_gta
        losses_task.append(loss.item())

        logprobs, sdomains, tdomains = mdan(ft_syn, ft_gta, ft_cs)

        slabels = torch.ones(batch_size, requires_grad=False).type(
            torch.LongTensor).to(gpu_ids[0])
        tlabels = torch.zeros(batch_size, requires_grad=False).type(
            torch.LongTensor).to(gpu_ids[0])

        # TODO: increase task loss
        # Compute prediction accuracy on multiple training sources.
        domain_losses = torch.stack([
            F.nll_loss(sdomains[j], slabels) +
            F.nll_loss(tdomains[j], tlabels)
            for j in range(configs['num_domains'])
        ])
        losses_domain_syn.append(domain_losses[0].item())
        losses_domain_gta.append(domain_losses[1].item())

        # Different final loss function depending on different training modes.
        if configs['mode'] == "maxmin":
            loss = torch.max(loss) + configs['mu'] * torch.min(domain_losses)
        elif configs['mode'] == "dynamic":
            loss = torch.log(
                torch.sum(
                    torch.exp(configs['gamma'] *
                              (loss + configs['mu'] * domain_losses)))
            ) / configs['gamma']

        # backward pass
        loss.backward()
        losses.append(loss.item())

        torch.nn.utils.clip_grad_norm_(net.module.parameters(), 10)
        torch.nn.utils.clip_grad_norm_(mdan.module.parameters(), 10)
        # step gradients
        opt.step()

        # log results
        if iteration % 10 == 0:
            logging.info(
                'Iteration {}:\t{:.3f} Domain SYN: {:.3f} Domain GTA: {:.3f} Task: {:.3f}'
                .format(iteration, np.mean(losses), np.mean(losses_domain_syn),
                        np.mean(losses_domain_gta), np.mean(losses_task)))
            writer.add_scalar('loss', np.mean(losses), iteration)
            writer.add_scalar('domain_syn', np.mean(losses_domain_syn),
                              iteration)
            writer.add_scalar('domain_gta', np.mean(losses_domain_gta),
                              iteration)
            writer.add_scalar('task', np.mean(losses_task), iteration)
        iteration += 1

        if iteration % 500 == 0:
            os.makedirs(output, exist_ok=True)
            torch.save(net.module.state_dict(),
                       '{}/net-itercurr.pth'.format(output))

        if iteration % snapshot == 0:
            torch.save(net.module.state_dict(),
                       '{}/iter_{}.pth'.format(output, iteration))

        if iteration >= iterations:
            logging.info('Optimization complete.')
Beispiel #3
0
def main(output, dataset, datadir, batch_size, lr, step, iterations, momentum,
         snapshot, downscale, augmentation, fyu, crop_size, weights, model,
         gpu, num_cls):
    if weights is not None:
        raise RuntimeError("weights don't work because eric is bad at coding")
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    config_logging()

    logdir = 'runs/{:s}/{:s}'.format(model, '-'.join(dataset))
    writer = SummaryWriter(log_dir=logdir)
    net = get_model(model, num_cls=num_cls, finetune=True)
    net.cuda()
    transform = []
    target_transform = []
    if downscale is not None:
        transform.append(torchvision.transforms.Scale(1024 // downscale))
        target_transform.append(
            torchvision.transforms.Scale(1024 // downscale,
                                         interpolation=Image.NEAREST))
    transform.extend([torchvision.transforms.Scale(1024), net.transform])
    target_transform.extend([
        torchvision.transforms.Scale(1024, interpolation=Image.NEAREST),
        to_tensor_raw
    ])
    transform = torchvision.transforms.Compose(transform)
    target_transform = torchvision.transforms.Compose(target_transform)

    datasets = [
        get_dataset(name,
                    os.path.join(datadir, name),
                    transform=transform,
                    target_transform=target_transform) for name in dataset
    ]

    if weights is not None:
        weights = np.loadtxt(weights)
    opt = torch.optim.SGD(net.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=0.0005)

    if augmentation:
        collate_fn = lambda batch: augment_collate(
            batch, crop=crop_size, flip=True)
    else:
        collate_fn = torch.utils.data.dataloader.default_collate
    print(datasets)
    loaders = [
        torch.utils.data.DataLoader(dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=2,
                                    collate_fn=collate_fn,
                                    pin_memory=True) for dataset in datasets
    ]
    iteration = 0
    losses = deque(maxlen=10)
    for im, label in roundrobin_infinite(*loaders):
        # Clear out gradients
        opt.zero_grad()

        # load data/label
        im = make_variable(im, requires_grad=False)
        label = make_variable(label, requires_grad=False)

        # forward pass and compute loss
        preds = net(im)
        loss = supervised_loss(preds, label)

        # backward pass
        loss.backward()
        losses.append(loss.item())

        # step gradients
        opt.step()

        # log results
        if iteration % 10 == 0:
            logging.info('Iteration {}:\t{}'.format(iteration,
                                                    np.mean(losses)))
            writer.add_scalar('loss', np.mean(losses), iteration)
        iteration += 1
        if step is not None and iteration % step == 0:
            logging.info('Decreasing learning rate by 0.1.')
            step_lr(optimizer, 0.1)
        if iteration % snapshot == 0:
            torch.save(net.state_dict(),
                       '{}-iter{}.pth'.format(output, iteration))
        if iteration >= iterations:
            logging.info('Optimization complete.')
            break
Beispiel #4
0
def main(output, dataset, datadir, batch_size, lr, step, iterations, momentum,
         snapshot, downscale, augmentation, fyu, crop_size, weights, model,
         gpu, num_cls, nthreads, model_weights, data_flag, serial_batches,
         resize_to, start_step, preprocessing, small, rundir_flag, force_split,
         adam):
    if weights is not None:
        raise RuntimeError("weights don't work because eric is bad at coding")
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    config_logging()
    logdir_flag = data_flag
    if rundir_flag != "":
        logdir_flag += "_{}".format(rundir_flag)

    logdir = 'runs/{:s}/{:s}/{:s}'.format(model, '-'.join(dataset),
                                          logdir_flag)
    writer = SummaryWriter(log_dir=logdir)
    if model == 'fcn8s':
        net = get_model(model, num_cls=num_cls, weights_init=model_weights)
    else:
        net = get_model(model,
                        num_cls=num_cls,
                        finetune=True,
                        weights_init=model_weights)
    net.cuda()

    str_ids = gpu.split(',')
    gpu_ids = []
    for str_id in str_ids:
        id = int(str_id)
        if id >= 0:
            gpu_ids.append(id)

    # set gpu ids
    if len(gpu_ids) > 0:
        torch.cuda.set_device(gpu_ids[0])
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)

    transform = []
    target_transform = []

    if preprocessing:
        transform.extend([
            torchvision.transforms.Resize(
                [int(resize_to), int(int(resize_to) * 1.8)])
        ])
        target_transform.extend([
            torchvision.transforms.Resize(
                [int(resize_to), int(int(resize_to) * 1.8)],
                interpolation=Image.NEAREST)
        ])

    transform.extend([net.module.transform])
    target_transform.extend([to_tensor_raw])
    transform = torchvision.transforms.Compose(transform)
    target_transform = torchvision.transforms.Compose(target_transform)

    if force_split:
        datasets = []
        datasets.append(
            get_dataset(dataset[0],
                        os.path.join(datadir, dataset[0]),
                        num_cls=num_cls,
                        transform=transform,
                        target_transform=target_transform,
                        data_flag=data_flag))
        datasets.append(
            get_dataset(dataset[1],
                        os.path.join(datadir, dataset[1]),
                        num_cls=num_cls,
                        transform=transform,
                        target_transform=target_transform))
    else:
        datasets = [
            get_dataset(name,
                        os.path.join(datadir, name),
                        num_cls=num_cls,
                        transform=transform,
                        target_transform=target_transform,
                        data_flag=data_flag) for name in dataset
        ]

    if weights is not None:
        weights = np.loadtxt(weights)

    if adam:
        print("Using Adam")
        opt = torch.optim.Adam(net.module.parameters(), lr=1e-4)
    else:
        print("Using SGD")
        opt = torch.optim.SGD(net.module.parameters(),
                              lr=lr,
                              momentum=momentum,
                              weight_decay=0.0005)

    if augmentation:
        collate_fn = lambda batch: augment_collate(
            batch, crop=crop_size, flip=True)
    else:
        collate_fn = torch.utils.data.dataloader.default_collate

    loaders = [
        torch.utils.data.DataLoader(dataset,
                                    batch_size=batch_size,
                                    shuffle=not serial_batches,
                                    num_workers=nthreads,
                                    collate_fn=collate_fn,
                                    pin_memory=True) for dataset in datasets
    ]
    iteration = start_step
    losses = deque(maxlen=10)

    for loader in loaders:
        loader.dataset.__getitem__(0, debug=True)

    for im, label in roundrobin_infinite(*loaders):
        # Clear out gradients
        opt.zero_grad()

        # load data/label
        im = make_variable(im, requires_grad=False)
        label = make_variable(label, requires_grad=False)

        if iteration == 0:
            print("im size: {}".format(im.size()))
            print("label size: {}".format(label.size()))

        # forward pass and compute loss
        preds = net(im)
        loss = supervised_loss(preds, label)

        # backward pass
        loss.backward()
        losses.append(loss.item())

        # step gradients
        opt.step()

        # log results
        if iteration % 10 == 0:
            logging.info('Iteration {}:\t{}'.format(iteration,
                                                    np.mean(losses)))
            writer.add_scalar('loss', np.mean(losses), iteration)
        iteration += 1
        if step is not None and iteration % step == 0:
            logging.info('Decreasing learning rate by 0.1.')
            step_lr(opt, 0.1)

        if iteration % snapshot == 0:
            torch.save(net.module.state_dict(),
                       '{}/iter_{}.pth'.format(output, iteration))

        if iteration >= iterations:
            logging.info('Optimization complete.')
Beispiel #5
0
def main(config_path):
    config = None

    config_file = config_path.split('/')[-1]
    version = config_file.split('.')[0][1:]

    with open(config_path, 'r') as f:
        config = json.load(f)

    config["version"] = version
    config_logging()

    # Initialize SummaryWriter - For tensorboard visualizations
    logdir = 'runs/{:s}/{:s}/{:s}/{:s}'.format(config["model"],
                                               config["dataset"],
                                               'v{}'.format(config["version"]),
                                               'tflogs')
    logdir = logdir + "/"

    checkpointdir = join('runs', config["model"], config["dataset"],
                         'v{}'.format(config["version"]), 'checkpoints')

    print("Logging directory: {}".format(logdir))
    print("Checkpoint directory: {}".format(checkpointdir))

    versionpath = join('runs', config["model"], config["dataset"],
                       'v{}'.format(config["version"]))

    if not exists(versionpath):
        os.makedirs(versionpath)
        os.makedirs(checkpointdir)
        os.makedirs(logdir)
    elif exists(versionpath) and config["force"]:
        shutil.rmtree(versionpath)
        os.makedirs(versionpath)
        os.makedirs(checkpointdir)
        os.makedirs(logdir)
    else:
        print(
            "Version {} already exists! Please run with different version number"
            .format(config["version"]))
        logging.info(
            "Version {} already exists! Please run with different version number"
            .format(config["version"]))
        sys.exit(-1)

    writer = SummaryWriter(logdir)
    # Get appropriate model based on config parameters
    net = get_model(config["model"], num_cls=config["num_cls"])
    if args.load:
        net.load_state_dict(torch.load(args.load))
        print("============ Loading Model ===============")

    model_parameters = filter(lambda p: p.requires_grad, net.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    dataset = config["dataset"]
    num_workers = config["num_workers"]
    pin_memory = config["pin_memory"]
    dataset = dataset[0]

    datasets_train = get_fcn_dataset(config["dataset"],
                                     config["data_type"],
                                     join(config["datadir"],
                                          config["dataset"]),
                                     split='train')
    datasets_val = get_fcn_dataset(config["dataset"],
                                   config["data_type"],
                                   join(config["datadir"], config["dataset"]),
                                   split='val')
    datasets_test = get_fcn_dataset(config["dataset"],
                                    config["data_type"],
                                    join(config["datadir"], config["dataset"]),
                                    split='test')

    if config["weights"] is not None:
        weights = np.loadtxt(config["weights"])
    opt = torch.optim.SGD(net.parameters(),
                          lr=config["lr"],
                          momentum=config["momentum"],
                          weight_decay=0.0005)

    if config["augmentation"]:
        collate_fn = lambda batch: augment_collate(
            batch, crop=config["crop_size"], flip=True)
    else:
        collate_fn = torch.utils.data.dataloader.default_collate

    train_loader = torch.utils.data.DataLoader(datasets_train,
                                               batch_size=config["batch_size"],
                                               shuffle=True,
                                               num_workers=num_workers,
                                               collate_fn=collate_fn,
                                               pin_memory=pin_memory)

    # val_loader = torch.utils.data.DataLoader(datasets_val, batch_size=config["batch_size"],
    #                                         shuffle=True, num_workers=num_workers,
    #                                         collate_fn=collate_fn,
    #                                         pin_memory=pin_memory)

    test_loader = torch.utils.data.DataLoader(datasets_test,
                                              batch_size=config["batch_size"],
                                              shuffle=False,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn,
                                              pin_memory=pin_memory)

    data_metric = {'train': None, 'val': None, 'test': None}
    Q_size = len(train_loader) / config["batch_size"]

    metrics = {'losses': list(), 'ious': list(), 'recalls': list()}

    data_metric['train'] = copy(metrics)
    data_metric['val'] = copy(metrics)
    data_metric['test'] = copy(metrics)
    num_cls = config["num_cls"]
    hist = np.zeros((num_cls, num_cls))
    iteration = 0

    for epoch in range(config["num_epochs"] + 1):
        if config["phase"] == 'train':
            net.train()
            iterator = tqdm(iter(train_loader))

            # Epoch train
            print("Train Epoch!")
            for im, label in iterator:
                if torch.isnan(im).any() or torch.isnan(label).any():
                    import pdb
                    pdb.set_trace()
                iteration += 1
                # Clear out gradients
                opt.zero_grad()
                # load data/label
                im = make_variable(im, requires_grad=False)
                label = make_variable(label, requires_grad=False)
                #print(im.size())

                # forward pass and compute loss
                preds = net(im)
                #score = preds.data
                #_, pred = torch.max(score, 1)

                #hist += fast_hist(label.cpu().numpy().flatten(), pred.cpu().numpy().flatten(),num_cls)

                #acc_overall, acc_percls, iu, fwIU = result_stats(hist)
                loss = supervised_loss(preds, label)
                # iou = jaccard_score(preds, label)
                precision, rc, fscore, support, iou = sklearnScores(
                    preds, label.type(torch.IntTensor))
                #print(acc_overall, np.nanmean(acc_percls), np.nanmean(iu), fwIU)
                # backward pass
                loss.backward()

                # TODO: Right now this is running average, ideally we want true average. Make that change
                # Total average will be memory intensive, let it be running average for the moment.
                data_metric['train']['losses'].append(loss.item())
                data_metric['train']['ious'].append(iou)
                data_metric['train']['recalls'].append(rc)
                # step gradients
                opt.step()

                # Train visualizations - each iteration
                if iteration % config["train_tf_interval"] == 0:
                    vizz = preprocess_viz(im, preds, label)
                    writer.add_scalar('train/loss', loss, iteration)
                    writer.add_scalar('train/IOU', iou, iteration)
                    writer.add_scalar('train/recall', rc, iteration)
                    imutil = vutils.make_grid(torch.from_numpy(vizz),
                                              nrow=3,
                                              normalize=True,
                                              scale_each=True)
                    writer.add_image('{}_image_data'.format('train'), imutil,
                                     iteration)

                iterator.set_description("TRAIN V: {} | Epoch: {}".format(
                    config["version"], epoch))
                iterator.refresh()

                if iteration % 20000 == 0:
                    torch.save(
                        net.state_dict(),
                        join(checkpointdir,
                             'iter_{}_{}.pth'.format(iteration, epoch)))

            # clean before test/val
            opt.zero_grad()

            # Train visualizations - per epoch
            vizz = preprocess_viz(im, preds, label)
            writer.add_scalar('trainepoch/loss',
                              np.mean(data_metric['train']['losses']),
                              global_step=epoch)
            writer.add_scalar('trainepoch/IOU',
                              np.mean(data_metric['train']['ious']),
                              global_step=epoch)
            writer.add_scalar('trainepoch/recall',
                              np.mean(data_metric['train']['recalls']),
                              global_step=epoch)
            imutil = vutils.make_grid(torch.from_numpy(vizz),
                                      nrow=3,
                                      normalize=True,
                                      scale_each=True)
            writer.add_image('{}_image_data'.format('trainepoch'),
                             imutil,
                             global_step=epoch)

            print("Loss :{}".format(np.mean(data_metric['train']['losses'])))
            print("IOU :{}".format(np.mean(data_metric['train']['ious'])))
            print("recall :{}".format(np.mean(
                data_metric['train']['recalls'])))

            if epoch % config["checkpoint_interval"] == 0:
                torch.save(net.state_dict(),
                           join(checkpointdir, 'iter{}.pth'.format(epoch)))

            # Train epoch done. Free up lists
            for key in data_metric['train'].keys():
                data_metric['train'][key] = list()

            if epoch % config["val_epoch_interval"] == 0:
                net.eval()
                print("Val_epoch!")
                iterator = tqdm(iter(val_loader))
                for im, label in iterator:
                    # load data/label
                    im = make_variable(im, requires_grad=False)
                    label = make_variable(label, requires_grad=False)

                    # forward pass and compute loss
                    preds = net(im)
                    loss = supervised_loss(preds, label)
                    precision, rc, fscore, support, iou = sklearnScores(
                        preds, label.type(torch.IntTensor))

                    data_metric['val']['losses'].append(loss.item())
                    data_metric['val']['ious'].append(iou)
                    data_metric['val']['recalls'].append(rc)

                    iterator.set_description("VAL V: {} | Epoch: {}".format(
                        config["version"], epoch))
                    iterator.refresh()

                # Val visualizations
                vizz = preprocess_viz(im, preds, label)
                writer.add_scalar('valepoch/loss',
                                  np.mean(data_metric['val']['losses']),
                                  global_step=epoch)
                writer.add_scalar('valepoch/IOU',
                                  np.mean(data_metric['val']['ious']),
                                  global_step=epoch)
                writer.add_scalar('valepoch/Recall',
                                  np.mean(data_metric['val']['recalls']),
                                  global_step=epoch)
                imutil = vutils.make_grid(torch.from_numpy(vizz),
                                          nrow=3,
                                          normalize=True,
                                          scale_each=True)
                writer.add_image('{}_image_data'.format('val'),
                                 imutil,
                                 global_step=epoch)

                # Val epoch done. Free up lists
                for key in data_metric['val'].keys():
                    data_metric['val'][key] = list()

            # Epoch Test
            if epoch % config["test_epoch_interval"] == 0:
                net.eval()
                print("Test_epoch!")
                iterator = tqdm(iter(test_loader))
                for im, label in iterator:
                    # load data/label
                    im = make_variable(im, requires_grad=False)
                    label = make_variable(label, requires_grad=False)

                    # forward pass and compute loss
                    preds = net(im)
                    loss = supervised_loss(preds, label)
                    precision, rc, fscore, support, iou = sklearnScores(
                        preds, label.type(torch.IntTensor))

                    data_metric['test']['losses'].append(loss.item())
                    data_metric['test']['ious'].append(iou)
                    data_metric['test']['recalls'].append(rc)

                    iterator.set_description("TEST V: {} | Epoch: {}".format(
                        config["version"], epoch))
                    iterator.refresh()

                # Test visualizations
                writer.add_scalar('testepoch/loss',
                                  np.mean(data_metric['test']['losses']),
                                  global_step=epoch)
                writer.add_scalar('testepoch/IOU',
                                  np.mean(data_metric['test']['ious']),
                                  global_step=epoch)
                writer.add_scalar('testepoch/Recall',
                                  np.mean(data_metric['test']['recalls']),
                                  global_step=epoch)

                # Test epoch done. Free up lists
                for key in data_metric['test'].keys():
                    data_metric['test'][key] = list()

            if config["step"] is not None and epoch % config["step"] == 0:
                logging.info('Decreasing learning rate by 0.1 factor')
                step_lr(optimizer, 0.1)

    logging.info('Optimization complete.')