Beispiel #1
0
def main():
    device = torch.device("cuda:0")
    num_datapoints = 1984
    views = 10
    load_points = True
    prepareDir(vis_directory)
    # Need to replace sundermeyer-random with something where we can
    # use predetermined poses, to plot each arch to visualize
    datagen = DatasetGenerator("",
                            "./data/cad-files/ply-files/obj_10.ply",
                            375,
                            num_datapoints,
                            "not_used",
                            device,
                            "sundermeyer-random",
                            random_light=False)

    # load model
    encoder = Encoder("./data/obj1-18/encoder.npy").to(device)
    encoder.eval()
    checkpoint = torch.load("./output/paper-models/10views/obj10/models/model-epoch199.pt")

    model = Model(num_views=views).cuda()
    model.load_state_dict(checkpoint['model'])
    model = model.eval()
    pipeline = Pipeline(encoder, model, device)

    # around x
    shiftx = np.eye(3, dtype=np.float)
    theta = np.pi / num_datapoints
    #theta = np.pi / 3
    shiftx[1,1] = np.cos(theta)
    shiftx[1,2] = -np.sin(theta)
    shiftx[2,2] = np.cos(theta)
    shiftx[2,1] = np.sin(theta)
    # around y
    shifty = np.eye(3, dtype=np.float)
    theta = np.pi / num_datapoints
    shifty[0,0] = np.cos(theta)
    shifty[0,2] = -np.sin(theta)
    shifty[2,2] = np.cos(theta)
    shifty[2,0] = np.sin(theta)
    # around z
    shiftz = np.eye(3, dtype=np.float)
    theta = np.pi / num_datapoints
    shiftz[0,0] = np.cos(theta)
    shiftz[0,1] = -np.sin(theta)
    shiftz[1,1] = np.cos(theta)
    shiftz[1,0] = np.sin(theta)
    predicted_poses = []
    predicted_poses_raw = []
    R_conv = np.eye(3, dtype=np.float)
    #R_conv = np.array([[ 0.5435,  0.1365,  0.8283],
    #                   [ 0.6597,  0.5406, -0.5220],
    #                   [-0.5190,  0.8301,  0.2037]])
    #R_conv = np.array([[-0.7132,  0.0407,  0.6998],
    #                   [ 0.1696, -0.9586,  0.2287],
    #                   [ 0.6802,  0.2818,  0.6767]])
    #R_conv = np.array([[-0.9959,  0.0797,  0.0423],
    #                   [ 0.0444,  0.0249,  0.9987],
    #                   [ 0.0786,  0.9965, -0.0283]])

    if load_points:
        # Try with points from the sphere
        points = np.load('./output/depth/spherical_mapping_obj10_1_500/points.npy', allow_pickle=True)
        num_datapoints = len(points)

        Rin = []
        for point in points:
            Rin.append(pointToMat(point))
            #Rin.append(np.matmul(pointToMat(point), shiftx))

    else:
        Rin = []
        for i in range(num_datapoints):
            # get data from fixed R and T vectors
            R_conv = np.matmul(R_conv, shiftx)
            R = torch.from_numpy(R_conv)
            Rin.append(R)

    t = torch.tensor([0.0, 0.0, 375])
    # Generate images
    data = datagen.generate_image_batch(Rin = Rin, tin = t, augment = False)

    # run images through model
    # Predict poses
    output = pipeline.process(data["images"])

    # evaluate how output confidence and each view changes with input pose
    plot_confidences(output.detach().cpu().numpy())
    if load_points:
        plot_flat_landscape(points, output[:,0:views].detach().cpu().numpy())

    rotation_matrices = []
    for i in range(views):
        start = views + i*6
        end = views + (i + 1)*6
        curr_poses = output[:,start:end]
        matrices = compute_rotation_matrix_from_ortho6d(curr_poses)
        euler_angles = compute_euler_angles_from_rotation_matrices(matrices)
        print(matrices.shape)
        print(matrices[0:3])
        print(euler_angles.shape)
        print(euler_angles[0:3])
        exit()
dataloader_eval = DataLoader(torch_dataset,
                             batch_size=mini_batch_size,
                             shuffle=False,
                             num_workers=0)

# determine if CUDA is available at the compute node
if (torch.backends.cudnn.version() != None) and (USE_CUDA == True):

    # push dataloader to CUDA
    dataloader_eval = DataLoader(torch_dataset.cuda(),
                                 batch_size=mini_batch_size,
                                 shuffle=False)

