Exemplo n.º 1
0
    def finetune(index):
        seq_times = env.get_seq_times(index, args.seq_length)
        _, (allxs, allys) = env.seq_call(seq_times)
        allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
        if env.meta_info["task"] == "classification":
            allys = allys.view(-1)
        historical_x, historical_y = allxs.to(args.device), allys.to(
            args.device)
        model = get_model(**model_kwargs)
        model = model.to(args.device)

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.init_lr,
                                     amsgrad=True)
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[
                int(args.epochs * 0.25),
                int(args.epochs * 0.5),
                int(args.epochs * 0.75),
            ],
            gamma=0.3,
        )

        train_metric = metric_cls(True)
        best_loss, best_param = None, None
        for _iepoch in range(args.epochs):
            preds = model(historical_x)
            optimizer.zero_grad()
            loss = criterion(preds, historical_y)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            # save best
            if best_loss is None or best_loss > loss.item():
                best_loss = loss.item()
                best_param = copy.deepcopy(model.state_dict())
        model.load_state_dict(best_param)
        # model.analyze_weights()
        with torch.no_grad():
            train_metric(preds, historical_y)
        train_results = train_metric.get_info()
        return train_results, model
Exemplo n.º 2
0
def main(args):
    logger, env_info, model_kwargs = lfna_setup(args)

    # check indexes to be evaluated
    to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"],
                                            None)
    logger.log("Evaluate {:}, which has {:} timestamps in total.".format(
        args.srange, len(to_evaluate_indexes)))

    w_container_per_epoch = dict()

    per_timestamp_time, start_time = AverageMeter(), time.time()
    for i, idx in enumerate(to_evaluate_indexes):

        need_time = "Time Left: {:}".format(
            convert_secs2time(
                per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True))
        logger.log("[{:}]".format(time_string()) +
                   " [{:04d}/{:04d}][{:04d}]".format(i, len(
                       to_evaluate_indexes), idx) + " " + need_time)
        # train the same data
        assert idx != 0
        historical_x, historical_y = [], []
        for past_i in range(idx):
            historical_x.append(env_info["{:}-x".format(past_i)])
            historical_y.append(env_info["{:}-y".format(past_i)])
        historical_x, historical_y = torch.cat(historical_x), torch.cat(
            historical_y)
        historical_x, historical_y = subsample(historical_x, historical_y)
        # build model
        model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
        # build optimizer
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.init_lr,
                                     amsgrad=True)
        criterion = torch.nn.MSELoss()
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[
                int(args.epochs * 0.25),
                int(args.epochs * 0.5),
                int(args.epochs * 0.75),
            ],
            gamma=0.3,
        )
        train_metric = MSEMetric()
        best_loss, best_param = None, None
        for _iepoch in range(args.epochs):
            preds = model(historical_x)
            optimizer.zero_grad()
            loss = criterion(preds, historical_y)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            # save best
            if best_loss is None or best_loss > loss.item():
                best_loss = loss.item()
                best_param = copy.deepcopy(model.state_dict())
        model.load_state_dict(best_param)
        with torch.no_grad():
            train_metric(preds, historical_y)
        train_results = train_metric.get_info()

        metric = ComposeMetric(MSEMetric(), SaveMetric())
        eval_dataset = torch.utils.data.TensorDataset(
            env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)])
        eval_loader = torch.utils.data.DataLoader(eval_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=0)
        results = basic_eval_fn(eval_loader, model, metric, logger)
        log_str = ("[{:}]".format(time_string()) +
                   " [{:04d}/{:04d}]".format(idx, env_info["total"]) +
                   " train-mse: {:.5f}, eval-mse: {:.5f}".format(
                       train_results["mse"], results["mse"]))
        logger.log(log_str)

        save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
            idx, env_info["total"])
        w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
        save_checkpoint(
            {
                "model_state_dict": model.state_dict(),
                "model": model,
                "index": idx,
                "timestamp": env_info["{:}-timestamp".format(idx)],
            },
            save_path,
            logger,
        )
        logger.log("")
        per_timestamp_time.update(time.time() - start_time)
        start_time = time.time()

    save_checkpoint(
        {"w_container_per_epoch": w_container_per_epoch},
        logger.path(None) / "final-ckp.pth",
        logger,
    )
    logger.log("-" * 200 + "\n")
    logger.close()
