Example #1
0
    def bp(self, fake_logit: Variable, prob):
        bs = fake_logit.size()[0]
        self.opt.zero_grad()
        reward = torch.tanh(fake_logit.detach())
        # loss = -(torch.mean(torch.log(prob) * reward)).backward()
        torch.log(prob).backward(-reward / bs)

        self.opt.step()
Example #2
0
    def forward(self, x: List):
        """
        Forward pass for  all architecture
        :param x: Has different meaning with different mode of training
        :return:
        """

        if self.mode == 1:
            '''
            Variable length training. This mode runs for one
            more than the length of program for producing stop symbol. Note
            that there is no padding as is done in traditional RNN for
            variable length programs. This is done mainly because of computational
            efficiency of forward pass, that is, each batch contains only
            programs of same length and losses from all batches of
            different time-lengths are combined to compute gradient and
            update in the network. This ensures that every update of the
            network has equal contribution coming from programs of different lengths.
            Training is done using the script train_synthetic.py
            '''
            data, input_op, program_len = x

            assert data.size()[0] == program_len + 1, "Incorrect stack size!!"
            batch_size = data.size()[1]
            h = Variable(torch.zeros(1, batch_size, self.hd_sz)).cuda()
            x_f = self.encoder.encode(data[-1, :, 0:1, :, :])
            x_f = x_f.view(1, batch_size, self.in_sz)
            if not self.tf:
                outputs = []
                for timestep in range(0, program_len + 1):
                    # X_f is always input to the RNN at every time step
                    # along with previous predicted label
                    input_op_rnn = self.relu(
                        self.dense_input_op(input_op[:, timestep, :]))
                    input_op_rnn = input_op_rnn.view(1, batch_size,
                                                     self.input_op_sz)
                    #input_op_rnn = torch.zeros((1, batch_size, self.input_op_sz)).cuda()
                    input = torch.cat((self.drop(x_f), input_op_rnn), 2)
                    h, _ = self.rnn(input, h)
                    hd = self.relu(self.dense_fc_1(self.drop(h[0])))
                    output = self.dense_output(self.drop(hd))
                    outputs.append(output)
                return torch.stack(outputs)
            else:
                # remove stop token for input to decoder
                input_op_rnn = self.relu(
                    self.dense_input_op(input_op))[:, :-1, :].permute(1, 0, 2)
                # input_op_rnn = torch.zeros((program_len+1, batch_size, self.input_op_sz)).cuda()
                x_f = x_f.repeat(program_len + 1, 1, 1)
                input = torch.cat((self.drop(x_f), input_op_rnn), 2)
                output, h = self.rnn(input, h)
                output = self.relu(self.dense_fc_1(self.drop(output)))
                output = self.dense_output(self.drop(output))
                return output

        elif self.mode == 2:
            '''Train variable length RL'''
            # program length in this case is the maximum time step that RNN runs
            data, input_op, program_len = x
            batch_size = data.size()[1]
            h = Variable(torch.zeros(1, batch_size, self.hd_sz)).cuda()
            x_f = self.encoder.encode(data[-1, :, 0:1, :, :])
            x_f = x_f.view(1, batch_size, self.in_sz)
            outputs = []
            samples = []
            temp_input_op = input_op[:, 0, :]
            for timestep in range(0, program_len):
                # X_f is the input to the RNN at every time step along with previous
                # predicted label
                input_op_rnn = self.relu(self.dense_input_op(temp_input_op))
                input_op_rnn = input_op_rnn.view(1, batch_size,
                                                 self.input_op_sz)
                input = torch.cat((x_f, input_op_rnn), 2)
                h, _ = self.rnn(input, h)
                hd = self.relu(self.dense_fc_1(self.drop(h[0])))
                dense_output = self.dense_output(self.drop(hd))
                output = self.logsoftmax(dense_output)
                # output for loss, these are log-probabs
                outputs.append(output)

                output_probs = self.softmax(dense_output)
                # Get samples from output probabs based on epsilon greedy way
                # Epsilon will be reduced to 0 gradually following some schedule
                if np.random.rand() < self.epsilon:
                    # This is during training
                    sample = torch.multinomial(output_probs, 1)
                else:
                    # This is during testing
                    sample = torch.max(output_probs, 1)[1].view(batch_size, 1)

                # Stopping the gradient to flow backward from samples
                sample = sample.detach()
                samples.append(sample)

                # Create next input to the RNN from the sampled instructions
                arr = Variable(
                    torch.zeros(batch_size, self.num_draws + 1).scatter_(
                        1, sample.data.cpu(), 1.0)).cuda()
                arr = arr.detach()
                temp_input_op = arr
            return [outputs, samples]
        else:
            assert False, "Incorrect mode!!"
