Пример #1
0
    def train(self, data_loader):
        print('Training...')
        with torch.autograd.set_detect_anomaly(True):
            self.epoch += 1
            self.G.train()
            self.D.train()
            record_G = utils.Record()
            record_D = utils.Record()
            start_time = time.time()
            progress = progressbar.ProgressBar(maxval=len(data_loader)).start()
            for i, (trace, image) in enumerate(data_loader):
                progress.update(i + 1)
                trace = trace.cuda()
                image = image.cuda()

                self.D.zero_grad()
                # update D with real images
                real_output = self.D(image)
                err_D_real = self.loss(real_output, self.real_label)
                D_x = real_output.data.mean()
                # update D with reconstructed images
                fake_input, *_ = self.trace2image(trace)
                fake_refine = self.G(fake_input)
                fake_output = self.D(fake_refine.detach())
                err_D_fake = self.loss(fake_output, self.fake_label)
                D_G_z = fake_output.data.mean()

                err_D = err_D_fake + err_D_real
                err_D.backward()
                self.optimizerD.step()

                self.G.zero_grad()
                # update G
                fake_output = self.D(fake_refine)
                err_G = self.loss(fake_output, self.real_label)

                err_G.backward()
                self.optimizerG.step()

                record_D.add(err_D.item())
                record_G.add(err_G.item())
            progress.finish()
            utils.clear_progressbar()
            print('----------------------------------------')
            print('Epoch: %d' % self.epoch)
            print('Costs time: %.2f s' % (time.time() - start_time))
            print('Loss of G: %f' % (record_G.mean()))
            print('Loss of D: %f' % (record_D.mean()))
            print('D(x): %f, D(G(z)): %f' % (D_x, D_G_z))
            print('----------------------------------------')
            utils.save_image(image.data, ('%s/image/test/target_%03d.jpg' %
                                          (self.args['gan_dir'], self.epoch)))
            utils.save_image(trace2image.data,
                             ('%s/image/test/tr2im_%03d.jpg' %
                              (self.args['gan_dir'], self.epoch)))
            utils.save_image(image2image.data,
                             ('%s/image/test/im2im_%03d.jpg' %
                              (self.args['gan_dir'], self.epoch)))
Пример #2
0
    def train(self, data_loader):
        print('Training...')
        with torch.autograd.set_detect_anomaly(True):
            self.epoch += 1
            self.set_train()
            record_trace = utils.Record()
            record_image = utils.Record()
            record_inter = utils.Record()
            record_kld = utils.Record()
            start_time = time.time()
            progress = progressbar.ProgressBar(maxval=len(data_loader)).start()
            for i, (trace, image) in enumerate(data_loader):
                progress.update(i + 1)
                trace = trace.cuda()
                image = image.cuda()
                self.zero_grad()
                trace_embed = self.TraceEncoder(trace)
                image_embed = self.ImageEncoder(image)
                trace_mu, trace_logvar = trace_embed, trace_embed
                image_mu, image_logvar = image_embed, image_embed
                trace_z = utils.reparameterize(trace_mu, trace_logvar)
                image_z = utils.reparameterize(image_mu, image_logvar)
                trace2image, trace_inter = self.Decoder(trace_z)
                image2image, image_inter = self.Decoder(image_z)

                err_trace = self.l1(trace2image, image)
                err_image = self.l1(image2image, image)
                #err_inter = self.l2(trace_inter, image_inter)
                err_kld = self.kld(image_mu, image_logvar, trace_mu,
                                   trace_logvar)

                #(err_trace + err_image + err_inter + self.args['beta'] * err_kld).backward()
                (err_trace + err_image +
                 self.args['beta'] * err_kld).backward()

                self.optimizer.step()

                record_trace.add(err_trace)
                record_image.add(err_image)
                #record_inter.add(err_inter)
                record_kld.add(err_kld)
            progress.finish()
            utils.clear_progressbar()
            print('----------------------------------------')
            print('Epoch: %d' % self.epoch)
            print('Costs time: %.2fs' % (time.time() - start_time))
            print('Loss of Trace to Image: %f' % (record_trace.mean()))
            print('Loss of Image to Image: %f' % (record_image.mean()))
            print('Loss of KL-Divergence: %f' % (record_kld.mean()))
            print('----------------------------------------')
            utils.save_image(image.data, ('%s/image/train/target_%03d.jpg' %
                                          (self.args['vae_dir'], self.epoch)))
            utils.save_image(trace2image.data,
                             ('%s/image/train/tr2im_%03d.jpg' %
                              (self.args['vae_dir'], self.epoch)))
            utils.save_image(image2image.data,
                             ('%s/image/train/im2im_%03d.jpg' %
                              (self.args['vae_dir'], self.epoch)))
