예제 #1
0
def test_model(model, sess, test_set, model_args):
    batch_size = model_args["batch_size"]
    total_steps = int(test_set.shape[0] / batch_size)

    mrr_list = {5: [], 20: []}
    hr_list = {5: [], 20: []}
    ndcg_list = {5: [], 20: []}

    time_buffer = []
    for batch_step in range(total_steps):
        test_batch = test_set[batch_step * batch_size:(batch_step + 1) *
                              batch_size, :]

        tic = time.time()
        pred_probs = sess.run(model.probs_test,
                              feed_dict={model.input_test: test_batch})
        toc = time.time()
        time_buffer.append(toc - tic)

        ground_truth = test_batch[:, -1]

        top_5_rank, top_20_rank = sample_top_ks(pred_probs, [5, 20])

        indices_5 = [
            np.argwhere(line == item)
            for line, item in zip(top_5_rank, ground_truth)
        ]
        indices_20 = [
            np.argwhere(line == item)
            for line, item in zip(top_20_rank, ground_truth)
        ]

        mrr5_sub, hr5_sub, ndcg5_sub = get_metric(indices_5)
        mrr20_sub, hr20_sub, ndcg20_sub = get_metric(indices_20)

        mrr_list[5].extend(mrr5_sub), mrr_list[20].extend(mrr20_sub)
        hr_list[5].extend(hr5_sub), hr_list[20].extend(hr20_sub)
        ndcg_list[5].extend(ndcg5_sub), ndcg_list[20].extend(ndcg20_sub)

        mrr_list[5].extend(mrr5_sub)
        hr_list[5].extend(hr5_sub)
        ndcg_list[5].extend(ndcg5_sub)

    logging.info("[Test] Time: {:.3f}s +- {:.3f}s per batch".format(
        np.mean(time_buffer), np.std(time_buffer)))

    ndcg_5, ndcg_20 = np.mean(ndcg_list[5]), np.mean(ndcg_list[20])
    mrr_5, mrr_20 = np.mean(mrr_list[5]), np.mean(mrr_list[20])
    hr_5, hr_20 = np.mean(hr_list[5]), np.mean(hr_list[20])

    logging.info("\t MRR@5: {:.4f},  HIT@5: {:.4f},  NDCG@5: {:.4f}".format(
        mrr_5, hr_5, ndcg_5))
    logging.info("\tMRR@20: {:.4f}, HIT@20: {:.4f}, NDCG@20: {:.4f}".format(
        mrr_20, hr_20, ndcg_20))

    return mrr_5