Exemplo n.º 3
0
def main(args):
    prepare_seed(args.rand_seed)
    logger = prepare_logger(args)
    train_env = get_synthetic_env(mode="train", version=args.env_version)
    valid_env = get_synthetic_env(mode="valid", version=args.env_version)
    trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
    test_env = get_synthetic_env(mode="test", version=args.env_version)
    all_env = get_synthetic_env(mode=None, version=args.env_version)
    logger.log("The training enviornment: {:}".format(train_env))
    logger.log("The validation enviornment: {:}".format(valid_env))
    logger.log("The trainval enviornment: {:}".format(trainval_env))
    logger.log("The total enviornment: {:}".format(all_env))
    logger.log("The test enviornment: {:}".format(test_env))
    model_kwargs = dict(
        config=dict(model_type="norm_mlp"),
        input_dim=all_env.meta_info["input_dim"],
        output_dim=all_env.meta_info["output_dim"],
        hidden_dims=[args.hidden_dim] * 2,
        act_cls="relu",
        norm_cls="layer_norm_1d",
    )

    base_model = get_model(**model_kwargs)
    base_model = base_model.to(args.device)
    if all_env.meta_info["task"] == "regression":
        criterion = torch.nn.MSELoss()
        metric = MSEMetric(True)
    elif all_env.meta_info["task"] == "classification":
        criterion = torch.nn.CrossEntropyLoss()
        metric = Top1AccMetric(True)
    else:
        raise ValueError("This task ({:}) is not supported.".format(
            all_env.meta_info["task"]))

    shape_container = base_model.get_w_container().to_shape_container()

    # pre-train the hypernetwork
    timestamps = trainval_env.get_timestamp(None)
    if args.ablation is None:
        MetaModel_cls = MetaModelV1
    elif args.ablation == "old":
        MetaModel_cls = MetaModel_TraditionalAtt
    else:
        raise ValueError("Unknown ablation : {:}".format(args.ablation))
    meta_model = MetaModel_cls(
        shape_container,
        args.layer_dim,
        args.time_dim,
        timestamps,
        seq_length=args.seq_length,
        interval=trainval_env.time_interval,
    )
    meta_model = meta_model.to(args.device)

    logger.log("The base-model has {:} weights.".format(base_model.numel()))
    logger.log("The meta-model has {:} weights.".format(meta_model.numel()))
    logger.log("The base-model is\n{:}".format(base_model))
    logger.log("The meta-model is\n{:}".format(meta_model))

    meta_train_procedure(base_model, meta_model, criterion, trainval_env, args,
                         logger)

    # try to evaluate once
    # online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
    # online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
    """
    w_containers, loss_meter = online_evaluate(
        all_env, meta_model, base_model, criterion, args, logger, True
    )
    logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
    """
    w_containers_care_adapt, loss_adapt_v1, metric_adapt_v1 = online_evaluate(
        test_env, meta_model, base_model, criterion, metric, args, logger,
        True, False)
    w_containers_easy_adapt, loss_adapt_v2, metric_adapt_v2 = online_evaluate(
        test_env, meta_model, base_model, criterion, metric, args, logger,
        True, True)
    logger.log("[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format(
        loss_adapt_v1, metric_adapt_v1))
    logger.log("[Easy-Adapt] loss = {:.6f}, metric = {:.6f}".format(
        loss_adapt_v2, metric_adapt_v2))

    save_checkpoint(
        {
            "w_containers_care_adapt": w_containers_care_adapt,
            "w_containers_easy_adapt": w_containers_easy_adapt,
            "test_loss_adapt_v1": loss_adapt_v1,
            "test_loss_adapt_v2": loss_adapt_v2,
            "test_metric_adapt_v1": metric_adapt_v1,
            "test_metric_adapt_v2": metric_adapt_v2,
        },
        logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
        logger,
    )

    logger.log("-" * 200 + "\n")
    logger.close()
