def main(path, dataset, datadir, model, gpu, num_cls):
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    net = get_model(model, num_cls=num_cls, weights_init=path)
    net.eval()
    ds = get_fcn_dataset(dataset, datadir, split='val',
                         transform=net.transform, target_transform=to_tensor_raw)
    classes = ds.classes
    loader = torch.utils.data.DataLoader(ds, num_workers=8)

    intersections = np.zeros(num_cls)
    unions = np.zeros(num_cls)

    errs = []
    hist = np.zeros((num_cls, num_cls))
    if len(loader) == 0:
        print('Empty data loader')
        return
    iterations = tqdm(enumerate(loader))
    for im_i, (im, label) in iterations:
        im = Variable(im.cuda())
        score = net(im).data
        _, preds = torch.max(score, 1)
        hist += fast_hist(label.numpy().flatten(),
                          preds.cpu().numpy().flatten(),
                          num_cls)
        acc_overall, acc_percls, iu, fwIU = result_stats(hist)
        iterations.set_postfix({'mIoU': ' {:0.2f}  fwIoU: {:0.2f} pixel acc: {:0.2f} per cls acc: {:0.2f}'.format(
            np.nanmean(iu), fwIU, acc_overall, np.nanmean(acc_percls))})
    print()
    print(','.join(classes))
    print(fmt_array(iu))
    print(np.nanmean(iu), fwIU, acc_overall, np.nanmean(acc_percls))
    print()
    print('Errors:', errs)
Exemplo n.º 2
0
def main(path, dataset, datadir, model, gpu, num_cls):
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    net = get_model(model, num_cls=num_cls, weights_init=path)
    net.eval()
    ds = get_fcn_dataset(dataset,
                         datadir,
                         split='val',
                         transform=net.transform,
                         target_transform=to_tensor_raw)
    classes = ds.classes
    loader = torch.utils.data.DataLoader(ds, num_workers=8)

    intersections = np.zeros(num_cls)
    unions = np.zeros(num_cls)

    errs = []
    hist = np.zeros((num_cls, num_cls))
    if len(loader) == 0:
        print('Empty data loader')
        return
    iterations = tqdm(enumerate(loader))
    for im_i, (im, label) in iterations:
        im = Variable(im.cuda())
        score = net(im).data
        _, preds = torch.max(score, 1)
        p = preds
        p = p.cpu().numpy().flatten()  # 481 868
        p = p.reshape(481, 868)
        p = p.astype(np.int32)
        image = Image.fromarray(p)
        image = image.convert('RGB')
        datas = image.getdata()
        newdata = []
        for item in datas:
            if item[0] == 0:
                newdata.append((151, 126, 171))
            elif item[0] == 1:
                newdata.append((232, 250, 80))
            elif item[0] == 2:
                newdata.append((55, 181, 57))
            elif item[0] == 3:
                newdata.append((187, 70, 156))
            else:
                newdata.append((0, 0, 0))
        image.putdata(newdata)
        image.save('./image/' + ds.ids[im_i] + '.png')
        hist += fast_hist(label.numpy().flatten(),
                          preds.cpu().numpy().flatten(), num_cls)
        acc_overall, acc_percls, iu, fwIU = result_stats(hist)
        iterations.set_postfix({
            'mIoU':
            ' {:0.2f}  fwIoU: {:0.2f} pixel acc: {:0.2f} per cls acc: {:0.2f}'.
            format(np.nanmean(iu), fwIU, acc_overall, np.nanmean(acc_percls))
        })
    print('Errors:', errs)