예제 #2
0
def start(global_args):
    preset(global_args)

    model_args, train_set, test_set = get_data_and_config(global_args)

    ratio = global_args["occupy"]
    if ratio is None:
        gpu_config = get_proto_config()
        logging.info("Auto-growth GPU memory.")
    else:
        gpu_config = get_proto_config_with_occupy(ratio)
        logging.info("{:.1f}% GPU memory occupied.".format(ratio * 100))

    sess = tf.Session(config=gpu_config)

    with tf.variable_scope("policy_net"):
        policy_net = PolicyNetGumbelGru(model_args)
        policy_net.build_policy()

    with tf.variable_scope(tf.get_variable_scope()):
        model = NextItNetGumbel(model_args)
        model.build_train_graph(policy_action=policy_net.action_predict)
        model.build_test_graph(policy_action=policy_net.action_predict)

    variables = tf.contrib.framework.get_variables_to_restore()
    model_variables = [
        v for v in variables if not v.name.startswith("policy_net")
    ]
    policy_variables = [
        v for v in variables if v.name.startswith("policy_net")
    ]

    with tf.variable_scope(tf.get_variable_scope()):
        optimizer_finetune = tf.train.AdamOptimizer(
            learning_rate=model_args["lr"], name="Adam_finetune")
        train_model = optimizer_finetune.minimize(model.loss,
                                                  var_list=model_variables)
    with tf.variable_scope("policy_net"):
        optimizer_policy = tf.train.AdamOptimizer(
            learning_rate=model_args["lr"], name="Adam_policy")
        train_policy = optimizer_policy.minimize(model.loss,
                                                 var_list=policy_variables)

    init = tf.global_variables_initializer()
    sess.run(init)

    # restore if needed
    if global_args["use_pre"]:
        restore_op = tf.train.Saver(var_list=model_variables)
        restore_op.restore(sess, global_args["pre"])
        sess.run(tf.assign(policy_net.item_embedding, model.item_embedding))
        logging.info(">>>>> Parameters loaded from pre-trained model.")
    else:
        logging.info(">>>>> Training without pre-trained model.")

    logging.info("Start @ {}".format(strftime("%m.%d-%H:%M:%S", localtime())))

    saver = tf.train.Saver(max_to_keep=3)

    batch_size = model_args["batch_size"]
    log_meter = model_args["log_every"]
    total_iters = model_args["iterations"]
    total_steps = int(train_set.shape[0] / batch_size)
    test_steps = int(test_set.shape[0] / batch_size)

    model_save_path = global_args["store_path"]
    model_name = global_args["name"]

    logging.info("Batch size = {}, Batches = {}".format(
        batch_size, total_steps))

    best_mrr_at5 = 0.0

    for idx in range(total_iters):
        logging.info("-" * 30)
        logging.info("Iter: {} / {}".format(idx + 1, total_iters))
        num_iter = 1
        tic = time.time()

        train_usage_sample = []
        for batch_step in range(total_steps):
            train_batch = train_set[batch_step * batch_size:(batch_step + 1) *
                                    batch_size, :]
            _, _, loss, action = sess.run(
                [
                    train_model, train_policy, model.loss,
                    policy_net.action_predict
                ],
                feed_dict={
                    model.input_train: train_batch,
                    policy_net.input: train_batch,
                },
            )
            train_usage_sample.extend(np.array(action).tolist())

            if num_iter % log_meter == 0:
                logging.info("\t{:5d} /{:5d} Loss: {:.3f}".format(
                    batch_step + 1, total_steps, loss))
            num_iter += 1

        summary_block(train_usage_sample, len(model_args["dilations"]),
                      "Train")

        # 1. eval model
        mrr_list = {5: [], 20: []}
        hr_list = {5: [], 20: []}
        ndcg_list = {5: [], 20: []}

        test_usage_sample = []
        for batch_step in range(test_steps):
            test_batch = test_set[batch_step * batch_size:(batch_step + 1) *
                                  batch_size, :]

            action, pred_probs = sess.run(
                [policy_net.action_predict, model.probs],
                feed_dict={
                    model.input_test: test_batch,
                    policy_net.input: test_batch,
                },
            )

            test_usage_sample.extend(np.array(action).tolist())

            ground_truth = test_batch[:, -1]
            top_5_rank, top_20_rank = sample_top_ks(pred_probs, [5, 20])
            indices_5 = [
                np.argwhere(line == item)
                for line, item in zip(top_5_rank, ground_truth)
            ]
            indices_20 = [
                np.argwhere(line == item)
                for line, item in zip(top_20_rank, ground_truth)
            ]

            mrr5_sub, hr5_sub, ndcg5_sub = get_metric(indices_5)
            mrr20_sub, hr20_sub, ndcg20_sub = get_metric(indices_20)

            mrr_list[5].extend(mrr5_sub), mrr_list[20].extend(mrr20_sub)
            hr_list[5].extend(hr5_sub), hr_list[20].extend(hr20_sub)
            ndcg_list[5].extend(ndcg5_sub), ndcg_list[20].extend(ndcg20_sub)

        summary_block(test_usage_sample, len(model_args["dilations"]), "Test")

        ndcg_5, ndcg_20 = np.mean(ndcg_list[5]), np.mean(ndcg_list[20])
        mrr_5, mrr_20 = np.mean(mrr_list[5]), np.mean(mrr_list[20])
        hr_5, hr_20 = np.mean(hr_list[5]), np.mean(hr_list[20])

        logging.info("<Metric>::TestSet")
        logging.info(
            "\t MRR@5: {:.4f},  HIT@5: {:.4f},  NDCG@5: {:.4f}".format(
                mrr_5, hr_5, ndcg_5))
        logging.info(
            "\tMRR@20: {:.4f}, HIT@20: {:.4f}, NDCG@20: {:.4f}".format(
                mrr_20, hr_20, ndcg_20))

        mrr_at5 = mrr_5

        # 2. save model
        if mrr_at5 > best_mrr_at5:
            logging.info(
                ">>>>> Saving model due to better MRR@5: {:.4f} <<<<< ".format(
                    mrr_at5))
            saver.save(
                sess,
                os.path.join(model_save_path,
                             "{}_{}.tfkpt".format(model_name, num_iter)),
            )
            best_mrr_at5 = mrr_at5

        toc = time.time()
        logging.info("Iter: {} / {} finish. Time: {:.2f} min".format(
            idx + 1, total_iters, (toc - tic) / 60))

    sess.close()