Exemplo n.º 4
0
def main(args):
    prepare_seed(args.rand_seed)
    logger = prepare_logger(args)
    env = get_synthetic_env(mode=None, version=args.env_version)
    model_kwargs = dict(
        config=dict(model_type="norm_mlp"),
        input_dim=env.meta_info["input_dim"],
        output_dim=env.meta_info["output_dim"],
        hidden_dims=[args.hidden_dim] * 2,
        act_cls="relu",
        norm_cls="layer_norm_1d",
    )
    logger.log("The total enviornment: {:}".format(env))
    w_containers = dict()

    if env.meta_info["task"] == "regression":
        criterion = torch.nn.MSELoss()
        metric_cls = MSEMetric
    elif env.meta_info["task"] == "classification":
        criterion = torch.nn.CrossEntropyLoss()
        metric_cls = Top1AccMetric
    else:
        raise ValueError("This task ({:}) is not supported.".format(
            all_env.meta_info["task"]))

    per_timestamp_time, start_time = AverageMeter(), time.time()
    for idx, (future_time, (future_x, future_y)) in enumerate(env):

        need_time = "Time Left: {:}".format(
            convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True))
        logger.log("[{:}]".format(time_string()) +
                   " [{:04d}/{:04d}]".format(idx, len(env)) + " " + need_time)
        # train the same data
        historical_x = future_x.to(args.device)
        historical_y = future_y.to(args.device)
        # build model
        model = get_model(**model_kwargs)
        model = model.to(args.device)
        if idx == 0:
            print(model)
        # build optimizer
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.init_lr,
                                     amsgrad=True)
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[
                int(args.epochs * 0.25),
                int(args.epochs * 0.5),
                int(args.epochs * 0.75),
            ],
            gamma=0.3,
        )
        train_metric = metric_cls(True)
        best_loss, best_param = None, None
        for _iepoch in range(args.epochs):
            preds = model(historical_x)
            optimizer.zero_grad()
            loss = criterion(preds, historical_y)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            # save best
            if best_loss is None or best_loss > loss.item():
                best_loss = loss.item()
                best_param = copy.deepcopy(model.state_dict())
        model.load_state_dict(best_param)
        model.analyze_weights()
        with torch.no_grad():
            train_metric(preds, historical_y)
        train_results = train_metric.get_info()

        xmetric = ComposeMetric(metric_cls(True), SaveMetric())
        eval_dataset = torch.utils.data.TensorDataset(future_x.to(args.device),
                                                      future_y.to(args.device))
        eval_loader = torch.utils.data.DataLoader(eval_dataset,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=0)
        results = basic_eval_fn(eval_loader, model, xmetric, logger)
        log_str = ("[{:}]".format(time_string()) +
                   " [{:04d}/{:04d}]".format(idx, len(env)) +
                   " train-score: {:.5f}, eval-score: {:.5f}".format(
                       train_results["score"], results["score"]))
        logger.log(log_str)

        save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
            idx, len(env))
        w_containers[idx] = model.get_w_container().no_grad_clone()
        save_checkpoint(
            {
                "model_state_dict": model.state_dict(),
                "model": model,
                "index": idx,
                "timestamp": future_time.item(),
            },
            save_path,
            logger,
        )
        logger.log("")
        per_timestamp_time.update(time.time() - start_time)
        start_time = time.time()

    save_checkpoint(
        {"w_containers": w_containers},
        logger.path(None) / "final-ckp.pth",
        logger,
    )

    logger.log("-" * 200 + "\n")
    logger.close()