# VISUALIZE LATENT SPACE REPRESETATION
# set networks in evaluation mode (don't apply dropout)
encoder_eval.eval()
decoder_eval.eval()

# init batch count
batch_count = 0

# iterate over epoch mini batches
for enc_transactions_batch in dataloader_eval:

    # determine latent space representation of all transactions
    z_enc_transactions_batch = encoder_eval(enc_transactions_batch)

    # case: initial batch
    if batch_count == 0:

        # collect reconstruction errors of batch
def main():
    global optimizer, lr_reducer, views, epoch, pipeline
    # Read configuration file
    parser = argparse.ArgumentParser()
    parser.add_argument("experiment_name")
    arguments = parser.parse_args()

    cfg_file_path = os.path.join("./experiments", arguments.experiment_name)
    args = configparser.ConfigParser()
    args.read(cfg_file_path)

    seed=args.getint('Training', 'RANDOM_SEED')
    if(seed is not None):
        torch.manual_seed(seed)
        #torch.use_deterministic_algorithms(True) # Requires pytorch>=1.8.0
        #torch.backends.cudnn.deterministic = True
        np.random.seed(seed=seed)
        ia.seed(seed)
        random.seed(seed)

    model_seed=args.getint('Training', 'MODEL_RANDOM_SEED', fallback=None)
    if(model_seed is not None):
        torch.manual_seed(model_seed)

    # Prepare rotation matrices for multi view loss function
    eulerViews = json.loads(args.get('Rendering', 'VIEWS'))
    views = prepareViews(eulerViews)

    # Set the cuda device
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)

    # Handle loading of multiple object paths
    try:
        model_path_loss = json.loads(args.get('Dataset', 'MODEL_PATH_LOSS'))
    except:
        model_path_loss = [args.get('Dataset', 'MODEL_PATH_LOSS')]

    # Set up batch renderer
    br = BatchRender(model_path_loss,
                     device,
                     batch_size=args.getint('Training', 'BATCH_SIZE'),
                     faces_per_pixel=args.getint('Rendering', 'FACES_PER_PIXEL'),
                     render_method=args.get('Rendering', 'SHADER'),
                     image_size=args.getint('Rendering', 'IMAGE_SIZE'),
                     norm_verts=args.getboolean('Rendering', 'NORMALIZE_VERTICES'))

    # Set size of model output depending on pose representation - deprecated?
    pose_rep = args.get('Training', 'POSE_REPRESENTATION')
    if(pose_rep == '6d-pose'):
        pose_dim = 6
    elif(pose_rep == 'quat'):
        pose_dim = 4
    elif(pose_rep == 'axis-angle'):
        pose_dim = 4
    elif(pose_rep == 'euler'):
        pose_dim = 3
    else:
        print("Unknown pose representation specified: ", pose_rep)
        pose_dim = -1

    # Initialize a model using the renderer, mesh and reference image
    model = Model(num_views=len(views),
                  weight_init_name=args.get('Training', 'WEIGHT_INIT_NAME', fallback=""))
    model.to(device)

    # Create an optimizer. Here we are using Adam and we pass in the parameters of the model
    low_lr = args.getfloat('Training', 'LEARNING_RATE_LOW')
    high_lr = args.getfloat('Training', 'LEARNING_RATE_HIGH')
    optimizer = torch.optim.Adam(model.parameters(), lr=low_lr)
    lr_reducer = OneCycleLR(optimizer, num_steps=args.getfloat('Training', 'NUM_ITER'), lr_range=(low_lr, high_lr))

    # Prepare output directories
    output_path = args.get('Training', 'OUTPUT_PATH')
    prepareDir(output_path)
    shutil.copy(cfg_file_path, os.path.join(output_path, cfg_file_path.split("/")[-1]))

    # Setup early stopping if enabled
    early_stopping = args.getboolean('Training', 'EARLY_STOPPING', fallback=False)
    if early_stopping:
        window = args.getint('Training', 'STOPPING_WINDOW', fallback=10)
        time_limit = args.getint('Training', 'STOPPING_TIME_LIMIT', fallback=10)
        window_means = []
        lowest_mean = np.inf
        lowest_x = 0
        timer = 0

    # Load checkpoint for last epoch if it exists
    model_path = latestCheckpoint(os.path.join(output_path, "models/"))
    if(model_path is not None):
        model, optimizer, epoch, lr_reducer = loadCheckpoint(model_path)

    if early_stopping:
        validation_csv=os.path.join(output_path, "validation-loss.csv")
        if os.path.exists(validation_csv):
            with open(validation_csv) as f:
                val_reader = csv.reader(f, delimiter='\n')
                val_loss = list(val_reader)
            val_losses = np.array(val_loss, dtype=np.float32).flatten()
            for epoch in range(window,len(val_loss)):
                timer += 1
                w_mean = np.mean(val_losses[epoch-window:epoch])
                window_means.append(w_mean)
                if w_mean < lowest_mean:
                    lowest_mean = w_mean
                    lowest_x = epoch
                    timer = 0


    # Prepare pipeline
    encoder = Encoder(args.get('Dataset', 'ENCODER_WEIGHTS')).to(device)
    encoder.eval()
    pipeline = Pipeline(encoder, model, device)

    # Handle loading of multiple object paths and translations
    try:
        model_path_data = json.loads(args.get('Dataset', 'MODEL_PATH_DATA'))
        translations = np.array(json.loads(args.get('Rendering', 'T')))
    except:
        model_path_data = [args.get('Dataset', 'MODEL_PATH_DATA')]
        translations = [np.array(json.loads(args.get('Rendering', 'T')))]

    # Prepare datasets
    bg_path = "../../autoencoder_ws/data/VOC2012/JPEGImages/"
    training_data = DatasetGenerator(args.get('Dataset', 'BACKGROUND_IMAGES'),
                                     model_path_data,
                                     translations,
                                     args.getint('Training', 'BATCH_SIZE'),
                                     "not_used",
                                     device,
                                     sampling_method = args.get('Training', 'VIEW_SAMPLING'),
                                     max_rel_offset = args.getfloat('Training', 'MAX_REL_OFFSET', fallback=0.2),
                                     augment_imgs = args.getboolean('Training', 'AUGMENT_IMGS', fallback=True),
                                     seed=args.getint('Training', 'RANDOM_SEED'))
    training_data.max_samples = args.getint('Training', 'NUM_SAMPLES')

    # Load the validationset
    validation_data = loadDataset(json.loads(args.get('Dataset', 'VALID_DATA_PATH')),
                                  args.getint('Training', 'BATCH_SIZE'))
    print("Loaded validation set!")

    # Start training
    while(epoch < args.getint('Training', 'NUM_ITER')):
        # Train on synthetic data
        model = model.train() # Set model to train mode
        loss = runEpoch(br, training_data, model, device, output_path,
                          t=translations, config=args)
        append2file([loss], os.path.join(output_path, "train-loss.csv"))
        append2file([lr_reducer.get_lr()], os.path.join(output_path, "learning-rate.csv"))

        # Test on validation data
        model = model.eval() # Set model to eval mode
        val_loss = runEpoch(br, validation_data, model, device, output_path,
                          t=translations, config=args)
        append2file([val_loss], os.path.join(output_path, "validation-loss.csv"))

        # Plot losses
        val_losses = plotLoss(os.path.join(output_path, "train-loss.csv"),
                 os.path.join(output_path, "train-loss.png"),
                 validation_csv=os.path.join(output_path, "validation-loss.csv"))
        print("-"*20)
        print("Epoch: {0} - train loss: {1} - validation loss: {2}".format(epoch,loss,val_loss))
        print("-"*20)
        if early_stopping and epoch >= window:
            timer += 1
            if timer > time_limit:
                # print stuff here
                print()
                print("-"*60)
                print("Validation loss seems to have plateaued, stopping early.")
                print("Best mean loss value over an epoch window of size {} was found at epoch {} ({:.8f} mean loss)".format(window, lowest_x, lowest_mean))
                print("-"*60)
                break
            w_mean = np.mean(val_losses[epoch-window:epoch])
            window_means.append(w_mean)
            if w_mean < lowest_mean:
                lowest_mean = w_mean
                lowest_x = epoch
                timer = 0
        epoch = epoch+1