예제 #3
0
    def run_epoch(self, phase, epoch, data_loader, no_aug_loader=None):
        model_with_loss = self.model
        if phase == 'train':
            model_with_loss.train()
        else:
            if len(self.opt.gpus) > 1:
                model_with_loss = self.model.module
            model_with_loss.eval()
            torch.cuda.empty_cache()
        opt = self.opt
        results = {}
        data_time, batch_time = AverageMeter(), AverageMeter()
        avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
        num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
        bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
        end = time.time()

        if opt.save_video:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            vid_pth = os.path.join(opt.save_dir, opt.exp_id + '_pred')
            gt_pth = os.path.join(opt.save_dir, opt.exp_id + '_gt')
            out_pred = cv2.VideoWriter('{}.mp4'.format(vid_pth), fourcc,
                                       opt.save_framerate,
                                       (opt.input_w, opt.input_h))
            out_gt = cv2.VideoWriter('{}.mp4'.format(gt_pth), fourcc,
                                     opt.save_framerate,
                                     (opt.input_w, opt.input_h))

        delta_max = opt.delta_max
        delta_min = opt.delta_min
        delta = delta_min
        umax = opt.umax
        a_thresh = opt.acc_thresh
        metric = get_metric(opt)
        iter_id = 0

        data_iter = iter(data_loader)
        update_lst = []
        acc_lst = []
        coco_res_lst = []
        while True:
            load_time, total_model_time, model_time, update_time, tot_time, display_time = 0, 0, 0, 0, 0, 0
            start_time = time.time()
            # data loading
            try:
                batch = next(data_iter)
            except StopIteration:
                break

            if iter_id > opt.num_iters:
                break

            loaded_time = time.time()
            load_time += (loaded_time - start_time)

            if opt.adaptive:
                if iter_id % delta == 0:
                    u = 0
                    update = True
                    while (update):
                        output, tmp_model_time = self.run_model(batch)
                        total_model_time += tmp_model_time
                        # save the stuff every iteration
                        acc = metric.get_score(batch, output, u)
                        print(acc)
                        if u < umax and acc < a_thresh:
                            update_time = self.update_model(batch)
                        else:
                            update = False
                        u += 1
                    if acc > a_thresh:
                        delta = min(delta_max, 2 * delta)
                    else:
                        delta = max(delta_min, delta / 2)
                    output, _ = self.run_model(
                        batch)  # run model with new weights
                    model_time = total_model_time / u
                    update_lst += [(iter_id, u)]
                    acc_lst += [(iter_id, acc)]
                    self.accum_coco.store_metric_coco(iter_id, batch, output,
                                                      opt)
                else:
                    update_lst += [(iter_id, 0)]
                    output, model_time = self.run_model(batch)
                    if opt.acc_collect and (iter_id % opt.acc_interval == 0):
                        acc = metric.get_score(batch, output, 0)
                        print(acc)
                        acc_lst += [(iter_id, acc)]
                        self.accum_coco.store_metric_coco(
                            iter_id, batch, output, opt)
            else:
                output, model_time = self.run_model(batch)
                if opt.acc_collect:
                    acc = metric.get_score(batch, output, 0, is_baseline=True)
                    print(acc)
                    acc_lst += [(iter_id, acc)]
                    self.accum_coco.store_metric_coco(iter_id,
                                                      batch,
                                                      output,
                                                      opt,
                                                      is_baseline=True)

            display_start = time.time()

            if opt.tracking:
                trackers, viz_pred = self.tracking(
                    batch, output,
                    iter_id)  # TODO: factor this into the other class
                out_pred.write(viz_pred)
            elif opt.save_video:
                pred, gt = self.debug(batch, output, iter_id)
                out_pred.write(pred)
                out_gt.write(gt)
            if opt.debug > 1:
                self.debug(batch, output, iter_id)

            display_end = time.time()
            display_time = (display_end - display_start)
            end_time = time.time()
            tot_time = (end_time - start_time)

            # add a bunch of stuff to the bar to print
            Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
                epoch,
                iter_id,
                num_iters,
                phase=phase,
                total=bar.elapsed_td,
                eta=bar.eta_td)  # add to the progress bar
            if opt.print_iter > 0:
                if iter_id % opt.print_iter == 0:
                    print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
            else:
                bar.next()
            if opt.display_timing:
                time_str = 'total {:.3f}s| load {:.3f}s | model_time {:.3f}s | update_time {:.3f}s | display {:.3f}s'.format(
                    tot_time, load_time, model_time, update_time, display_time)
                print(time_str)
            self.save_result(output, batch, results)
            del output
            iter_id += 1

        bar.finish()
        ret = {k: v.avg for k, v in avg_loss_stats.items()}
        ret['time'] = bar.elapsed_td.total_seconds() / 60.
        save_dict = {}
        if opt.adaptive:
            plt.scatter(*zip(*update_lst))
            plt.xlabel('iteration')
            plt.ylabel('number of updates')
            plt.savefig(opt.save_dir + '/update_frequency.png')
            save_dict['updates'] = update_lst
            plt.clf()
        if opt.acc_collect:
            plt.scatter(*zip(*acc_lst))
            plt.xlabel('iteration')
            plt.ylabel('mAP')
            plt.savefig(opt.save_dir + '/acc_figure.png')
            save_dict['acc'] = acc_lst
        if opt.adaptive and opt.acc_collect:
            x, y = zip(*filter(lambda x: x[1] > 0, update_lst))
            plt.scatter(x, y, c='r', marker='o')
            plt.xlabel('iteration')

        # save dict
        # gt_dict = self.accum_coco.get_gt()
        # dt_dict = self.accum_coco.get_dt()
        dt_dict = self.accum_coco_det.get_dt()

        # save_dict['gt_dict'] = gt_dict
        # save_dict['dt_dict'] = dt_dict
        save_dict['full_res_pred'] = dt_dict
        return ret, results, save_dict