Exemplo n.º 5
0
def main(args):
    logger, model_kwargs = lfna_setup(args)

    w_containers = dict()

    per_timestamp_time, start_time = AverageMeter(), time.time()
    for idx in range(args.prev_time, env_info["total"]):

        need_time = "Time Left: {:}".format(
            convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
        )
        logger.log(
            "[{:}]".format(time_string())
            + " [{:04d}/{:04d}]".format(idx, env_info["total"])
            + " "
            + need_time
        )
        # train the same data
        historical_x = env_info["{:}-x".format(idx - args.prev_time)]
        historical_y = env_info["{:}-y".format(idx - args.prev_time)]
        # build model
        model = get_model(**model_kwargs)
        print(model)
        # build optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
        criterion = torch.nn.MSELoss()
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[
                int(args.epochs * 0.25),
                int(args.epochs * 0.5),
                int(args.epochs * 0.75),
            ],
            gamma=0.3,
        )
        train_metric = MSEMetric()
        best_loss, best_param = None, None
        for _iepoch in range(args.epochs):
            preds = model(historical_x)
            optimizer.zero_grad()
            loss = criterion(preds, historical_y)
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            # save best
            if best_loss is None or best_loss > loss.item():
                best_loss = loss.item()
                best_param = copy.deepcopy(model.state_dict())
        model.load_state_dict(best_param)
        model.analyze_weights()
        with torch.no_grad():
            train_metric(preds, historical_y)
        train_results = train_metric.get_info()

        metric = ComposeMetric(MSEMetric(), SaveMetric())
        eval_dataset = torch.utils.data.TensorDataset(
            env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)]
        )
        eval_loader = torch.utils.data.DataLoader(
            eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
        )
        results = basic_eval_fn(eval_loader, model, metric, logger)
        log_str = (
            "[{:}]".format(time_string())
            + " [{:04d}/{:04d}]".format(idx, env_info["total"])
            + " train-mse: {:.5f}, eval-mse: {:.5f}".format(
                train_results["mse"], results["mse"]
            )
        )
        logger.log(log_str)

        save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
            idx, env_info["total"]
        )
        w_containers[idx] = model.get_w_container().no_grad_clone()
        save_checkpoint(
            {
                "model_state_dict": model.state_dict(),
                "model": model,
                "index": idx,
                "timestamp": env_info["{:}-timestamp".format(idx)],
            },
            save_path,
            logger,
        )
        logger.log("")
        per_timestamp_time.update(time.time() - start_time)
        start_time = time.time()

    save_checkpoint(
        {"w_containers": w_containers},
        logger.path(None) / "final-ckp.pth",
        logger,
    )

    logger.log("-" * 200 + "\n")
    logger.close()
Exemplo n.º 6
0
def compare_algs(save_dir, version, alg_dir="./outputs/GeMOSA-synthetic"):
    save_dir = Path(str(save_dir))
    for substr in ("pdf", "png"):
        sub_save_dir = save_dir / substr
        sub_save_dir.mkdir(parents=True, exist_ok=True)

    dpi, width, height = 30, 3200, 2000
    figsize = width / float(dpi), height / float(dpi)
    LabelSize, LegendFontsize, font_gap = 80, 80, 5

    dynamic_env = get_synthetic_env(mode=None, version=version)
    allxs, allys = [], []
    for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env,
                                                         ncols=50)):
        allxs.append(allx)
        allys.append(ally)
    allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)

    alg_name2dir = OrderedDict()
    # alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
    # alg_name2dir["MAML"] = "use-maml-s1"
    # alg_name2dir["LFNA (fix init)"] = "lfna-fix-init"
    if version == "v1":
        # alg_name2dir["Optimal"] = "use-same-timestamp"
        alg_name2dir[
            "GMOA"] = "lfna-battle-bs128-d16_16_16-s16-lr0.002-wd1e-05-e10000-envv1"
    else:
        raise ValueError("Invalid version: {:}".format(version))
    alg_name2all_containers = OrderedDict()
    for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
        ckp_path = Path(alg_dir) / str(xdir) / "final-ckp.pth"
        xdata = torch.load(ckp_path, map_location="cpu")
        alg_name2all_containers[alg] = xdata["w_containers"]
    # load the basic model
    model = get_model(
        dict(model_type="norm_mlp"),
        input_dim=1,
        output_dim=1,
        hidden_dims=[16] * 2,
        act_cls="gelu",
        norm_cls="layer_norm_1d",
    )

    alg2xs, alg2ys = defaultdict(list), defaultdict(list)
    colors = ["r", "g", "b", "m", "y"]

    linewidths, skip = 10, 5
    for idx, (timestamp, (ori_allx,
                          ori_ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
        if idx <= skip:
            continue
        fig = plt.figure(figsize=figsize)
        cur_ax = fig.add_subplot(2, 1, 1)

        # the data
        allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy()
        plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")

        for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
            with torch.no_grad():
                predicts = model.forward_with_container(
                    ori_allx, alg_name2all_containers[alg][idx])
                predicts = predicts.cpu()
                # keep data
                metric = MSEMetric()
                metric(predicts, ori_ally)
                predicts = predicts.view(-1).numpy()
                alg2xs[alg].append(idx)
                alg2ys[alg].append(metric.get_info()["mse"])
            plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99,
                         linewidths, alg)

        cur_ax.set_xlabel("X", fontsize=LabelSize)
        cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
        for tick in cur_ax.xaxis.get_major_ticks():
            tick.label.set_fontsize(LabelSize - font_gap)
            tick.label.set_rotation(10)
        for tick in cur_ax.yaxis.get_major_ticks():
            tick.label.set_fontsize(LabelSize - font_gap)
        cur_ax.set_xlim(round(allxs.min().item(), 1),
                        round(allxs.max().item(), 1))
        cur_ax.set_ylim(round(allys.min().item(), 1),
                        round(allys.max().item(), 1))
        cur_ax.legend(loc=1, fontsize=LegendFontsize)

        # the trajectory data
        cur_ax = fig.add_subplot(2, 1, 2)
        for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
            # plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg)
            cur_ax.plot(
                alg2xs[alg],
                alg2ys[alg],
                color=colors[idx_alg],
                linestyle="-",
                linewidth=5,
                label=alg,
            )
        cur_ax.legend(loc=1, fontsize=LegendFontsize)

        cur_ax.set_xlabel("Timestamp", fontsize=LabelSize)
        cur_ax.set_ylabel("MSE", fontsize=LabelSize)
        for tick in cur_ax.xaxis.get_major_ticks():
            tick.label.set_fontsize(LabelSize - font_gap)
            tick.label.set_rotation(10)
        for tick in cur_ax.yaxis.get_major_ticks():
            tick.label.set_fontsize(LabelSize - font_gap)
        cur_ax.set_xlim(1, len(dynamic_env))
        cur_ax.set_ylim(0, 10)
        cur_ax.legend(loc=1, fontsize=LegendFontsize)

        pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(
            version, idx - skip)
        fig.savefig(str(pdf_save_path),
                    dpi=dpi,
                    bbox_inches="tight",
                    format="pdf")
        png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(
            version, idx - skip)
        fig.savefig(str(png_save_path),
                    dpi=dpi,
                    bbox_inches="tight",
                    format="png")
        plt.close("all")
    save_dir = save_dir.resolve()
    base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format(
        xdir=save_dir / "png", w=width, h=height, ver=version)
    os.system("{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd,
                                                    xdir=save_dir,
                                                    ver=version))
    os.system("{:} {xdir}/com-alg-{ver}.webm".format(base_cmd,
                                                     xdir=save_dir,
                                                     ver=version))