Beispiel #4
0
def main():
    visualize = True

    # Read configuration file
    parser = argparse.ArgumentParser()
    parser.add_argument("-mp", help="path to the model checkpoint")
    parser.add_argument("-ep", help="path to the encoder weights")
    parser.add_argument("-pi", help="path to the pickle input file")
    parser.add_argument("-op",
                        help="path to the CAD model for the object",
                        default=None)
    parser.add_argument("-o", help="output path", default="./output.csv")
    args = parser.parse_args()

    # Load dataset
    data = pickle.load(open(args.pi, "rb"), encoding="latin1")

    # Run prepare our model if needed
    if ("Rs_predicted" not in data):

        # Set the cuda device
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)

        # Initialize a model
        model = Model().to(device)

        # Load model checkpoint
        model, optimizer, epoch, learning_rate = loadCheckpoint(
            args.mp, device)
        model.to(device)
        model.eval()

        # Load and prepare encoder
        encoder = Encoder(args.ep).to(device)
        encoder.eval()

        # Setup the pipeline
        pipeline = Pipeline(encoder, model, device)

    # Prepare renderer if defined
    obj_path = args.op
    if (obj_path is not None):
        obj_model = inout.load_ply(obj_path.replace(".obj", ".ply"))
        img_size = 128
        K = np.array([
            1075.65091572, 0.0, 128.0 / 2.0, 0.0, 1073.90347929, 128.0 / 2.0,
            0.0, 0.0, 1.0
        ]).reshape(3, 3)
        renderer = Renderer(obj_model, (img_size, img_size),
                            K,
                            surf_color=(1, 1, 1),
                            mode='rgb',
                            random_light=False)
    else:
        renderer = None

    # Store results in a dict
    results = {
        "scene_id": [],
        "im_id": [],
        "obj_id": [],
        "score": [],
        "R": [],
        "t": [],
        "time": []
    }

    # Loop through dataset
    for i, img in enumerate(data["images"]):
        print("Current image: {0}/{1}".format(i + 1, len(data["images"])))

        if ("Rs_predicted" in data):
            R_predicted = data["Rs_predicted"][i]
        else:

            # Run through model
            predicted_poses = pipeline.process([img])

            # Find best pose
            num_views = int(predicted_poses.shape[1] / (6 + 1))
            pose_start = num_views
            pose_end = pose_start + 6
            best_pose = 0.0
            R_predicted = None

            for k in range(num_views):
                # Extract current pose and move to next one
                curr_pose = predicted_poses[:, pose_start:pose_end]
                print(curr_pose)
                Rs_predicted = compute_rotation_matrix_from_ortho6d(curr_pose)
                Rs_predicted = Rs_predicted.detach().cpu().numpy()[0]
                pose_start = pose_end
                pose_end = pose_start + 6

                conf = predicted_poses[:, k].detach().cpu().numpy()[0]
                if (conf > best_pose):
                    R_predicted = Rs_predicted
                    best_pose = conf

            # Invert xy axes
            xy_flip = np.eye(3, dtype=np.float)
            xy_flip[0, 0] = -1.0
            xy_flip[1, 1] = -1.0
            R_predicted = R_predicted.dot(xy_flip)

            # Inverse rotation matrix
            R_predicted = np.transpose(R_predicted)

        results["scene_id"].append(data["scene_ids"][i])
        results["im_id"].append(data["img_ids"][i])
        results["obj_id"].append(data["obj_ids"][i])
        results["score"].append(-1)
        results["R"].append(arr2str(R_predicted))
        results["t"].append(arr2str(data["ts"][i]))
        results["time"].append(-1)

        if (renderer is None):
            visualize = False

        if (visualize):
            t_gt = np.array(data["ts"][i])
            t = np.array([0, 0, t_gt[2]])

            # Render predicted pose
            R_predicted = correct_trans_offset(R_predicted, t_gt)
            ren_predicted = renderer.render(R_predicted, t)

            # Render groundtruth pose
            R_gt = data["Rs"][i]
            R_gt = correct_trans_offset(R_gt, t_gt)
            ren_gt = renderer.render(R_gt, t)

            cv2.imshow("gt render", np.flip(ren_gt, axis=2))
            cv2.imshow("predict render", np.flip(ren_predicted, axis=2))

            cv2.imshow("input image", np.flip(img, axis=2))
            if ("codebook_images" in data):
                cv2.imshow("codebook image",
                           np.flip(data["codebook_images"][i], axis=2))

            print(ren_gt.shape)
            print(ren_predicted.shape)
            print(img.shape)
            numpy_horizontal_concat = np.concatenate(
                (np.flip(ren_gt, axis=2), np.flip(
                    ren_predicted, axis=2), np.flip(img, axis=2)),
                axis=1)
            cv2.imshow("gt - prediction - input", numpy_horizontal_concat)
            key = cv2.waitKey(0)
            if (key == ord("q")):
                exit()
                visualize = False
                #break
                continue

    # Save to CSV
    output_path = args.o
    print("Saving to: ", output_path)
    with open(output_path, "w") as f:
        col_names = list(results.keys())
        w = csv.DictWriter(f, results.keys())
        w.writeheader()
        num_lines = len(results[col_names[0]])

        for i in np.arange(num_lines):
            row_dict = {}
            for c in col_names:
                row_dict[c] = results[c][i]
            w.writerow(row_dict)