예제 #4
0
def start(global_args):
    preset(global_args)

    model_args, train_set, test_set = get_data_and_config(global_args)

    # ----------------------
    # Part.1 Build Model(s)
    # ----------------------
    ratio = global_args["occupy"]
    if ratio is None:
        gpu_config = get_proto_config()
        logging.info("Auto-growth GPU memory.")
    else:
        gpu_config = get_proto_config_with_occupy(ratio)
        logging.info("{:.1f}% GPU memory occupied.".format(ratio * 100))

    sess = tf.Session(config=gpu_config)

    with tf.variable_scope("policy_net"):
        policy_net = PolicyNetGumbelRL(model_args)
        policy_net.build_policy()

    with tf.variable_scope(tf.get_variable_scope()):
        model = NextItNetGumbelRL(model_args)
        model.build_train_graph(policy_action=policy_net.action_predict)
        model.build_test_graph(policy_action=policy_net.action_predict)

    # step-1, prepare parameters' name
    variables = tf.contrib.framework.get_variables_to_restore()
    model_variables = [
        v for v in variables if not v.name.startswith("policy_net")
    ]
    policy_variables = [
        v for v in variables if v.name.startswith("policy_net")
    ]

    # step-2, create optimizer
    with tf.variable_scope(tf.get_variable_scope()):
        optimizer_finetune = tf.train.AdamOptimizer(
            learning_rate=model_args["lr"], name="Adam_finetune")
        train_model = optimizer_finetune.minimize(model.loss,
                                                  var_list=model_variables)
        # train_model_rl = optimizer_finetune.minimize(
        #     policy_net.rl_loss, var_list=model_variables
        # )
    with tf.variable_scope("policy_net"):
        optimizer_policy = tf.train.AdamOptimizer(
            learning_rate=model_args["lr"], name="Adam_policy")
        train_policy = optimizer_policy.minimize(model.loss,
                                                 var_list=policy_variables)
        train_policy_rl = optimizer_policy.minimize(policy_net.rl_loss,
                                                    var_list=policy_variables)

    # step-4, restore parameters if needed
    if not global_args["resume"]:
        init = tf.global_variables_initializer()
        sess.run(init)
        start_at = 0
        if global_args["use_pre"]:
            # step-4.1 restore pre-trained parameters
            restore_op = tf.train.Saver(var_list=model_variables)
            restore_op.restore(sess, global_args["pre"])
            # step-4.2 copy embedding to policy-net
            sess.run(tf.assign(policy_net.item_embedding,
                               model.item_embedding))
            logging.info(">>>>> Parameters loaded from pre-trained model.")
        else:
            logging.info(">>>>> Training without pre-trained model.")
    else:
        resume_op = tf.train.Saver()
        resume_op.restore(sess, global_args["resume_path"])
        start_at = global_args["resume_at"]
        logging.info(
            ">>>>> Resume from checkpoint, start at epoch {}".format(start_at))

    # ----------------------
    # Part.2 Train
    # ----------------------
    logging.info("Start @ {}".format(strftime("%m.%d-%H:%M:%S", localtime())))

    saver = tf.train.Saver(max_to_keep=3)

    batch_size = model_args["batch_size"]
    log_meter = model_args["log_every"]
    total_iters = model_args["iter"]
    total_steps = int(train_set.shape[0] / batch_size)
    test_steps = int(test_set.shape[0] / batch_size)

    model_save_path = global_args["store_path"]
    model_name = global_args["name"]

    logging.info("Batch size = {}, Batches = {}".format(
        batch_size, total_steps))

    best_mrr_at5 = 0.0

    action_nums = len(model_args["dilations"])

    for idx in range(start_at, total_iters):
        logging.info("-" * 30)
        if idx < global_args["rl_iter"]:
            rl_str = "OFF"
        else:
            rl_str = " ON"
        logging.info("[RL-{}] Iter: {} / {}".format(rl_str, idx + 1,
                                                    total_iters))
        num_iter = 1
        tic = time.time()

        train_usage_sample = []
        for batch_step in range(total_steps):
            train_batch = train_set[batch_step * batch_size:(batch_step + 1) *
                                    batch_size, :]

            if idx >= global_args["rl_iter"]:
                # 1. soft_result
                # 2. map_result
                # 3. advantage -> reward -> optimize
                [soft_probs, soft_action] = sess.run(
                    [model.probs, policy_net.action_predict],
                    feed_dict={
                        model.input_test:
                        train_batch,
                        policy_net.input:
                        train_batch,
                        policy_net.method:
                        np.array(0),
                        policy_net.sample_action:
                        np.ones((batch_size, action_nums)),
                    },
                )
                [hard_probs, hard_action] = sess.run(
                    [model.probs, policy_net.action_predict],
                    feed_dict={
                        model.input_test:
                        train_batch,
                        policy_net.input:
                        train_batch,
                        policy_net.method:
                        np.array(1),
                        policy_net.sample_action:
                        np.ones((batch_size, action_nums)),
                    },
                )
                ground_truth = train_batch[:, -1]
                reward_soft = reward_fn(soft_probs, ground_truth, soft_action,
                                        global_args["gamma"])
                reward_hard = reward_fn(hard_probs, ground_truth, hard_action,
                                        global_args["gamma"])
                reward_train = reward_soft - reward_hard
                _, _, _, action, loss, rl_loss = sess.run(
                    [
                        train_policy_rl,
                        train_policy,
                        train_model,
                        policy_net.action_predict,
                        model.loss,
                        policy_net.rl_loss,
                    ],
                    feed_dict={
                        model.input_train: train_batch,
                        policy_net.input: train_batch,
                        policy_net.method: np.array(-1),
                        policy_net.sample_action: soft_action,
                        policy_net.reward: reward_train,
                    },
                )
                train_usage_sample.extend(np.array(action).tolist())
                if num_iter % log_meter == 0:
                    logging.info(
                        "\t{:5d} /{:5d} Loss: {:.3f}, RL-Loss: {:.3f}, Reward-Avg: {:.3f}"
                        .format(
                            batch_step + 1,
                            total_steps,
                            loss,
                            rl_loss,
                            np.mean(reward_train),
                        ))
                num_iter += 1
            else:
                [hard_action] = sess.run(
                    [policy_net.action_predict],
                    feed_dict={
                        policy_net.method:
                        np.array(1),
                        policy_net.input:
                        train_batch,
                        policy_net.sample_action:
                        np.ones((batch_size, action_nums)),
                    },
                )
                [_, _, action, loss] = sess.run(
                    [
                        train_model, train_policy, policy_net.action_predict,
                        model.loss
                    ],
                    feed_dict={
                        model.input_train: train_batch,
                        policy_net.input: train_batch,
                        policy_net.sample_action: hard_action,
                        policy_net.method: np.array(-1),
                    },
                )
                train_usage_sample.extend(np.array(action).tolist())

                if num_iter % log_meter == 0:
                    logging.info("\t{:5d} /{:5d} Loss: {:.3f}".format(
                        batch_step + 1, total_steps, loss))
                num_iter += 1

        summary_block(train_usage_sample, len(model_args["dilations"]),
                      "Train")

        # 1. eval model
        mrr_list = {5: [], 20: []}
        hr_list = {5: [], 20: []}
        ndcg_list = {5: [], 20: []}

        test_usage_sample = []
        for batch_step in range(test_steps):
            test_batch = test_set[batch_step * batch_size:(batch_step + 1) *
                                  batch_size, :]
            action, pred_probs = sess.run(
                [policy_net.action_predict, model.probs],
                feed_dict={
                    model.input_test: test_batch,
                    policy_net.input: test_batch,
                    policy_net.method: np.array(1),
                    policy_net.sample_action: np.ones(
                        (batch_size, action_nums)),
                },
            )
            test_usage_sample.extend(np.array(action).tolist())

            ground_truth = test_batch[:, -1]
            top_5_rank, top_20_rank = sample_top_ks(pred_probs, [5, 20])
            indices_5 = [
                np.argwhere(line == item)
                for line, item in zip(top_5_rank, ground_truth)
            ]
            indices_20 = [
                np.argwhere(line == item)
                for line, item in zip(top_20_rank, ground_truth)
            ]

            mrr5_sub, hr5_sub, ndcg5_sub = get_metric(indices_5)
            mrr20_sub, hr20_sub, ndcg20_sub = get_metric(indices_20)

            mrr_list[5].extend(mrr5_sub), mrr_list[20].extend(mrr20_sub)
            hr_list[5].extend(hr5_sub), hr_list[20].extend(hr20_sub)
            ndcg_list[5].extend(ndcg5_sub), ndcg_list[20].extend(ndcg20_sub)

        summary_block(test_usage_sample, len(model_args["dilations"]), "Test")

        ndcg_5, ndcg_20 = np.mean(ndcg_list[5]), np.mean(ndcg_list[20])
        mrr_5, mrr_20 = np.mean(mrr_list[5]), np.mean(mrr_list[20])
        hr_5, hr_20 = np.mean(hr_list[5]), np.mean(hr_list[20])

        logging.info("<Metric>::TestSet")
        logging.info(
            "\t MRR@5: {:.4f},  HIT@5: {:.4f},  NDCG@5: {:.4f}".format(
                mrr_5, hr_5, ndcg_5))
        logging.info(
            "\tMRR@20: {:.4f}, HIT@20: {:.4f}, NDCG@20: {:.4f}".format(
                mrr_20, hr_20, ndcg_20))

        mrr_at5 = mrr_5

        # 2. save model
        if mrr_at5 > best_mrr_at5:
            logging.info(
                ">>>>> Saving model due to better MRR@5: {:.4f} <<<<< ".format(
                    mrr_at5))
            saver.save(
                sess,
                os.path.join(model_save_path,
                             "{}_{}.tfkpt".format(model_name, num_iter)),
            )
            best_mrr_at5 = mrr_at5

        toc = time.time()
        logging.info("Iter: {} / {} finish. Time: {:.2f} min".format(
            idx + 1, total_iters, (toc - tic) / 60))

    sess.close()