Exemplo n.º 7
0
def main(args):
    prepare_seed(args.rand_seed)
    logger = prepare_logger(args)
    train_env = get_synthetic_env(mode="train", version=args.env_version)
    valid_env = get_synthetic_env(mode="valid", version=args.env_version)
    trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
    test_env = get_synthetic_env(mode="test", version=args.env_version)
    all_env = get_synthetic_env(mode=None, version=args.env_version)
    logger.log("The training enviornment: {:}".format(train_env))
    logger.log("The validation enviornment: {:}".format(valid_env))
    logger.log("The trainval enviornment: {:}".format(trainval_env))
    logger.log("The total enviornment: {:}".format(all_env))
    logger.log("The test enviornment: {:}".format(test_env))
    model_kwargs = dict(
        config=dict(model_type="norm_mlp"),
        input_dim=all_env.meta_info["input_dim"],
        output_dim=all_env.meta_info["output_dim"],
        hidden_dims=[args.hidden_dim] * 2,
        act_cls="relu",
        norm_cls="layer_norm_1d",
    )

    model = get_model(**model_kwargs)
    model = model.to(args.device)
    if all_env.meta_info["task"] == "regression":
        criterion = torch.nn.MSELoss()
        metric_cls = MSEMetric
    elif all_env.meta_info["task"] == "classification":
        criterion = torch.nn.CrossEntropyLoss()
        metric_cls = Top1AccMetric
    else:
        raise ValueError(
            "This task ({:}) is not supported.".format(all_env.meta_info["task"])
        )

    maml = MAML(
        model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step
    )

    # meta-training
    last_success_epoch = 0
    per_epoch_time, start_time = AverageMeter(), time.time()
    for iepoch in range(args.epochs):
        need_time = "Time Left: {:}".format(
            convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
        )
        head_str = (
            "[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs)
            + need_time
        )

        maml.zero_grad()
        meta_losses = []
        for ibatch in range(args.meta_batch):
            future_idx = random.randint(0, len(trainval_env) - 1)
            future_t, (future_x, future_y) = trainval_env[future_idx]
            # -->>
            seq_times = trainval_env.get_seq_times(future_idx, args.seq_length)
            _, (allxs, allys) = trainval_env.seq_call(seq_times)
            allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
            if trainval_env.meta_info["task"] == "classification":
                allys = allys.view(-1)
            historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
            future_container = maml.adapt(historical_x, historical_y)

            future_x, future_y = future_x.to(args.device), future_y.to(args.device)
            future_y_hat = maml.predict(future_x, future_container)
            future_loss = maml.criterion(future_y_hat, future_y)
            meta_losses.append(future_loss)
        meta_loss = torch.stack(meta_losses).mean()
        meta_loss.backward()
        maml.step()

        logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item()))
        success, best_score = maml.save_best(-meta_loss.item())
        if success:
            logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
            save_checkpoint(maml.state_dict(), logger.path("model"), logger)
            last_success_epoch = iepoch
        if iepoch - last_success_epoch >= args.early_stop_thresh:
            logger.log("Early stop at {:}".format(iepoch))
            break

        per_epoch_time.update(time.time() - start_time)
        start_time = time.time()

    # meta-test
    maml.load_best()

    def finetune(index):
        seq_times = test_env.get_seq_times(index, args.seq_length)
        _, (allxs, allys) = test_env.seq_call(seq_times)
        allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
        if test_env.meta_info["task"] == "classification":
            allys = allys.view(-1)
        historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
        future_container = maml.adapt(historical_x, historical_y)

        historical_y_hat = maml.predict(historical_x, future_container)
        train_metric = metric_cls(True)
        # model.analyze_weights()
        with torch.no_grad():
            train_metric(historical_y_hat, historical_y)
        train_results = train_metric.get_info()
        return train_results, future_container

    metric = metric_cls(True)
    per_timestamp_time, start_time = AverageMeter(), time.time()
    for idx, (future_time, (future_x, future_y)) in enumerate(test_env):

        need_time = "Time Left: {:}".format(
            convert_secs2time(per_timestamp_time.avg * (len(test_env) - idx), True)
        )
        logger.log(
            "[{:}]".format(time_string())
            + " [{:04d}/{:04d}]".format(idx, len(test_env))
            + " "
            + need_time
        )

        # build optimizer
        train_results, future_container = finetune(idx)

        future_x, future_y = future_x.to(args.device), future_y.to(args.device)
        future_y_hat = maml.predict(future_x, future_container)
        future_loss = criterion(future_y_hat, future_y)
        metric(future_y_hat, future_y)
        log_str = (
            "[{:}]".format(time_string())
            + " [{:04d}/{:04d}]".format(idx, len(test_env))
            + " train-score: {:.5f}, eval-score: {:.5f}".format(
                train_results["score"], metric.get_info()["score"]
            )
        )
        logger.log(log_str)
        logger.log("")
        per_timestamp_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log("-" * 200 + "\n")
    logger.close()
