Beispiel #1
0
def main(args):
    ### Loading Dataset ###
    data_root = os.path.join(args.data_path, "train")
    dataset, _ = load_data(data_root)
    pb_list = torch.from_numpy(
        np.loadtxt(os.path.join(args.data_path, "train", "pb_list.txt"),
                   dtype=np.float32,
                   delimiter=",").reshape(-1, args.pb_dims))
    #pb_list = torch.from_numpy(np.loadtxt(os.path.join(args.data_path,
    #                                                   "sp",
    #                                                   "pb_list.txt"), dtype=np.float32, delimiter=",").reshape(-1, args.pb_dims))
    rnn = LSTMPB(args, pb_unit=pb_list[None, 0])
    #rnn = GRUPB(args, pb_unit=pb_list[None, 0])
    #rnn = LSTM(args)
    data = dataset[None, 0]
    his_log = deque(maxlen=32)
    optim = nn.optim.Adam()

    pt_file = load_model(args.model_load_path, "*.pt")
    rnn.load_state_dict(torch.load(pt_file))
    rnn.eval()

    output_log = []
    for i in range(data.shape[1] - 1):
        if i == 0:
            cur_input = data[:, 0, :]
            state = None
        else:
            if i == 80:
                rnn.pb_unit = pb_list[None, 5]
                data = dataset[None, 5]

            cur_input = torch.cat([output[:, :5], data[:, i, 5:]], dim=-1)
            #if i < 100:
            #    cur_input = torch.cat([output[:, :5], data[:, i, 5:]], dim=-1)
            #else:
            #    #cur_input = torch.cat([output[:, :5], dataset[0,None, i, 5:]], dim=-1)
            #    cur_input = output
            #cur_input = data[:, i, :]
            state = prev_state

        output, prev_state = rnn.step(cur_input, state)
        output_log.append(output.detach())

    output_log = torch.stack(output_log, dim=1)
    np.savetxt("{}LSTMPB_closed_predict.txt".format(args.log_path),
               output_log[0].numpy(),
               delimiter=",")
    fig = make_fig([[output_log[0], dataset[0]]], figsize=(16, 16))
    fig.savefig("{}LSTMPB_closed_predict_goal000.png".format(args.log_path))
    plt.close(fig)
    fig = make_fig([[output_log[0], dataset[5]]], figsize=(16, 16))
    fig.savefig("{}LSTMPB_closed_predict_goal001.png".format(args.log_path))
    plt.close(fig)
    fig = make_fig([[output_log[0], dataset[10]]], figsize=(16, 16))
    fig.savefig("{}LSTMPB_closed_predict_goal002.png".format(args.log_path))
    plt.close(fig)
    fig = make_fig([[output_log[0], dataset[15]]], figsize=(16, 16))
    fig.savefig("{}LSTMPB_closed_predict_goal003.png".format(args.log_path))
    plt.close(fig)
Beispiel #2
0
def main(args):
    ### Loading Dataset ###
    data_root = os.path.join(args.data_path, "train")
    dataset, _ = load_data(data_root)

    rnn = LSTM(args)
    data = dataset[None, 0]

    pt_file = load_model(args.model_load_path, "*.pt")
    rnn.load_state_dict(torch.load(pt_file))
    rnn.eval()

    output_log = []
    for i in range(data.shape[1] - 1):
        if i == 0:
            cur_input = data[:, 0, :]
            state = None
        else:
            cur_input = torch.cat([output[0, :, :8], data[:, i, 8:]], dim=-1)
            #cur_input = data[:, i, :]
            state = prev_state

        output, prev_state = rnn(cur_input.view(1, 1, -1), state)
        output_log.append(output.detach())

    output_log = torch.stack(output_log, dim=1)
    np.savetxt("{}LSTMPB_closed_predict.txt".format(args.log_path),
               output_log[0, 0].numpy(),
               delimiter=",")
    fig = make_fig([[output_log[0, :, 0, :], dataset[1]]], figsize=(16, 16))
    fig.savefig("{}LSTMPB_closed_predict.png".format(args.log_path))
    plt.close(fig)
Beispiel #3
0
    def train(self, train_loader, vali_loader=None, vali_pb_list=None):
        closed_flag = False
        lr_reduce_epoch = 2000
        for epoch in range(1, self.epoch + 1):
            loss_sum = .0
            if epoch == self.closed_step and closed_flag is False:
                print("=" * 5, "Start closed loop", "=" * 5)
                closed_flag = True
            for i, (batch_x, batch_y) in enumerate(train_loader, 1):
                batch_x, batch_y = batch_x.to(self.device), batch_y.to(
                    self.device)

                self.optim.zero_grad()
                output_log, state_log = self.rnn(batch_x[:, :-self.delay, :],
                                                 closed_flag=closed_flag)
                loss = self.loss(batch_y[:, self.delay:, :], output_log)
                loss_sum += loss.item()

                loss.backward()
                self.optim.step()

            if epoch % lr_reduce_epoch == 0:
                self.scheduler.step()

            self.loss_list["loss"] = loss_sum / i
            if epoch % self.log_iter == 0:
                if vali_loader is not None:
                    self.test(vali_loader, vali_pb_list)

                output = output_log.detach().cpu().numpy()
                np.savetxt("{}{}_result_{}".format(self.log_path, self.name,
                                                   epoch),
                           output[0],
                           delimiter=",")
                data = batch_y.detach().cpu().numpy()
                self.result_img = make_fig([[output[0], data[0]]])

                self.logger(epoch)
                plt.close(self.result_img)

            if epoch % self.model_save_iter == 0:
                torch.save(
                    self.rnn.state_dict(),
                    "{}{}_{:0>6}.pt".format(self.model_save_path, self.name,
                                            epoch))