Beispiel #5
0
def main():
    device = torch.device("cuda:0")

    datagen = DatasetGenerator("./data/VOC2012/JPEGImages/",
                               ["./data/cad-files/ply-files/obj_10.ply"],
                               [[0.0, 0.0, 375.0]],
                               batch_size,
                               "not_used",
                               device,
                               "sundermeyer-random",
                               random_light=True,
                               num_bgs=200)
    datagen.max_rel_offset = 0.0
    datagen.pad_factor = 3.0
    datagen.img_size = 256
    datagen.backgrounds = datagen.load_bg_images("backgrounds",
                                                 "./data/VOC2012/JPEGImages/",
                                                 200, 256, 256)
    # create output directory
    if not os.path.exists('embedding-output/'):
        os.makedirs('embedding-output/')

    # load model
    encoder = Encoder("./data/obj1-18/encoder.npy").to(device)
    encoder.eval()

    shifts = [
        (0.0, 0.0),  #0
        (-1.0, -1.0),  #1
        (0.0, -1.0),  #2
        (1.0, -1.0),  #3
        (1.0, 0.0),  #4
        (1.0, 1.0),  #5
        (0.0, 1.0),  #6
        (-1.0, 1.0),  #7
        (-1.0, 0.0)
    ]  #8

    Rin = torch.from_numpy(np.vstack([np.eye(3)] * batch_size))
    tin = torch.from_numpy(np.array([0.0, 0.0, 375.0]))

    print(Rin.shape)

    labels = []
    codes = []
    #np.random.seed(seed=10)
    data = datagen.generate_image_batch(augment=True)
    curr_img = 0
    for i in range(iterations):
        org_img = data["images"][curr_img]

        curr_img += 1
        if (curr_img == batch_size):
            #np.random.seed(seed=10)
            data = datagen.generate_image_batch(augment=True)
            curr_img = 0

        # Sample the same random shift
        obj_bb = 50, 50, 156, 156
        rand_trans_x = np.random.uniform(0.25, 0.5) * 156
        rand_trans_y = np.random.uniform(0.25, 0.5) * 156

        for k, s in enumerate(shifts):
            print("img{0}-shift{1}.png".format(i, k))
            # Apply shift
            obj_bb_off = obj_bb + np.array(
                [s[0] * rand_trans_x, s[1] * rand_trans_y, 0, 0])

            cropped = extract_square_patch(org_img,
                                           obj_bb_off,
                                           pad_factor=1.2,
                                           resize=(128, 128))

            # Normalize image
            img = cropped
            img_max = np.max(img)
            img_min = np.min(img)
            img = (img - img_min) / (img_max - img_min)

            # Run image through encoder
            img = torch.from_numpy(img).unsqueeze(0).permute(0, 3, 1,
                                                             2).to(device)
            code = encoder(img.float())
            code = code.detach().cpu().numpy()[0]
            norm_code = code / np.linalg.norm(code)

            # Add labels
            codes.append(norm_code)
            labels.append(k)
            #labels.append(curr_img)

            if (visualize):
                fig = plt.figure()
                plt.imshow(cropped)
                fig.savefig("embedding-output/test{0}-shift{1}.png".format(
                    i, k),
                            dpi=fig.dpi)
                plt.close()
            #break

    # Apply t-SNE
    X = np.array(codes)
    X_embedded = TSNE(n_components=2,
                      n_jobs=-1,
                      learning_rate=200.0,
                      perplexity=30.0).fit_transform(X)
    print(X_embedded.shape)

    fig = plt.figure()
    plt.scatter(X[:, 0], X[:, 1], c=np.array(labels), cmap="tab10")
    fig.savefig("embedding-output/embedding-{0}samples.png".format(iterations),
                dpi=fig.dpi)
    plt.close()
