def iter_train(opt):
    logs = init_logs(opt)
    opt.istrain = False
    eval_logs = init_logs(opt)
    opt.istrain = True

    data_agent = IterCycleData(opt)
    setup_seed(args.seed)
    model = CycleGANModel(opt)
    model.update(opt)
    best_reward = -1000
    for iter in range(opt.iterations):
        # Train
        opt.istrain = True
        best_reward = train(opt, data_agent, model, iter, best_reward, logs)
        if opt.finetune:
            opt.pair_n = 700
            opt.display_gap = 100
            opt.eval_gap = 100

        # Test
        opt.istrain = False
        opt.init_start = False
        with torch.no_grad():
            test(opt, iter, eval_logs)

        # Collect Data

        collect_data(opt, data_agent, model)
示例#2
0
def main(args):
    txt_logs, img_logs, weight_logs = init_logs(args)
    model = CycleGANModel(args)
    model.load(weight_logs)
    model.cross_policy.eval_policy(gxmodel=gxmodel,
                                   axmodel=axmodel,
                                   eval_episodes=10)
示例#3
0
def train(opt):
    img_logs, weight_logs,tensor_writer = init_logs(opt)
    dataset = BoltData.get_loader(opt)
    model = CycleGANModel(opt)
    model.train_forward_state(dataset, opt.pretrain_f)

    niter,best_loss = 0,100
    for epoch_id in range(opt.epoch_size):
        for batch_id, data in enumerate(dataset):
            model.set_input(data)
            model.optimize_parameters()

            if (batch_id) % opt.display_gap == 0:
                errors = model.get_current_errors()
                display = '\n===> Epoch[{}]({}/{})'.format(epoch_id, batch_id, len(dataset))
                for key, value in errors.items():
                    display += '{}:{:.4f}  '.format(key, value)
                    tensor_writer.add_scalar(key, value, niter)
                    niter += 1
                print(display)

                cur_loss = errors['L_t0']
                if cur_loss<best_loss:
                    best_loss = cur_loss
                    model.save(weight_logs)
                path = os.path.join(img_logs, 'img_batch_{}.jpg'.format(niter))
                model.visual(path)
示例#4
0
def get_state(opt):
    img_logs, weight_logs, tensor_writer = init_logs(opt)
    dataset = BoltData.get_loader(opt)
    model = CycleGANModel(opt)
    model.parallel_init([0, 1, 2, 3])
    model.load(weight_logs)

    pred, gt = [], []
    for batch_id, data in enumerate(dataset):
        model.set_input(data)
        model.test()

        pred.append(model.fake_A)
        gt.append(model.gt0)

        print(batch_id)

        # if batch_id>20:
        #     break

    pred = torch.cat(pred, 0).cpu().data.numpy()
    gt = torch.cat(gt, 0).cpu().data.numpy()
    print(abs(pred - gt).mean(0))

    np.save(weight_logs.replace('weights', 'pred_z.npy'), pred)
    np.save(weight_logs.replace('weights', 'gt_z.npy'), gt)
示例#5
0
def test(args):
    args.istrain = False
    args.init_start = False
    txt_logs, img_logs, weight_logs = init_logs(args)
    data_agent = CycleData(args)
    model = CycleGANModel(args)
    model.fengine.train_statef(data_agent.data1)
    print(weight_logs)
    model.load(weight_logs)
    model.update(args)

    model.cross_policy.eval_policy(
        gxmodel=model.netG_B,
        axmodel=model.net_action_G_A,
        # imgpath=img_logs,
        eval_episodes=10)
示例#6
0
def test(args):
    args.istrain = False
    args.init_start = False
    txt_logs, img_logs, weight_logs = init_logs(args)
    data_agent = CycleData(args)
    model = CycleGANModel(args)
    model.fengine.train_statef(data_agent.data1)
    print(weight_logs)
    model.load(weight_logs)
    model.update(args)

    reward, success_rate = model.cross_policy.eval_policy(
        gxmodel=model.netG_B,
        axmodel=model.net_action_G_A,
        # imgpath=img_logs,
        eval_episodes=100)

    txt_logs.write('Final Evaluation: {}, Success Rate: {}\n'.format(
        reward, success_rate))
    txt_logs.flush()
示例#7
0
def train(args):
    txt_logs, img_logs, weight_logs = init_logs(args)
    data_agent = CycleData(args)
    model = CycleGANModel(args)
    model.fengine.train_statef(data_agent.data1)
    model.cross_policy.eval_policy(gxmodel=model.netG_B,
                                   axmodel=model.net_action_G_A,
                                   eval_episodes=10)

    best_reward = 0
    end_id = 0
    for iteration in range(3):

        args.lr_Gx = 1e-4
        args.lr_Ax = 0
        model.update(args)

        start_id = end_id
        end_id = start_id + args.pair_n
        for batch_id in range(start_id, end_id):
            item = data_agent.sample()
            data1, data2 = item
            model.set_input(item)
            model.optimize_parameters()
            real, fake = model.fetch()

            if (batch_id + 1) % args.display_gap == 0:
                display = '\n===> Batch[{}/{}]'.format(batch_id + 1,
                                                       args.pair_n)
                print(display)
                display = add_errors(model, display)
                txt_logs.write('{}\n'.format(display))
                txt_logs.flush()

                path = os.path.join(img_logs,
                                    'imgA_{}.jpg'.format(batch_id + 1))
                model.visual(path)

            if (batch_id + 1) % args.eval_gap == 0:
                reward = model.cross_policy.eval_policy(
                    gxmodel=model.netG_B,
                    axmodel=model.net_action_G_A,
                    eval_episodes=args.eval_n)
                if reward > best_reward:
                    best_reward = reward
                    model.save(weight_logs)
                print('best_reward:{:.1f}  cur_reward:{:.1f}'.format(
                    best_reward, reward))

        args.init_start = False
        args.lr_Gx = 0
        args.lr_Ax = 1e-4
        model.update(args)

        start_id = end_id
        end_id = start_id + args.pair_n
        for batch_id in range(start_id, end_id):
            item = data_agent.sample()
            data1, data2 = item
            model.set_input(item)
            model.optimize_parameters()
            real, fake = model.fetch()

            if (batch_id + 1) % args.display_gap == 0:
                display = '\n===> Batch[{}/{}]'.format(batch_id + 1,
                                                       args.pair_n)
                print(display)
                display = add_errors(model, display)
                txt_logs.write('{}\n'.format(display))
                txt_logs.flush()

                path = os.path.join(img_logs,
                                    'imgA_{}.jpg'.format(batch_id + 1))
                model.visual(path)

            if (batch_id + 1) % args.eval_gap == 0:
                reward = model.cross_policy.eval_policy(
                    gxmodel=model.netG_B,
                    axmodel=model.net_action_G_A,
                    eval_episodes=args.eval_n)
                if reward > best_reward:
                    best_reward = reward
                    model.save(weight_logs)

                print('best_reward:{:.1f}  cur_reward:{:.1f}'.format(
                    best_reward, reward))