예제 #5
0
def run_training(
        experiment_name,
        debug=False,
        only_ees=False,
        only_kinematics=False,
        use_neptune=False,
        epochs=2000,  # 20000
        train_batch=21330,  # 54726 // 2,
        val_batch=1123,  # 2000
        dtype=np.float32,
        val_split=0.05,
        shuffle_data=True,
        model_type="linear",  # "GRU",  # 'transformer',
        model_cfg=None,
        bptt=350,
        hidden_size=6,
        lr=1e-2,
        start_trim=700,
        log_interval=5,
        val_interval=20,
        clip_grad_norm=False,
        output_dir="results",
        normalize_input=True,
        optimizer="Adam",  # "AdamW",
        scheduler=None,  # "StepLR",
        train_weight=100.,
        batch_first=True,
        toss_allzero_mn=True,
        dumb_augment=False,
        score="pearson",
        metric="l2"):  # pearson
    """Run training and validation."""
    if use_neptune and NEPTUNE_IMPORTED:
        neptune.init("Serre-Lab/deepspine")
        if experiment_name is None:
            experiment_name = "synthetic_data"
        neptune.create_experiment(experiment_name)
    assert model_type is not None, "You must select a model."
    default_model_params = tools.get_model_defaults()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    timestamp = datetime.datetime.fromtimestamp(
        time.time()).strftime('%Y-%m-%d-%H_%M_%S')
    if model_cfg is None:
        print("Using default model cfg file.")
        model_cfg = model_type
    data = np.load(DATA_FILE)
    mn = data["mn"]
    ees = data["ees"]
    kinematics = data["kinematics"]
    X = torch.from_numpy(np.concatenate((ees, kinematics), 1).astype(dtype))
    Y = torch.from_numpy(mn.astype(dtype))
    X = X.permute(0, 2, 1)
    Y = Y.permute(0, 2, 1)
    if only_ees:
        X = X[..., 0][..., None]  # Only ees -- 0.73
    if only_kinematics:
        X = X[..., 1:]  # Only kinematics -- 0.89
    input_size = X.size(-1)
    output_size = Y.size(-1)
    meta = Meta(
        batch_first=batch_first,
        data_size=X.shape,
        train_batch=train_batch,
        val_batch=val_batch,
        val_split=val_split,
        model_type=model_type,
        model_cfg=model_cfg,
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        metric=metric,
        score=score,
        normalize_input=normalize_input,
        lr=lr,
        bptt=bptt,
        epochs=epochs,
        optimizer=optimizer,
        scheduler=scheduler,
        clip_grad_norm=clip_grad_norm,
        log_interval=log_interval,
        val_interval=val_interval,
        start_trim=start_trim,
        train_weight=train_weight,
        device=device)

    # Prepare data
    if toss_allzero_mn:
        # Restrict to nonzero mn fibers
        # mask = (Y.sum(1) > 127.5).sum(-1) == 2  # Ys where both are nonzero at some point
        mask = ((Y > 200).sum(1) > 0).sum(-1) == 2  # Ys where both are > 127.5 at some point
        # mask = ((Y > 127.5).sum(1) > 0).sum(-1) >= 1  # Ys where either is > 127.5 at some point
        print("Throwing out {} examples.".format((mask == False).sum()))
        X = X[mask]
        Y = Y[mask]
    if meta.start_trim:
        X = X.narrow(1, meta.start_trim, X.size(1) - meta.start_trim)
        Y = Y.narrow(1, meta.start_trim, Y.size(1) - meta.start_trim)

    if shuffle_data:
        idx = np.random.permutation(len(X))
        X = X[idx]
        Y = Y[idx]

    if meta.normalize_input:
        # X = (X - 127.5) / 127.5
        # Y = (Y - 127.5) / 127.5
        k_X = X[..., 1:]
        k_X = (k_X - k_X.mean(1, keepdim=True)) / (k_X.std(1, keepdim=True) + 1e-8)  # This is peaking but whatever...
        e_X = X[..., 0][..., None]
        e_X = e_X / 255.
        X = torch.cat((k_X, e_X), -1)
        if meta.metric != "bce":
            Y = (Y - Y.mean(1, keepdim=True)) / (Y.std(1, keepdim=True) + 1e-8)
            # Y = Y / 255.
        else:
            # Quantize Y
            Y = (Y > 127.5).float()
    X = X.to(meta.device)
    Y = Y.to(meta.device)
    cv_idx = np.arange(len(X))
    cv_idx = cv_idx > np.round(float(len(X)) * val_split).astype(int)
    X_train = X[cv_idx]
    Y_train = Y[cv_idx]
    X_val = X[~cv_idx]
    Y_val = Y[~cv_idx]
    assert meta.train_batch < len(X_train), "Train batch size > dataset size {}.".format(len(X_train) - 1)
    assert meta.val_batch < len(X_val), "Val batch size > dataset size {}.".format(len(X_val) - 1)

    if dumb_augment:
        X_train = torch.cat((X_train, X_train[:, torch.arange(X_train.size(1) - 1, -1, -1).long()]))
        Y_train = torch.cat((Y_train, Y_train[:, torch.arange(Y_train.size(1) - 1, -1, -1).long()]))

    if not meta.batch_first:
        X_train = X_train.permute(1, 0, 2)
        Y_train = Y_train.permute(1, 0, 2)
        X_val = X_val.permute(1, 0, 2)
        Y_val = Y_val.permute(1, 0, 2)

    # Create model
    model = modeling.create_model(
        batch_first=meta.batch_first,
        bptt=meta.bptt,
        model_type=meta.model_type,
        model_cfg=meta.model_cfg,
        input_size=meta.input_size,
        hidden_size=meta.hidden_size,
        output_size=meta.output_size,
        default_model_params=default_model_params,
        device=meta.device)
    num_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
    print('Total number of parameters: {}'.format(num_params))
    score, criterion = metrics.get_metric(metric, meta.batch_first)
    optimizer_fun = optimizers.get_optimizer(optimizer)
    assert lr < 1, "LR is greater than 1."
    if "adam" in optimizer.lower():
        optimizer = optimizer_fun(model.parameters(), lr=lr, amsgrad=True)
    else:
        optimizer = optimizer_fun(model.parameters(), lr=lr)
    if scheduler is not None:
        scheduler = optimizers.get_scheduler(scheduler) 
        scheduler = scheduler(optimizer)

    # Start training
    best_val_loss = float("inf")
    best_model = None
    X_val, _ = batchify(X_val, bsz=meta.val_batch, random=False, batch_first=meta.batch_first)
    Y_val, _ = batchify(Y_val, bsz=meta.val_batch, random=False, batch_first=meta.batch_first)
    for epoch in range(1, meta.epochs + 1):
        epoch_start_time = time.time()
        meta.epoch = epoch
        X_train_i, random_idx = batchify(
            X_train,
            bsz=meta.train_batch,
            random=True,
            batch_first=meta.batch_first)
        Y_train_i, _ = batchify(
            Y_train,
            bsz=meta.train_batch,
            random=random_idx,
            batch_first=meta.batch_first)
        min_train_loss, max_train_loss, train_output, train_gt = train(
            model=model,
            X=X_train_i,
            Y=Y_train_i,
            optimizer=optimizer,
            criterion=criterion,
            score=score,
            scheduler=scheduler,
            meta=meta)
        if epoch % meta.val_interval == 0:
            val_loss, val_score, val_output, val_gt = evaluate(
                model=model,
                X=X_val,
                Y=Y_val,
                criterion=criterion,
                score=score,
                meta=meta)
            meta.min_train_loss.append(min_train_loss)
            meta.max_train_loss.append(max_train_loss)
            meta.val_loss.append(val_loss)
            meta.val_score.append(val_score)
            print('-' * 89)
            print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid score {:5.2f}'.format(
                  epoch,
                  (time.time() - epoch_start_time),
                  meta.val_loss[-1],
                  meta.val_score[-1]))
            print('-' * 89)
            if use_neptune and NEPTUNE_IMPORTED:
                neptune.log_metric('min_train_loss', min_train_loss)
                neptune.log_metric('max_train_loss', max_train_loss)
                neptune.log_metric('val_{}'.format(meta.metric), val_loss)
                neptune.log_metric('val_pearson', val_score)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model = model
            if val_loss < 0.65 and debug:
                from matplotlib import pyplot as plt
                fig = plt.figure()
                plt.title('val')
                plt.subplot(211)
                plt.plot(val_output[50].cpu())
                plt.subplot(212)
                plt.plot(val_gt[50].cpu())
                plt.show()
                plt.close(fig)
                fig = plt.figure()
                plt.title('train')
                plt.subplot(211)
                plt.plot(train_output[50].cpu().detach())
                plt.subplot(212)
                plt.plot(train_gt[50].cpu())
                plt.show()
                plt.close(fig)
            if scheduler is not None:
                scheduler.step()

    # Fix some type issues
    meta.val_loss = [x.cpu() for x in meta.val_loss]
    meta.val_score = [x.cpu() for x in meta.val_score]
    np.savez(os.path.join(output_dir, '{}results_{}'.format(experiment_name, timestamp)), **meta.__dict__)  # noqa
    np.savez(os.path.join(output_dir, '{}example_{}'.format(experiment_name, timestamp)), train_output=train_output.cpu().detach(), train_gt=train_gt.cpu(), val_output=val_output.cpu(), val_gt=val_gt.cpu())
    torch.save(best_model.state_dict(), os.path.join(output_dir, '{}model_{}.pth'.format(experiment_name, timestamp)))