Example #3
0
def train_simple_trans():
    opt = TrainOptions().parse()
    data_root = 'data/processed'
    train_params = {'lr': 0.01, 'epoch_milestones': (100, 500)}
    # dataset = DynTexNNFTrainDataset(data_root, 'flame')
    dataset = DynTexFigureTrainDataset(data_root, 'flame')
    dataloader = DataLoader(dataset=dataset,
                            batch_size=opt.batchsize,
                            num_workers=opt.num_workers,
                            shuffle=True)
    nnf_conf = 3
    syner = Synthesiser()
    nnfer = NNFPredictor(out_channel=nnf_conf)
    if torch.cuda.is_available():
        syner = syner.cuda()
        nnfer = nnfer.cuda()
    optimizer_nnfer = Adam(nnfer.parameters(), lr=train_params['lr'])
    table = Table()
    for epoch in range(opt.epoch):
        pbar = tqdm(total=len(dataloader), desc='epoch#{}'.format(epoch))
        pbar.set_postfix({'loss': 'N/A'})
        loss_tot = 0.0
        gamma = epoch / opt.epoch

        for i, (source_t, target_t, source_t1,
                target_t1) in enumerate(dataloader):

            if torch.cuda.is_available():
                source_t = Variable(source_t, requires_grad=True).cuda()
                target_t = Variable(target_t, requires_grad=True).cuda()
                source_t1 = Variable(source_t1, requires_grad=True).cuda()
                target_t1 = Variable(target_t1, requires_grad=True).cuda()
            nnf = nnfer(source_t, target_t)
            if nnf_conf == 3:
                nnf = nnf[:, :
                          2, :, :] * nnf[:,
                                         2:, :, :]  # mask via the confidence
            # --- synthesis ---
            target_predict = syner(source_t, nnf)
            target_t1_predict = syner(source_t1, nnf)
            loss_t = tnf.mse_loss(target_predict, target_t)
            loss_t1 = tnf.mse_loss(target_t1_predict, target_t1)
            loss = loss_t + loss_t1

            optimizer_nnfer.zero_grad()
            loss.backward()
            optimizer_nnfer.step()
            loss_tot += float(loss_t1)

            # ---   vis    ---
            name = os.path.join(data_root, '../result/', str(epoch),
                                '{}.png'.format(str(i)))
            index = str(epoch) + '({})'.format(i)
            if not os.path.exists('/'.join(name.split('/')[:-1])):
                os.makedirs('/'.join(name.split('/')[:-1]))

            cv2.imwrite(
                name.replace('.png', '_s.png'),
                (source_t.detach().cpu().numpy()[0].transpose(1, 2, 0) *
                 255).astype('int'))

            cv2.imwrite(
                name.replace('.png', '_s1.png'),
                (source_t1.detach().cpu().numpy()[0].transpose(1, 2, 0) *
                 255).astype('int'))

            cv2.imwrite(
                name.replace('.png', '_t.png'),
                (target_t.detach().cpu().numpy()[0].transpose(1, 2, 0) *
                 255).astype('int'))

            cv2.imwrite(
                name.replace('.png', '_p.png'),
                (target_predict.detach().cpu().numpy()[0].transpose(1, 2, 0) *
                 255).astype('int'))

            cv2.imwrite(
                name.replace('.png', '_t1.png'),
                (target_t1.detach().cpu().numpy()[0].transpose(1, 2, 0) *
                 255).astype('int'))

            cv2.imwrite(name.replace('.png', '_p1.png'),
                        (target_t1_predict.detach().cpu().numpy()[0].transpose(
                            1, 2, 0) * 255).astype('int'))

            # vis in table
            table.add(
                index,
                os.path.abspath(name.replace('.png', '_s.png')).replace(
                    '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
            table.add(
                index,
                os.path.abspath(name.replace('.png', '_s1.png')).replace(
                    '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
            table.add(
                index,
                os.path.abspath(name.replace('.png', '_t.png')).replace(
                    '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
            table.add(
                index,
                os.path.abspath(name.replace('.png', '_t1.png')).replace(
                    '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
            table.add(
                index,
                os.path.abspath(name.replace('.png', '_p.png')).replace(
                    '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
            table.add(
                index,
                os.path.abspath(name.replace('.png', '_p1.png')).replace(
                    '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
            pbar.set_postfix({'loss': str(loss_tot / (i + 1))})
            pbar.update(1)
        table.build_html('data/')
        pbar.close()
Example #4
0
def train_complex_trans():
    opt = TrainOptions().parse()
    data_root = 'data/processed'
    train_params = {'lr': 0.001, 'epoch_milestones': (100, 500)}
    dataset = DynTexFigureTransTrainDataset(data_root, 'flame')
    dataloader = DataLoader(dataset=dataset,
                            batch_size=opt.batchsize,
                            num_workers=opt.num_workers,
                            shuffle=True)
    nnf_conf = 3
    syner = Synthesiser()
    nnfer = NNFPredictor(out_channel=nnf_conf)
    flownet = NNFPredictor(out_channel=nnf_conf)
    if torch.cuda.is_available():
        syner = syner.cuda()
        nnfer = nnfer.cuda()
        flownet = flownet.cuda()
    optimizer_nnfer = Adam(nnfer.parameters(), lr=train_params['lr'])
    optimizer_flow = Adam(flownet.parameters(), lr=train_params['lr'] * 0.1)
    scheduler_nnfer = lr_scheduler.MultiStepLR(
        optimizer_nnfer,
        gamma=0.1,
        last_epoch=-1,
        milestones=train_params['epoch_milestones'])
    scheduler_flow = lr_scheduler.MultiStepLR(
        optimizer_flow,
        gamma=0.1,
        last_epoch=-1,
        milestones=train_params['epoch_milestones'])
    table = Table()
    writer = SummaryWriter(log_dir=opt.log_dir)
    for epoch in range(opt.epoch):
        scheduler_flow.step()
        scheduler_nnfer.step()
        pbar = tqdm(total=len(dataloader), desc='epoch#{}'.format(epoch))
        pbar.set_postfix({'loss': 'N/A'})
        loss_tot = 0.0
        for i, (source_t, target_t, source_t1,
                target_t1) in enumerate(dataloader):
            if torch.cuda.is_available():
                source_t = Variable(source_t, requires_grad=True).cuda()
                target_t = Variable(target_t, requires_grad=True).cuda()
                source_t1 = Variable(source_t1, requires_grad=True).cuda()
                target_t1 = Variable(target_t1, requires_grad=True).cuda()

            nnf = nnfer(source_t, target_t)
            flow = flownet(source_t, source_t1)
            # mask...
            if nnf_conf == 3:
                nnf = nnf[:, :
                          2, :, :] * nnf[:,
                                         2:, :, :]  # mask via the confidence
                flow = flow[:, :2, :, :] * flow[:, 2:, :, :]
            # --- synthesis ---
            source_t1_predict = syner(source_t, flow)  # flow penalty
            target_flow = syner(flow, nnf)  # predict flow
            target_t1_predict = syner(target_t, target_flow)
            # target_t1_predict = syner(source_t1, nnf)

            loss_t1_f = tnf.mse_loss(source_t1,
                                     source_t1_predict)  # flow penalty
            loss_t1 = tnf.mse_loss(target_t1_predict,
                                   target_t1)  # total penalty
            loss = loss_t1_f + loss_t1 * 2

            optimizer_flow.zero_grad()
            optimizer_nnfer.zero_grad()
            loss.backward()
            optimizer_nnfer.step()
            optimizer_flow.step()
            loss_tot += float(loss)

            # ---   vis    ---
            if epoch % 10 == 0 and i % 2 == 0:
                name = os.path.join(data_root, '../result/', str(epoch),
                                    '{}.png'.format(str(i)))
                index = str(epoch) + '({})'.format(i)
                if not os.path.exists('/'.join(name.split('/')[:-1])):
                    os.makedirs('/'.join(name.split('/')[:-1]))

                cv2.imwrite(
                    name.replace('.png', '_s.png'),
                    (source_t.detach().cpu().numpy()[0].transpose(1, 2, 0) *
                     255).astype('int'))

                cv2.imwrite(
                    name.replace('.png', '_s1.png'),
                    (source_t1.detach().cpu().numpy()[0].transpose(1, 2, 0) *
                     255).astype('int'))

                cv2.imwrite(
                    name.replace('.png', '_t.png'),
                    (target_t.detach().cpu().numpy()[0].transpose(1, 2, 0) *
                     255).astype('int'))

                cv2.imwrite(
                    name.replace('.png', '_p.png'),
                    (source_t1_predict.detach().cpu().numpy()[0].transpose(
                        1, 2, 0) * 255).astype('int'))

                cv2.imwrite(
                    name.replace('.png', '_t1.png'),
                    (target_t1.detach().cpu().numpy()[0].transpose(1, 2, 0) *
                     255).astype('int'))

                cv2.imwrite(
                    name.replace('.png', '_p1.png'),
                    (target_t1_predict.detach().cpu().numpy()[0].transpose(
                        1, 2, 0) * 255).astype('int'))

                # vis in table
                table.add(
                    index,
                    os.path.abspath(name.replace('.png', '_s.png')).replace(
                        '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
                table.add(
                    index,
                    os.path.abspath(name.replace('.png', '_s1.png')).replace(
                        '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
                table.add(
                    index,
                    os.path.abspath(name.replace('.png', '_t.png')).replace(
                        '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
                table.add(
                    index,
                    os.path.abspath(name.replace('.png', '_t1.png')).replace(
                        '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
                table.add(
                    index,
                    os.path.abspath(name.replace('.png', '_p.png')).replace(
                        '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
                table.add(
                    index,
                    os.path.abspath(name.replace('.png', '_p1.png')).replace(
                        '/mnt/cephfs_hl/lab_ad_idea/maoyiming', ''))
            pbar.set_postfix({'loss': str(loss_tot / (i + 1))})
            writer.add_scalar('scalars/{}/loss_train'.format(opt.time),
                              float(loss), i + int(epoch * len(dataloader)))
            writer.add_scalar('scalars/{}/lr'.format(opt.time),
                              float(scheduler_nnfer.get_lr()[0]),
                              i + int(epoch * len(dataloader)))
            pbar.update(1)
        table.build_html('data/')
        pbar.close()
    writer.close()
Example #5
0
for epoch in range(opt.n_epochs):
    print('Epoch {}'.format(epoch))
    for i, (batch_A, batch_B) in enumerate(dataloader_train):

        real = Variable(Tensor(batch_A.size(0), 1).fill_(1),
                        requires_grad=False)
        fake = Variable(Tensor(batch_A.size(0), 1).fill_(0),
                        requires_grad=False)

        imgs_real_A = Variable(batch_A.type(Tensor))
        imgs_real_B = Variable(batch_B.type(Tensor))

        # == Discriminator update == #
        optimizer_D.zero_grad()

        imgs_fake = Variable(generator(imgs_real_A.detach()))

        d_loss = gan_loss(discriminator(imgs_real_A, imgs_real_B),
                          real) + gan_loss(
                              discriminator(imgs_real_A, imgs_fake), fake)

        d_loss.backward()
        optimizer_D.step()

        # == Generator update == #
        imgs_fake = generator(imgs_real_A)

        optimizer_G.zero_grad()

        g_loss = gan_loss(
            discriminator(imgs_real_A, imgs_fake),
Example #6
0
for epoch in range(opt.n_epochs):
    print('Epoch {}'.format(epoch))
    for i, (batch_lr, batch_hr) in enumerate(dataloader_train):

        real = Variable(Tensor(batch_lr.size(0), 1).fill_(1),
                        requires_grad=False)
        fake = Variable(Tensor(batch_lr.size(0), 1).fill_(0),
                        requires_grad=False)

        imgs_real_lr = Variable(batch_lr.type(Tensor))
        imgs_real_hr = Variable(batch_hr.type(Tensor))

        # == Discriminator update == #
        optimizer_D.zero_grad()

        imgs_fake_hr = Variable(generator(imgs_real_lr.detach()))

        d_loss = gan_loss(discriminator(imgs_real_hr), real) + gan_loss(
            discriminator(imgs_fake_hr), fake)

        d_loss.backward()
        optimizer_D.step()

        # == Generator update == #
        imgs_fake_hr = generator(imgs_real_lr)

        optimizer_G.zero_grad()

        g_loss = (1 / 12.75) * content_loss(
            vgg(imgs_fake_hr), vgg(imgs_real_hr.detach())) + 1e-3 * gan_loss(
                discriminator(imgs_fake_hr), real)