Beispiel #4
0
import task2
import utils
from matplotlib import pyplot as plt
import numpy as np

tcases = utils.load('testcases.pkl')
for tcase in tcases:
    img_patches, shape, reconstructed = tcases[tcase]
    cleaned_img = task2.reconstruct_from_noisy_patches(img_patches, shape)
    try:
        if np.allclose(cleaned_img, reconstructed):
            print('testcase# {} passed'.format(tcase))
            utils.make_fig(reconstructed, cmap='gray', title='Correct')
            utils.make_fig(cleaned_img, cmap='gray', title='Yours')
            plt.show()
        else:
            print('testcase# {} failed'.format(tcase))
            utils.make_fig(reconstructed, cmap='gray', title='Correct')
            utils.make_fig(cleaned_img, cmap='gray', title='Yours')
            plt.show()
    except Exception as e:
        print('testcase# {} failed'.format(tcase))
parser.add_argument('--input', type=str)
parser.add_argument('--k', type=int)
parser.add_argument('--output', type=str)

args = parser.parse_args()

k = args.k
assert k < 51 and k > 0
assert os.path.isfile(args.input), 'input({}) is not a valid file'.format(
    args.input)


def transform_img(arr, k):
    original_shape = arr.shape
    assert arr.ndim == 3 and arr.shape[
        2] == 3, 'invalid arr. shape = {}'.format(arr.shape)
    arr = arr.astype('float').reshape((-1, 3))
    centroid, label = kmeans2(arr,
                              k=k,
                              iter=10,
                              minit='++',
                              check_finite=False)
    kimg = centroid[label].reshape(original_shape)
    return utils.clip_both_sides(kimg)


img = utils.load_image(args.input, rescale=False, grayscale=False)
tri = transform_img(img, k=args.k)
utils.make_fig(tri, cmap=None)
plt.savefig(args.output, cmap=None)
Beispiel #6
0
def main(args):
    ### Loading Dataset ###
    data_root = os.path.join(args.data_path, "train")
    dataset, _ = load_data(data_root)
    pb_list = torch.from_numpy(
        np.loadtxt(os.path.join(args.data_path, "train", "pb_list.txt"),
                   dtype=np.float32,
                   delimiter=",").reshape(-1, args.pb_dims))
    rnn = LSTMPB(args, pb_unit=pb_list[None, 0])
    #rnn = LSTMPB(args, pb_unit=nn.Parameter(pb_list[None, 0],
    #                                        requires_grad=True))
    data = dataset[None, 0]

    pt_file = load_model(args.model_load_path, "*.pt")
    state_dict = torch.load(pt_file)
    rnn.load_state_dict(state_dict)
    for param in rnn.parameters():
        param.requires_grad = False

    rnn.pb_unit = nn.Parameter(pb_list[None, 0], requires_grad=True)
    optim_param = [
        param for param in rnn.parameters() if param.requires_grad == True
    ]
    print(optim_param)
    optim = torch.optim.Adam(optim_param, lr=0.01)
    #optim = torch.optim.Adam(optim_param)
    mse_loss = nn.MSELoss()
    his_log = HistoryWindow(maxlen=WINDOW_SIZE)
    #rnn.eval()

    output_log = []
    print(rnn.pb_unit)
    pb_log = []
    for i in range(data.shape[1] - 1):
        if i == 0:
            cur_input = data[:, 0, :]
            state = None
        else:
            if i == 999:
                pred_his, actual_his, state_his = his_log.get()
                for _ in range(100):
                    log = []
                    cur = torch.cat([pred_his[:, 0, :5], actual_his[:, 0, 5:]],
                                    dim=-1)
                    s = state_his[0]
                    for step in range(1, len(state_his)):
                        o, prev_s = rnn.step(cur, s)
                        cur = o
                        s = prev_s
                        log.append(o)

                    log = torch.stack(log, dim=1)
                    #loss = mse_loss(log[0, :, 5:], actual_his[0, 1:, 5:]) + (rnn.pb_unit - pb_list).pow(2).mean()
                    loss = mse_loss(log[0, :, 5:], actual_his[0, 1:, 5:]) + (
                        rnn.pb_unit - pb_list).pow(2).mean()
                    loss.backward(retain_graph=True)
                    pb_log.append(rnn.pb_unit.data.clone())
                    optim.step()
                    print(loss.item())

            #    rnn.pb_unit=pb_list[None, 5]
            #    data = dataset[None, 5]
                prev_state = s

            cur_input = torch.cat([output[:, :5], data[:, i, 5:]], dim=-1)
            #if i < 100:
            #    cur_input = torch.cat([output[:, :5], data[:, i, 5:]], dim=-1)
            #else:
            #    #cur_input = torch.cat([output[:, :5], dataset[0,None, i, 5:]], dim=-1)
            #    cur_input = output
            #cur_input = data[:, i, :]
            state = prev_state

        output, prev_state = rnn.step(cur_input, state)
        his_log.put([output[:, :], data[:, i + 1, :], prev_state])
        output_log.append(output.detach())

    print(rnn.pb_unit)
    output_log = torch.stack(output_log, dim=1)
    np.savetxt("{}LSTMPB_closed_predict.txt".format(args.log_path),
               output_log[0].numpy(),
               delimiter=",")
    fig = make_fig([[output_log[0], dataset[0]]], figsize=(16, 16))
    fig.savefig("{}LSTMPB_closed_predict_goal000.png".format(args.log_path))
    plt.close(fig)
    fig = make_fig([[output_log[0], dataset[5]]], figsize=(16, 16))
    fig.savefig("{}LSTMPB_closed_predict_goal001.png".format(args.log_path))
    plt.close(fig)
    fig = make_fig([[output_log[0], dataset[10]]], figsize=(16, 16))
    fig.savefig("{}LSTMPB_closed_predict_goal002.png".format(args.log_path))
    plt.close(fig)
    fig = make_fig([[output_log[0], dataset[15]]], figsize=(16, 16))
    fig.savefig("{}LSTMPB_closed_predict_goal003.png".format(args.log_path))
    plt.close(fig)

    pb_log = torch.stack(pb_log, 1)
    np.savetxt("{}pb_log.txt".format(args.log_path),
               pb_log[0].numpy(),
               delimiter=",")