def main(path, dataset, datadir, model, gpu, num_cls):
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu

    #net = get_model('CalibratorNet', model=model, num_cls=num_cls, weights_init=path,cali_model = 'resnet_9blocks',pretrained=False)
    net = get_model(model, num_cls=num_cls, weights_init=path)
    net.load_state_dict(torch.load(path))
    net.eval()
    ds = get_fcn_dataset(dataset,
                         datadir,
                         split='val',
                         transform=net.transform,
                         target_transform=to_tensor_raw)
    #ds = get_fcn_dataset(dataset, datadir, split='val',
    #                     transform=torchvision.transforms.ToTensor(), target_transform=to_tensor_raw)
    classes = ds.classes
    loader = torch.utils.data.DataLoader(ds, num_workers=8)

    intersections = np.zeros(num_cls)
    unions = np.zeros(num_cls)

    errs = []
    hist = np.zeros((num_cls, num_cls))
    if len(loader) == 0:
        print('Empty data loader')
        return
    iterations = tqdm(enumerate(loader))
    count = 0
    res = []
    with torch.no_grad():
        for im_i, (im, label) in iterations:
            im = Variable(im.cuda())

            if count > 25:
                break
            source_out = torch.argmax(net(im).data, dim=1)

            res.append((im, source_out, label, count))
            #save_seg_results(max_ori_score,count)
            #save_seg_results(im,max_score,max_ori_score,count)
            '''
            _, preds = torch.max(score, 1)
            hist += fast_hist(label.numpy().flatten(),
                              preds.cpu().numpy().flatten(),
                              num_cls)
            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            iterations.set_postfix({'mIoU': ' {:0.2f}  fwIoU: {:0.2f} pixel acc: {:0.2f} per cls acc: {:0.2f}'.format(
                np.nanmean(iu), fwIU, acc_overall, np.nanmean(acc_percls))})
            '''
            count += 1
    with torch.no_grad():
        for r in res:
            im, score, label, id = r
            save_seg_results(im, score, label, id)
    '''
Exemplo n.º 4
0
def main(path, dataset, datadir, model, gpu, num_cls, batch_size, loadsize, finesize):
	os.environ['CUDA_VISIBLE_DEVICES'] = gpu
	loadSize=loadsize
	fineSize=finesize
	net = get_model(model, num_cls=num_cls, weights_init=path)
	
	str_ids = gpu.split(',')
	gpu_ids = []
	for str_id in str_ids:
		id = int(str_id)
		if id >= 0:
			gpu_ids.append(id)
	
	# set gpu ids
	if len(gpu_ids) > 0:
		torch.cuda.set_device(gpu_ids[0])
		assert (torch.cuda.is_available())
		net.to(gpu_ids[0])
		net = torch.nn.DataParallel(net, gpu_ids)
	
	net.eval()
	
	if (loadSize and fineSize) is not None:
		print("Loading Center Crop DataLoader Transform")
		data_transform = torchvision.transforms.Compose([transforms.Resize([int(loadSize), int(int(fineSize) * 1.8)], interpolation=Image.BICUBIC),
		                                                 net.module.transform.transforms[0], net.module.transform.transforms[1]])
		
		target_transform = torchvision.transforms.Compose([transforms.Resize([int(loadSize), int(int(fineSize) * 1.8)], interpolation=Image.NEAREST),
			 transforms.Lambda(lambda img: to_tensor_raw(img))])
	
	else:
		data_transform = net.module.transform
		target_transform = torchvision.transforms.Compose([transforms.Lambda(lambda img: to_tensor_raw(img))])
	
	ds = get_fcn_dataset(dataset, datadir, num_cls=num_cls, split='val', transform=data_transform, target_transform=target_transform)
	classes = ds.classes
	
	loader = torch.utils.data.DataLoader(ds, num_workers=16, batch_size=batch_size)

	errs = []
	hist = np.zeros((num_cls, num_cls))
	if len(loader) == 0:
		print('Empty data loader')
		return
	iterations = tqdm(enumerate(loader))
	for im_i, (im, label) in iterations:
		if im_i == 0:
			print(im.size())
			print(label.size())
		
		if im_i > 32:
			break
		
		im = Variable(im.cuda())
		score = net(im).data
		_, preds = torch.max(score, 1)
		hist += fast_hist(label.numpy().flatten(), preds.cpu().numpy().flatten(), num_cls)
		acc_overall, acc_percls, iu, fwIU = result_stats(hist)
		iterations.set_postfix({'mIoU': ' {:0.2f}  fwIoU: {:0.2f} pixel acc: {:0.2f} per cls acc: {:0.2f}'.format(np.nanmean(iu), fwIU, acc_overall,
		                                                                                                          np.nanmean(acc_percls))})
	print()
	
	synthia_metric_iu = 0
	
	# line = ""
	for index, item in enumerate(classes):
		print(classes[index], " {:0.1f}".format(iu[index]))
		if classes[index] != 'terrain' and classes[index] != 'truck' and classes[index] != 'train':
			synthia_metric_iu += iu[index]
			# line += " {:0.1f} &".format(iu[index])
			
	# variable "line" is used for adding format results into latex grids
	# print(line)
	
	print(np.nanmean(iu), fwIU, acc_overall, np.nanmean(acc_percls))
	print("16 Class-Wise mIOU is {}".format(synthia_metric_iu / 16))
	print('Errors:', errs)
	
	cur_path = path.split('/')[-1]
	parent_path = path.replace(cur_path, '')
	results_dict_path = os.path.join(parent_path, 'result.json')
	results_dict = {}
	results_dict[cur_path] = [np.nanmean(iu), synthia_metric_iu / 16]
	
	if os.path.exists(results_dict_path) is False:
		with open(results_dict_path, 'w') as fp:
			json.dump(results_dict, fp)
	else:
		with open(results_dict_path, 'r') as fp:
			exist_dict = json.load(fp)
		
		with open(results_dict_path, 'w') as fp:
			exist_dict.update(results_dict)
			json.dump(exist_dict, fp)
Exemplo n.º 5
0
def main(path, dataset, data_type, datadir, model, num_cls, mode):
    net = get_model(model, num_cls=num_cls)
    net.load_state_dict(torch.load(path))
    net.eval()
    ds = get_fcn_dataset(dataset,
                         data_type,
                         os.path.join(datadir, dataset),
                         split=mode)
    classes = ds.num_cls
    collate_fn = torch.utils.data.dataloader.default_collate

    loader = torch.utils.data.DataLoader(ds,
                                         num_workers=0,
                                         batch_size=16,
                                         shuffle=False,
                                         pin_memory=True,
                                         collate_fn=collate_fn)

    intersections = np.zeros(num_cls)
    unions = np.zeros(num_cls)

    ious = list()
    recalls = list()
    precisions = list()
    fscores = list()

    errs = []
    hist = np.zeros((num_cls, num_cls))

    if len(loader) == 0:
        print('Empty data loader')
        return
    iterations = tqdm(iter(loader))

    folderPath = '/'.join(
        path.split('/')[:-1]) + '/' + path.split('/')[-1].split('.')[0]

    os.makedirs(folderPath + '_worst_10', exist_ok=True)
    os.makedirs(folderPath + '_best_10', exist_ok=True)

    for i, (im, label) in enumerate(iterations):

        im = make_variable(im, requires_grad=False)
        label = make_variable(label, requires_grad=False)
        p = net(im)
        score = p

        iou = IoU(p, label)
        rc = recall(p, label)
        pr, rc, fs, _ = sklearnScores(p, label)

        ious.append(iou.item())

        recalls.append(rc)
        precisions.append(pr)
        fscores.append(fs)

        print("iou: ", np.mean(ious))
        print("recalls: ", np.mean(recalls))
        print("precision: ", np.mean(precisions))
        print("f1: ", np.mean(fscores))

    # Max, Min 10
    mx = list(np.argsort(ious)[-10:])
    mn = list(np.argsort(ious)[:10])

    iterations = tqdm(iter(loader))
    for i, (im, label) in enumerate(iterations):

        if i in mx:

            im = make_variable(im, requires_grad=False)
            label = make_variable(label, requires_grad=False)
            p = net(im)
            score = p

            saveImg(im, label, score,
                    folderPath + '_best_10' + "/img_" + str(i) + ".png")

        if i in mn:

            im = make_variable(im, requires_grad=False)
            label = make_variable(label, requires_grad=False)
            p = net(im)
            score = p

            saveImg(im, label, score,
                    folderPath + '_worst_10' + "/img_" + str(i) + ".png")

    print("=" * 100 + "\niou: ", np.mean(ious))
Exemplo n.º 6
0
def main(path, dataset, datadir, model, gpu, num_cls):
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu

    net = get_model('CalibratorNet',
                    model=model,
                    num_cls=num_cls,
                    weights_init=path,
                    cali_model='resnet_9blocks',
                    task='segmentation')
    #net = get_model(model, num_cls=num_cls, weights_init=path)
    #net.load_state_dict(torch.load(path))

    net.eval()

    transform = net.src_net.module.transform if hasattr(
        net.src_net, 'module') else net.src_net.transform

    ds = get_fcn_dataset(dataset,
                         datadir,
                         split='val',
                         transform=transform,
                         target_transform=to_tensor_raw)
    #ds = get_fcn_dataset(dataset, datadir, split='val',
    #                     transform=torchvision.transforms.ToTensor(), target_transform=to_tensor_raw)
    classes = ds.classes
    loader = torch.utils.data.DataLoader(ds, num_workers=8)

    intersections = np.zeros(num_cls)
    unions = np.zeros(num_cls)

    errs = []
    hist = np.zeros((num_cls, num_cls))
    if len(loader) == 0:
        print('Empty data loader')
        return
    iterations = tqdm(enumerate(loader))
    count = 0
    res = []

    with torch.no_grad():
        for im_i, (im, label) in iterations:

            im = Variable(im.cuda())
            pert = net.calibrator_T(im)
            #pert = torch.clamp(pert,0,0)
            score = net.src_net(torch.clamp(im + pert, -3, 3)).data

            max_score = torch.argmax(score, dim=1)
            #max_ori_score = torch.argmax(net(im).data,dim=1)

            #res.append((max_score,count))
            #save_seg_results(max_ori_score,count)
            #save_seg_results(im,max_score,max_ori_score,count)

            _, preds = torch.max(score, 1)
            hist += fast_hist(label.numpy().flatten(),
                              preds.cpu().numpy().flatten(), num_cls)
            acc_overall, acc_percls, iu, fwIU = result_stats(hist)
            iterations.set_postfix({
                'mIoU':
                ' {:0.2f}  fwIoU: {:0.2f} pixel acc: {:0.2f} per cls acc: {:0.2f}'
                .format(np.nanmean(iu), fwIU, acc_overall,
                        np.nanmean(acc_percls))
            })

            count += 1

    print()
    print(','.join(classes))
    print(fmt_array(iu))
    print(np.nanmean(iu), fwIU, acc_overall, np.nanmean(acc_percls))
    print()
    print('Errors:', errs)
Exemplo n.º 7
0
def main(config_path):
    config = None

    config_file = config_path.split('/')[-1]
    version = config_file.split('.')[0][1:]

    with open(config_path, 'r') as f:
        config = json.load(f)

    config["version"] = version
    config_logging()

    # Initialize SummaryWriter - For tensorboard visualizations
    logdir = 'runs/{:s}/{:s}/{:s}/{:s}'.format(config["model"],
                                               config["dataset"],
                                               'v{}'.format(config["version"]),
                                               'tflogs')
    logdir = logdir + "/"

    checkpointdir = join('runs', config["model"], config["dataset"],
                         'v{}'.format(config["version"]), 'checkpoints')

    print("Logging directory: {}".format(logdir))
    print("Checkpoint directory: {}".format(checkpointdir))

    versionpath = join('runs', config["model"], config["dataset"],
                       'v{}'.format(config["version"]))

    if not exists(versionpath):
        os.makedirs(versionpath)
        os.makedirs(checkpointdir)
        os.makedirs(logdir)
    elif exists(versionpath) and config["force"]:
        shutil.rmtree(versionpath)
        os.makedirs(versionpath)
        os.makedirs(checkpointdir)
        os.makedirs(logdir)
    else:
        print(
            "Version {} already exists! Please run with different version number"
            .format(config["version"]))
        logging.info(
            "Version {} already exists! Please run with different version number"
            .format(config["version"]))
        sys.exit(-1)

    writer = SummaryWriter(logdir)
    # Get appropriate model based on config parameters
    net = get_model(config["model"], num_cls=config["num_cls"])
    if args.load:
        net.load_state_dict(torch.load(args.load))
        print("============ Loading Model ===============")

    model_parameters = filter(lambda p: p.requires_grad, net.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    dataset = config["dataset"]
    num_workers = config["num_workers"]
    pin_memory = config["pin_memory"]
    dataset = dataset[0]

    datasets_train = get_fcn_dataset(config["dataset"],
                                     config["data_type"],
                                     join(config["datadir"],
                                          config["dataset"]),
                                     split='train')
    datasets_val = get_fcn_dataset(config["dataset"],
                                   config["data_type"],
                                   join(config["datadir"], config["dataset"]),
                                   split='val')
    datasets_test = get_fcn_dataset(config["dataset"],
                                    config["data_type"],
                                    join(config["datadir"], config["dataset"]),
                                    split='test')

    if config["weights"] is not None:
        weights = np.loadtxt(config["weights"])
    opt = torch.optim.SGD(net.parameters(),
                          lr=config["lr"],
                          momentum=config["momentum"],
                          weight_decay=0.0005)

    if config["augmentation"]:
        collate_fn = lambda batch: augment_collate(
            batch, crop=config["crop_size"], flip=True)
    else:
        collate_fn = torch.utils.data.dataloader.default_collate

    train_loader = torch.utils.data.DataLoader(datasets_train,
                                               batch_size=config["batch_size"],
                                               shuffle=True,
                                               num_workers=num_workers,
                                               collate_fn=collate_fn,
                                               pin_memory=pin_memory)

    # val_loader = torch.utils.data.DataLoader(datasets_val, batch_size=config["batch_size"],
    #                                         shuffle=True, num_workers=num_workers,
    #                                         collate_fn=collate_fn,
    #                                         pin_memory=pin_memory)

    test_loader = torch.utils.data.DataLoader(datasets_test,
                                              batch_size=config["batch_size"],
                                              shuffle=False,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn,
                                              pin_memory=pin_memory)

    data_metric = {'train': None, 'val': None, 'test': None}
    Q_size = len(train_loader) / config["batch_size"]

    metrics = {'losses': list(), 'ious': list(), 'recalls': list()}

    data_metric['train'] = copy(metrics)
    data_metric['val'] = copy(metrics)
    data_metric['test'] = copy(metrics)
    num_cls = config["num_cls"]
    hist = np.zeros((num_cls, num_cls))
    iteration = 0

    for epoch in range(config["num_epochs"] + 1):
        if config["phase"] == 'train':
            net.train()
            iterator = tqdm(iter(train_loader))

            # Epoch train
            print("Train Epoch!")
            for im, label in iterator:
                if torch.isnan(im).any() or torch.isnan(label).any():
                    import pdb
                    pdb.set_trace()
                iteration += 1
                # Clear out gradients
                opt.zero_grad()
                # load data/label
                im = make_variable(im, requires_grad=False)
                label = make_variable(label, requires_grad=False)
                #print(im.size())

                # forward pass and compute loss
                preds = net(im)
                #score = preds.data
                #_, pred = torch.max(score, 1)

                #hist += fast_hist(label.cpu().numpy().flatten(), pred.cpu().numpy().flatten(),num_cls)

                #acc_overall, acc_percls, iu, fwIU = result_stats(hist)
                loss = supervised_loss(preds, label)
                # iou = jaccard_score(preds, label)
                precision, rc, fscore, support, iou = sklearnScores(
                    preds, label.type(torch.IntTensor))
                #print(acc_overall, np.nanmean(acc_percls), np.nanmean(iu), fwIU)
                # backward pass
                loss.backward()

                # TODO: Right now this is running average, ideally we want true average. Make that change
                # Total average will be memory intensive, let it be running average for the moment.
                data_metric['train']['losses'].append(loss.item())
                data_metric['train']['ious'].append(iou)
                data_metric['train']['recalls'].append(rc)
                # step gradients
                opt.step()

                # Train visualizations - each iteration
                if iteration % config["train_tf_interval"] == 0:
                    vizz = preprocess_viz(im, preds, label)
                    writer.add_scalar('train/loss', loss, iteration)
                    writer.add_scalar('train/IOU', iou, iteration)
                    writer.add_scalar('train/recall', rc, iteration)
                    imutil = vutils.make_grid(torch.from_numpy(vizz),
                                              nrow=3,
                                              normalize=True,
                                              scale_each=True)
                    writer.add_image('{}_image_data'.format('train'), imutil,
                                     iteration)

                iterator.set_description("TRAIN V: {} | Epoch: {}".format(
                    config["version"], epoch))
                iterator.refresh()

                if iteration % 20000 == 0:
                    torch.save(
                        net.state_dict(),
                        join(checkpointdir,
                             'iter_{}_{}.pth'.format(iteration, epoch)))

            # clean before test/val
            opt.zero_grad()

            # Train visualizations - per epoch
            vizz = preprocess_viz(im, preds, label)
            writer.add_scalar('trainepoch/loss',
                              np.mean(data_metric['train']['losses']),
                              global_step=epoch)
            writer.add_scalar('trainepoch/IOU',
                              np.mean(data_metric['train']['ious']),
                              global_step=epoch)
            writer.add_scalar('trainepoch/recall',
                              np.mean(data_metric['train']['recalls']),
                              global_step=epoch)
            imutil = vutils.make_grid(torch.from_numpy(vizz),
                                      nrow=3,
                                      normalize=True,
                                      scale_each=True)
            writer.add_image('{}_image_data'.format('trainepoch'),
                             imutil,
                             global_step=epoch)

            print("Loss :{}".format(np.mean(data_metric['train']['losses'])))
            print("IOU :{}".format(np.mean(data_metric['train']['ious'])))
            print("recall :{}".format(np.mean(
                data_metric['train']['recalls'])))

            if epoch % config["checkpoint_interval"] == 0:
                torch.save(net.state_dict(),
                           join(checkpointdir, 'iter{}.pth'.format(epoch)))

            # Train epoch done. Free up lists
            for key in data_metric['train'].keys():
                data_metric['train'][key] = list()

            if epoch % config["val_epoch_interval"] == 0:
                net.eval()
                print("Val_epoch!")
                iterator = tqdm(iter(val_loader))
                for im, label in iterator:
                    # load data/label
                    im = make_variable(im, requires_grad=False)
                    label = make_variable(label, requires_grad=False)

                    # forward pass and compute loss
                    preds = net(im)
                    loss = supervised_loss(preds, label)
                    precision, rc, fscore, support, iou = sklearnScores(
                        preds, label.type(torch.IntTensor))

                    data_metric['val']['losses'].append(loss.item())
                    data_metric['val']['ious'].append(iou)
                    data_metric['val']['recalls'].append(rc)

                    iterator.set_description("VAL V: {} | Epoch: {}".format(
                        config["version"], epoch))
                    iterator.refresh()

                # Val visualizations
                vizz = preprocess_viz(im, preds, label)
                writer.add_scalar('valepoch/loss',
                                  np.mean(data_metric['val']['losses']),
                                  global_step=epoch)
                writer.add_scalar('valepoch/IOU',
                                  np.mean(data_metric['val']['ious']),
                                  global_step=epoch)
                writer.add_scalar('valepoch/Recall',
                                  np.mean(data_metric['val']['recalls']),
                                  global_step=epoch)
                imutil = vutils.make_grid(torch.from_numpy(vizz),
                                          nrow=3,
                                          normalize=True,
                                          scale_each=True)
                writer.add_image('{}_image_data'.format('val'),
                                 imutil,
                                 global_step=epoch)

                # Val epoch done. Free up lists
                for key in data_metric['val'].keys():
                    data_metric['val'][key] = list()

            # Epoch Test
            if epoch % config["test_epoch_interval"] == 0:
                net.eval()
                print("Test_epoch!")
                iterator = tqdm(iter(test_loader))
                for im, label in iterator:
                    # load data/label
                    im = make_variable(im, requires_grad=False)
                    label = make_variable(label, requires_grad=False)

                    # forward pass and compute loss
                    preds = net(im)
                    loss = supervised_loss(preds, label)
                    precision, rc, fscore, support, iou = sklearnScores(
                        preds, label.type(torch.IntTensor))

                    data_metric['test']['losses'].append(loss.item())
                    data_metric['test']['ious'].append(iou)
                    data_metric['test']['recalls'].append(rc)

                    iterator.set_description("TEST V: {} | Epoch: {}".format(
                        config["version"], epoch))
                    iterator.refresh()

                # Test visualizations
                writer.add_scalar('testepoch/loss',
                                  np.mean(data_metric['test']['losses']),
                                  global_step=epoch)
                writer.add_scalar('testepoch/IOU',
                                  np.mean(data_metric['test']['ious']),
                                  global_step=epoch)
                writer.add_scalar('testepoch/Recall',
                                  np.mean(data_metric['test']['recalls']),
                                  global_step=epoch)

                # Test epoch done. Free up lists
                for key in data_metric['test'].keys():
                    data_metric['test'][key] = list()

            if config["step"] is not None and epoch % config["step"] == 0:
                logging.info('Decreasing learning rate by 0.1 factor')
                step_lr(optimizer, 0.1)

    logging.info('Optimization complete.')