示例#1
0
def eval(opt):
    img_logs, weight_logs, tensor_writer = init_logs(opt)
    model = CycleGANModel(opt)
    model.parallel_init(device_ids=opt.device_ids)
    print(weight_logs)
    dataset = Robotdata.get_loader(opt)
    # fdataset = RobotStackFdata.get_loader(opt)
    # model.train_forward_state(fdataset, True)
    weight_list = list(
        filter(lambda x: 'weights_' in x,
               os.listdir(weight_logs.replace('weights', ''))))
    weight_list = sorted(weight_list, key=lambda x: int(x.split('_')[1]))

    for weight_path in weight_list:
        weight_id = int(weight_path.split('_')[1])
        weight_path = weight_logs.replace('weights', weight_path)
        print(weight_path)
        model.load(weight_path)

        # model.img_policy.online_test(model.netG_B,5)

        dataset = Robotdata.get_loader(opt)
        ave_loss = {}
        count_n = 10
        gt, pred = [], []
        for batch_id, data in enumerate(dataset):
            model.set_input(data)
            model.test()

            errors = model.get_current_errors()
            display = '===> Batch({}/{})'.format(batch_id, len(dataset))
            for key, value in errors.items():
                display += '{}:{:.4f}  '.format(key, value)
                try:
                    ave_loss[key] = ave_loss[key] + value
                except:
                    ave_loss[key] = value
            gt.append(model.gt0.cpu().data.numpy())
            pred.append(model.fake_At0.cpu().data.numpy())

            if batch_id >= count_n - 1:
                path = os.path.join(img_logs, 'imgA_{}.jpg'.format(weight_id))
                model.visual(path)
                break

        gt = np.vstack(gt)
        pred = np.vstack(pred)
        np.save(os.path.join(img_logs, 'gt_{}.npy'.format(weight_id)), gt)
        np.save(os.path.join(img_logs, 'pred_{}.npy'.format(weight_id)), pred)

        display = 'average loss: '
        for key, value in ave_loss.items():
            display += '{}:{:.4f}  '.format(key, value / (count_n))
        print(display)
示例#2
0
def train(opt):
    img_logs, weight_logs, tensor_writer = init_logs(opt)
    model = CycleGANModel(opt)
    dataset = Robotdata.get_loader(opt)
    fdataset = RobotStackFdata.get_loader(opt)
    model.train_forward_state(fdataset, opt.pretrain_f)
    model.parallel_init(device_ids=opt.device_ids)

    # model.img_policy.online_test(model.netG_B, 1)

    niter = 0
    for epoch_id in range(opt.epoch_size):
        for batch_id, data in enumerate(dataset):
            model.set_input(data)
            model.optimize_parameters()

            niter += 1
            if (batch_id) % opt.display_gap == 0:
                errors = model.get_current_errors()
                display = '===> Epoch[{}]({}/{})'.format(
                    epoch_id, batch_id, len(dataset))
                for key, value in errors.items():
                    display += '{}:{:.4f}  '.format(key, value)
                    tensor_writer.add_scalar(key, value, niter)
                print(display)

            if (batch_id) % opt.save_weight_gap == 0:
                path = os.path.join(
                    img_logs, 'imgA_{}_{}.jpg'.format(epoch_id, batch_id + 1))
                model.visual(path)
                model.save(
                    weight_logs.replace('weights', 'weights_{}'.format(niter)))
示例#3
0
def eval(opt):
    img_logs, weight_logs, tensor_writer = init_logs(opt)
    model = CycleGANModel(opt)
    model.parallel_init(device_ids=opt.device_ids)
    print(weight_logs)
    model.load(weight_logs)
    # model.img_policy.online_test(model.netG_B,5)

    dataset = Robotdata.get_loader(opt)
    ave_loss = {}
    for batch_id, data in enumerate(dataset):
        model.set_input(data)
        model.test()

        errors = model.get_current_errors()
        display = '===> Batch({}/{})'.format(batch_id, len(dataset))
        for key, value in errors.items():
            display += '{}:{:.4f}  '.format(key, value)
            try:
                ave_loss[key] = ave_loss[key] + value
            except:
                ave_loss[key] = value
        print(display)

        if (batch_id + 1) % opt.display_gap == 0:
            path = os.path.join(img_logs, 'imgA_{}.jpg'.format(batch_id + 1))
            model.visual(path)
示例#4
0
def eval_img2state(opt):
    model = PixelEncoder(opt).cuda()
    model = nn.DataParallel(model, device_ids=[0, 1])
    weight_path = os.path.join(
        opt.data_root, 'data_{}/img2state_large.pth'.format(opt.test_id1))
    model.load_state_dict(torch.load(weight_path))

    dataset = RobotStackdata.get_loader(opt)
    loss_fn = nn.L1Loss()

    epoch_loss = 0
    origin, recover = [], []
    for i, item in enumerate(dataset):
        state, action, result = item[1]
        input_Bt0 = item[0][0]
        input_Bt1 = item[0][2]
        action = item[0][1]
        gt0 = item[2][0].float().cuda()
        gt1 = item[2][1].float().cuda()

        img = input_Bt0.float().cuda()
        gt = gt0.float().cuda()

        out = model(img)
        loss = loss_fn(out, gt)
        epoch_loss += loss.item()

        print(i, epoch_loss / (i + 1))

        origin.append(gt.cpu().data.numpy())
        recover.append(out.cpu().data.numpy())

        if i > 100:
            break

    print('epoch:{} loss:{:.7f}'.format(0, epoch_loss / len(dataset)))

    origin = np.vstack(origin)
    recover = np.vstack(recover)

    np.save(
        os.path.join(opt.data_root, 'data_{}/origin.npy'.format(opt.test_id1)),
        origin)
    np.save(
        os.path.join(opt.data_root,
                     'data_{}/recover.npy'.format(opt.test_id1)), recover)
示例#5
0
def train_img2state(opt):
    # model = img2state(opt).cuda()
    # weight_path = os.path.join(opt.data_root, 'data_{}/img2state.pth'.format(opt.test_id1))
    # model.load_state_dict(torch.load(weight_path))

    model = PixelEncoder(opt).cuda()
    model = nn.DataParallel(model, device_ids=[0, 1, 2, 3])
    weight_path = os.path.join(
        opt.data_root, 'data_{}/img2state_large.pth'.format(opt.test_id1))
    try:
        model.load_state_dict(torch.load(weight_path))
        print('continue training!')
    except:
        print('training from scratch!')

    dataset = RobotStackdata.get_loader(opt)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = nn.L1Loss()

    for epoch in range(opt.epoch_size):
        epoch_loss = 0
        for i, item in enumerate(tqdm(dataset)):
            state, action, result = item[1]
            input_Bt0 = item[0][0]
            input_Bt1 = item[0][2]
            action = item[0][1]
            gt0 = item[2][0].float().cuda()
            gt1 = item[2][1].float().cuda()

            img = input_Bt0.float().cuda()
            gt = gt0.float().cuda()

            out = model(img)
            loss = loss_fn(out, gt)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        print('epoch:{} loss:{:.7f}'.format(epoch, epoch_loss / len(dataset)))
        weight_path = os.path.join(
            opt.data_root, 'data_{}/img2state_large.pth'.format(opt.test_id1))
        torch.save(model.state_dict(), weight_path)