Пример #3
0
    def test(self, data_loader):
        print('Testing...')
        with torch.no_grad():
            self.G.eval()
            self.D.eval()
            record_G = utils.Record()
            record_D = utils.Record()
            start_time = time.time()
            progress = progressbar.ProgressBar(maxval=len(data_loader)).start()
            for i, (trace, image) in enumerate(data_loader):
                progress.update(i + 1)
                trace = trace.cuda()
                image = image.cuda()

                real_output = self.D(image)
                err_D_real = self.loss(real_output, self.real_label)
                D_x = real_output.data.mean()

                fake_input, *_ = self.trace2image(trace)
                fake_refine = self.G(fake_input)
                fake_output = self.D(fake_refine.detach())
                err_D_fake = self.loss(fake_output, self.fake_label)
                D_G_z = fake_output.data.mean()

                err_D = err_D_fake + err_D_real

                fake_output = self.D(fake_refine)
                err_G = self.loss(fake_output, self.real_label)

                record_D.add(err_D.item())
                record_G.add(err_G.item())
            progress.finish()
            utils.clear_progressbar()
            print('----------------------------------------')
            print('Test at Epoch %d' % self.epoch)
            print('Costs time: %.2f s' % (time.time() - start_time))
            print('Loss of G: %f' % (record_G.mean()))
            print('Loss of D: %f' % (record_D.mean()))
            print('D(x): %f, D(G(z)): %f' % (D_x, D_G_z))
            print('----------------------------------------')
            utils.save_image(image.data, ('%s/image/test/target_%03d.jpg' %
                                          (self.args['gan_dir'], self.epoch)))
            utils.save_image(trace2image.data,
                             ('%s/image/test/tr2im_%03d.jpg' %
                              (self.args['gan_dir'], self.epoch)))
            utils.save_image(image2image.data,
                             ('%s/image/test/im2im_%03d.jpg' %
                              (self.args['gan_dir'], self.epoch)))
Пример #4
0
    def init(self):
        self.root_dir = os.path.abspath(os.path.dirname(__file__))
        self.server_running = False
        self.server = None
        self.server_logs = []

        stdout_profile = utils.Record()
        stdout_profile.level = self.config.log_level
        stdout_profile.output = sys.stdout
        log_profiles = [stdout_profile]
        self.logger = log.Logger('mana9er', log_profiles)

        # load plugins
        self._plugins = {}
        for plugin_name in self.config.plugin_names:
            plugin_logger = log.Logger(plugin_name, log_profiles)
            self._plugins[plugin_name] = importlib.import_module(
                plugin_name).load(plugin_logger,
                                  self)  # import plugins, call init function
        self.build_builtin_callback()
        self.start_server()
Пример #5
0
def main():
    qid2list = {}
    inc = 0
    for line in sys.stdin:
        parts = line.split(' ')
        score = float(parts[0])
        rating = int(parts[-1].split('=')[1])
        qid = parts[2]
        vec = utils.to_dense(parts[3:])
        seq = qid2list.get(qid)
        if seq is None:
            seq = []
            qid2list[qid] = seq

        record = utils.Record(qid, rating, vec, score=score)
        seq.append(record)
        inc += 1

    for qid, seq in qid2list.items():
        seq.sort(key=lambda x: x.score, reverse=True)

    print 'click_num', click_num(qid2list)
    print 'view_deep', view_deep(qid2list)
    print 'map', map(qid2list)
Пример #6
0
for line in sys.stdin:
    user_id, movie_id, rating, time_stamp = line.split(',')
    user_id = int(user_id)
    movie_id = int(movie_id)
    rating = float(rating)
    time_stamp = int(time_stamp)

    behaviors = user_2_behaviors.get(user_id)
    if behaviors is None:
        behaviors = []
        user_2_behaviors[user_id] = behaviors

    vec = movie_2_vec.get(movie_id)
    if vec is not None:
        record = utils.Record(movie_id, rating, vec, score=time_stamp)
        behaviors.append(record)


def vec_to_string(vec):
    vec_list = vec.tolist()
    _vec_str = [str(o) for o in vec_list]
    return ':'.join(_vec_str)


for user_id, behaviors in user_2_behaviors.items():
    if len(behaviors) < 50:
        continue
    behaviors = sorted(behaviors, key=lambda x: x.score)

    last_behaviors = behaviors[0:30]
Пример #7
0
import sys
import utils

qid2list = {}

ARG_type, ARG_do_bounce, ARG_seq_len = sys.argv[1], sys.argv[2], sys.argv[3]
ARG_seq_len = int(ARG_seq_len)

for line in sys.stdin:
    parts = line.split(' ')
    score = float(parts[0])
    rating = int(parts[1])
    qid = parts[2]
    vec = utils.to_dense(parts[3:])
    record = utils.Record(qid, rating, vec, score=score)
    seq = qid2list.get(qid)
    if seq is None:
        seq = []
        qid2list[qid] = seq

    seq.append(record)

for qid, seq in qid2list.items():
    if len(seq) < ARG_seq_len:
        continue
    seq.sort(key=lambda x: x.score, reverse=True)
    if len(seq) > ARG_seq_len:
        seq = seq[0:ARG_seq_len]

    fb_seq = utils.gen_scan_seq(seq)