def main():
    global optimizer, lr_reducer, views, epoch, pipeline
    # Read configuration file
    parser = argparse.ArgumentParser()
    parser.add_argument("experiment_name")
    arguments = parser.parse_args()

    cfg_file_path = os.path.join("./experiments", arguments.experiment_name)
    args = configparser.ConfigParser()
    args.read(cfg_file_path)

    # Prepare rotation matrices for multi view loss function
    eulerViews = json.loads(args.get('Rendering', 'VIEWS'))
    views = prepareViews(eulerViews)

    # Set the cuda device
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)

    # Handle loading of multiple object paths
    try:
        model_path_loss = json.loads(args.get('Dataset', 'MODEL_PATH_LOSS'))
    except:
        model_path_loss = [args.get('Dataset', 'MODEL_PATH_LOSS')]

    # Set up batch renderer
    br = BatchRender(model_path_loss,
                     device,
                     batch_size=args.getint('Training', 'BATCH_SIZE'),
                     faces_per_pixel=args.getint('Rendering',
                                                 'FACES_PER_PIXEL'),
                     render_method=args.get('Rendering', 'SHADER'),
                     image_size=args.getint('Rendering', 'IMAGE_SIZE'),
                     norm_verts=args.getboolean('Rendering',
                                                'NORMALIZE_VERTICES'))

    # Set size of model output depending on pose representation - deprecated?
    pose_rep = args.get('Training', 'POSE_REPRESENTATION')
    if (pose_rep == '6d-pose'):
        pose_dim = 6
    elif (pose_rep == 'quat'):
        pose_dim = 4
    elif (pose_rep == 'axis-angle'):
        pose_dim = 4
    elif (pose_rep == 'euler'):
        pose_dim = 3
    else:
        print("Unknown pose representation specified: ", pose_rep)
        pose_dim = -1

    # Initialize a model using the renderer, mesh and reference image
    model = Model(num_views=len(views))
    model.to(device)

    # Create an optimizer. Here we are using Adam and we pass in the parameters of the model
    low_lr = args.getfloat('Training', 'LEARNING_RATE_LOW')
    high_lr = args.getfloat('Training', 'LEARNING_RATE_HIGH')
    optimizer = torch.optim.Adam(model.parameters(), lr=low_lr)
    lr_reducer = ExponentialLR(optimizer, high_lr,
                               args.getfloat('Training', 'NUM_ITER'))

    # Prepare output directories
    output_path = args.get('Training', 'OUTPUT_PATH')
    prepareDir(output_path)
    shutil.copy(cfg_file_path,
                os.path.join(output_path,
                             cfg_file_path.split("/")[-1]))

    # Prepare pipeline
    encoder = Encoder(args.get('Dataset', 'ENCODER_WEIGHTS')).to(device)
    encoder.eval()
    pipeline = Pipeline(encoder, model, device)

    # Handle loading of multiple object paths and translations
    try:
        model_path_data = json.loads(args.get('Dataset', 'MODEL_PATH_DATA'))
        translations = np.array(json.loads(args.get('Rendering', 'T')))
    except:
        model_path_data = [args.get('Dataset', 'MODEL_PATH_DATA')]
        translations = [np.array(json.loads(args.get('Rendering', 'T')))]

    # Prepare datasets
    bg_path = "../../autoencoder_ws/data/VOC2012/JPEGImages/"
    training_data = DatasetGenerator(args.get('Dataset', 'BACKGROUND_IMAGES'),
                                     model_path_data, translations,
                                     args.getint('Training', 'BATCH_SIZE'),
                                     "not_used", device,
                                     args.get('Training', 'VIEW_SAMPLING'))
    training_data.max_samples = args.getint('Training', 'NUM_SAMPLES')

    # Start training
    np.random.seed(seed=args.getint('Training', 'RANDOM_SEED'))
    while (epoch < args.getint('Training', 'NUM_ITER')):
        # Train on synthetic data
        model = model.train()  # Set model to train mode
        loss = runEpoch(br,
                        training_data,
                        model,
                        device,
                        output_path,
                        t=translations,
                        config=args)
        append2file([loss], os.path.join(output_path, "train-loss.csv"))
        append2file([lr_reducer.get_lr()],
                    os.path.join(output_path, "learning-rate.csv"))

        # Plot losses
        val_losses = plotLoss(
            os.path.join(output_path, "train-loss.csv"),
            os.path.join(output_path, "train-loss.png"),
            validation_csv=os.path.join(output_path, "train-loss.csv"),
        )
        print("-" * 20)
        print("Epoch: {0} - train loss: {1}".format(epoch, loss))
        print("-" * 20)
        epoch = epoch + 1