Ejemplo n.º 1
0
def eval(opt):
    model = CycleGANModel(opt)
    dataset = CDFdata.get_loader(opt)
    img_logs,weight_logs = init_logs(opt)
    model.load(weight_logs)

    ave_loss = {}
    for batch_id, data in enumerate(dataset):
        model.set_input(data)
        model.test()

        errors = model.get_current_errors()
        display = '===> Batch({}/{})'.format(batch_id, len(dataset))
        for key, value in errors.items():
            display += '{}:{:.4f}  '.format(key, value)
            try:
                ave_loss[key] = ave_loss[key] + value
            except:
                ave_loss[key] = value
        print(display)

        if (batch_id + 1) % opt.display_gap == 0:
            path = os.path.join(img_logs, 'imgA_{}.jpg'.format(batch_id + 1))
            model.visual(path)
            model.save(weight_logs)

    display ='average loss: '
    for key, value in ave_loss.items():
        display += '{}:{:.4f}  '.format(key, value/len(dataset))
    print(display)
Ejemplo n.º 2
0
    def __init__(self, opt):
        self.opt = opt
        self.dynamic = opt.dynamic
        self.isTrain = opt.istrain
        self.Tensor = torch.cuda.FloatTensor

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)

        self.netG_B = img2state().cuda()
        self.netF_A = Fmodel().cuda()
        self.netAction = Amodel().cuda()
        self.dataF = CDFdata.get_loader(opt)
        self.train_forward(pretrained=True)

        self.gt_buffer = []
        self.pred_buffer = []

        # if self.isTrain:
        self.netD_B = stateDmodel().cuda()

        # if self.isTrain:
        self.fake_A_pool = ImagePool(pool_size=128)
        self.fake_B_pool = ImagePool(pool_size=128)
        # define loss functions
        self.criterionGAN = GANLoss(tensor=self.Tensor).cuda()
        if opt.loss == 'l1':
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
        elif opt.loss == 'l2':
            self.criterionCycle = torch.nn.MSELoss()
            self.criterionIdt = torch.nn.MSELoss()
        # initialize optimizers
        # self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()))
        self.optimizer_G = torch.optim.Adam([{
            'params': self.netF_A.parameters(),
            'lr': self.opt.F_lr
        }, {
            'params': self.netG_B.parameters(),
            'lr': self.opt.G_lr
        }, {
            'params':
            self.netAction.parameters(),
            'lr':
            self.opt.G_lr
        }])
        self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters())

        print('---------- Networks initialized ---------------')
        print('-----------------------------------------------')
Ejemplo n.º 3
0
def train(opt):
    model = CycleGANModel(opt)
    # model = CycleActionModel(opt)
    dataset = CDFdata.get_loader(opt)
    img_logs,weight_logs = init_logs(opt)
    # model.load(weight_logs)

    for epoch_id in range(opt.epoch_size):
        for batch_id, data in enumerate(dataset):
            model.set_input(data)
            model.optimize_parameters()

            errors = model.get_current_errors()
            display = '===> Epoch[{}]({}/{})'.format(epoch_id, batch_id, len(dataset))
            for key, value in errors.items():
                display += '{}:{:.4f}  '.format(key, value)
            print(display)

            if (batch_id+1) % opt.display_gap == 0:
                path = os.path.join(img_logs, 'imgA_{}_{}.jpg'.format(epoch_id, batch_id+1))
                model.visual(path)
                model.save(weight_logs)
Ejemplo n.º 4
0
                        type=int,
                        default=1,
                        help='datasetA from view1')
    parser.add_argument('--imgB_id',
                        type=int,
                        default=2,
                        help='datasetB from view2')
    parser.add_argument('--data_root',
                        type=str,
                        default='./tmp/data',
                        help='logs root path')

    opt = parser.parse_args()

    model = Fmodel().cuda()
    dataset = CDFdata.get_loader(opt)
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = nn.L1Loss()

    for epoch in range(100):
        for i, item in enumerate(dataset):
            state, action, result = item[1]
            state = state.float().cuda()
            action = action.float().cuda()
            result = result.float().cuda()

            out = model(state, action)
            loss = loss_fn(out, result)
            # loss = loss_fn(result*0.5,(state*0.5+action*0.05))
            print(epoch, i, loss.item())