Beispiel #7
0
def main(args):
    #########################################
    ###          Loading Dataset          ###
    #########################################
    train_root = os.path.join(args.data_path, "train")
    train_data, train_loader = load_data(train_root)
    test_data, test_loader = train_data, train_loader

    model_args = {"args": args}
    device = torch.device("cuda:0" if args.cuda else "cpu")
    if "PB" in args.net:
        train_pb_list = np.loadtxt(os.path.join(args.data_path, "train",
                                                "pb_list.txt"),
                                   dtype=np.float32,
                                   delimiter=",").reshape([-1, args.pb_dims])
        train_pb_list = torch.from_numpy(train_pb_list).to(device)
        test_pb_list = train_pb_list
        model_args["pb_unit"] = train_pb_list

    if args.validate:
        vali_root = os.path.join(args.data_path, "validation")
        vali_data, vali_loader = load_data(vali_root)
        if "PB" in args.net:
            vali_pb_list = np.loadtxt(os.path.join(args.data_path,
                                                   "validation",
                                                   "pb_list.txt"),
                                      dtype=np.float32,
                                      delimiter=",")
            vali_pb_list = torch.from_numpy(vali_pb_list).to(device).to(device)
    else:
        vali_loader, vali_pb_list = None, None

    ### Building Model ###
    model = Model(args)
    model.build(eval(args.net), **model_args)

    if args.model_load_path:
        model_name = args.net + "*.pt"
        pt_file = load_model(args.model_load_path, model_name)
        model.load(pt_file)

    if args.mode == "train":
        model.rnn.train()
        model.train(train_loader, vali_loader, vali_pb_list)
    elif args.mode == "test":
        model.rnn.eval()
        output_log, state_log, loss = model.test(test_loader)
        print("test loss:", loss)
        for i in range(output_log.shape[0]):
            fig = make_fig([[output_log[i], test_data[i].numpy()]])
            np.savetxt("{}{}_test_{:0>6d}.txt".format(args.log_path, args.net,
                                                      i),
                       output_log[i],
                       delimiter=" ")
            plt.savefig("{}{}_test_{:0>6d}.png".format(args.log_path, args.net,
                                                       i))
            plt.close(fig)
    elif args.mode == "predict":
        model.rnn.eval()
        inputs = test_data.to(device)
        output_log, state_log = model.predict(inputs, None,
                                              inputs.shape[1] - 1)
        for i in range(test_data.shape[0]):
            fig = make_fig([[output_log[i], test_data[i].numpy()]])
            np.savetxt("{}{}_predict_{:0>6d}.txt".format(
                args.log_path, args.net, i),
                       output_log[i],
                       delimiter=" ")
            fig.savefig("{}{}_predict_{:0>6d}.png".format(
                args.log_path, args.net, i))
            plt.close(fig)