Exemplo n.º 8
0
    def __init__(
        self,
        shape_container,
        layer_dim,
        time_dim,
        meta_timestamps,
        dropout: float = 0.1,
        seq_length: int = None,
        interval: float = None,
        thresh: float = None,
    ):
        super(MetaModel_TraditionalAtt, self).__init__()
        self._shape_container = shape_container
        self._num_layers = len(shape_container)
        self._numel_per_layer = []
        for ilayer in range(self._num_layers):
            self._numel_per_layer.append(shape_container[ilayer].numel())
        self._raw_meta_timestamps = meta_timestamps
        assert interval is not None
        self._interval = interval
        self._thresh = interval * seq_length if thresh is None else thresh

        self.register_parameter(
            "_super_layer_embed",
            torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)),
        )
        self.register_parameter(
            "_super_meta_embed",
            torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)),
        )
        self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
        self._time_embed_dim = time_dim
        self._append_meta_embed = dict(fixed=None, learnt=None)
        self._append_meta_timestamps = dict(fixed=None, learnt=None)

        self._tscalar_embed = super_core.SuperDynamicPositionE(
            time_dim, scale=1 / interval
        )

        # build transformer
        self._trans_att = super_core.SuperQKVAttention(
            in_q_dim=time_dim,
            in_k_dim=time_dim,
            in_v_dim=time_dim,
            num_heads=4,
            proj_dim=time_dim,
            qkv_bias=True,
            attn_drop=None,
            proj_drop=dropout,
        )

        model_kwargs = dict(
            config=dict(model_type="dual_norm_mlp"),
            input_dim=layer_dim + time_dim,
            output_dim=max(self._numel_per_layer),
            hidden_dims=[(layer_dim + time_dim) * 2] * 3,
            act_cls="gelu",
            norm_cls="layer_norm_1d",
            dropout=dropout,
        )
        self._generator = get_model(**model_kwargs)

        # initialization
        trunc_normal_(
            [self._super_layer_embed, self._super_meta_embed],
            std=0.02,
        )
Exemplo n.º 9
0
def main(args):
    prepare_seed(args.rand_seed)
    logger = prepare_logger(args)
    env = get_synthetic_env(mode="test", version=args.env_version)
    model_kwargs = dict(
        config=dict(model_type="norm_mlp"),
        input_dim=env.meta_info["input_dim"],
        output_dim=env.meta_info["output_dim"],
        hidden_dims=[args.hidden_dim] * 2,
        act_cls="relu",
        norm_cls="layer_norm_1d",
    )
    logger.log("The total enviornment: {:}".format(env))
    w_containers = dict()

    if env.meta_info["task"] == "regression":
        criterion = torch.nn.MSELoss()
        metric_cls = MSEMetric
    elif env.meta_info["task"] == "classification":
        criterion = torch.nn.CrossEntropyLoss()
        metric_cls = Top1AccMetric
    else:
        raise ValueError("This task ({:}) is not supported.".format(
            all_env.meta_info["task"]))

    seq_times = env.get_seq_times(0, args.seq_length)
    _, (allxs, allys) = env.seq_call(seq_times)
    allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
    if env.meta_info["task"] == "classification":
        allys = allys.view(-1)

    historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
    model = get_model(**model_kwargs)
    model = model.to(args.device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.init_lr,
                                 amsgrad=True)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[
            int(args.epochs * 0.25),
            int(args.epochs * 0.5),
            int(args.epochs * 0.75),
        ],
        gamma=0.3,
    )

    train_metric = metric_cls(True)
    best_loss, best_param = None, None
    for _iepoch in range(args.epochs):
        preds = model(historical_x)
        optimizer.zero_grad()
        loss = criterion(preds, historical_y)
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        # save best
        if best_loss is None or best_loss > loss.item():
            best_loss = loss.item()
            best_param = copy.deepcopy(model.state_dict())
    model.load_state_dict(best_param)
    model.analyze_weights()
    with torch.no_grad():
        train_metric(preds, historical_y)
    train_results = train_metric.get_info()
    print(train_results)

    metric = metric_cls(True)
    per_timestamp_time, start_time = AverageMeter(), time.time()
    for idx, (future_time, (future_x, future_y)) in enumerate(env):

        need_time = "Time Left: {:}".format(
            convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True))
        logger.log("[{:}]".format(time_string()) +
                   " [{:04d}/{:04d}]".format(idx, len(env)) + " " + need_time)
        # train the same data

        # build optimizer
        xmetric = ComposeMetric(metric_cls(True), SaveMetric())
        future_x, future_y = future_x.to(args.device), future_y.to(args.device)
        future_y_hat = model(future_x)
        future_loss = criterion(future_y_hat, future_y)
        metric(future_y_hat, future_y)
        log_str = ("[{:}]".format(time_string()) +
                   " [{:04d}/{:04d}]".format(idx, len(env)) +
                   " train-score: {:.5f}, eval-score: {:.5f}".format(
                       train_results["score"],
                       metric.get_info()["score"]))
        logger.log(log_str)
        logger.log("")
        per_timestamp_time.update(time.time() - start_time)
        start_time = time.time()

    save_checkpoint(
        {"w_containers": w_containers},
        logger.path(None) / "final-ckp.pth",
        logger,
    )

    logger.log("-" * 200 + "\n")
    logger.close()
    return metric.get_info()["score"]