Exemplo n.º 1
0
    def flush(sess):
        tf.logging.info('saving images')
        np.random.shuffle(image_list)
        for image_info in image_list:
            image_bytes = image_info['image_bytes']
            prob = image_info['prob']
            label = image_info['label']
            example = image_info['example']
            cnt = image_info['cnt']
            record_writer.write(example.SerializeToString())
            if np.random.random() < sample_prob:
                uid = uid_list[label]
                filename = os.path.join(
                    sample_dir, uid, 'image_{:d}_{:d}_{:d}_{:.2f}.jpeg'.format(
                        FLAGS.shard_id, FLAGS.task, cnt, prob))
                tf.logging.info('saving {:s}'.format(filename))
                image = sess.run(
                    decoded_image,
                    feed_dict={image_bytes_placeholder: image_bytes})
                utils.save_pic(image, filename)

                tf.logging.info(
                    '{:d}/{:d} images saved, elapsed time: {:.2f} h'.format(
                        num_picked_images, total_keep_example,
                        (time.time() - start_time) / 3600))
Exemplo n.º 2
0
def train(data_loader, model_index, x_eval_train, loaded_model):
    ### Model Initiation
    if loaded_model:
        ave = VAE()
        ave.cuda()
        saved_state_dict = tor.load(loaded_model)
        ave.load_state_dict(saved_state_dict)
        ave.cuda()
    else:
        ave = VAE()
        ave.cuda()

    loss_func = tor.nn.MSELoss().cuda()

    #optim = tor.optim.SGD(fcn.parameters(), lr=LR, momentum=MOMENTUM)
    optim = tor.optim.Adam(ave.parameters(), lr=LR)

    lr_step = StepLR(optim, step_size=LR_STEPSIZE, gamma=LR_GAMMA)

    x = Variable(tor.FloatTensor(BATCHSIZE, 3, 64, 64)).cuda()
    y = Variable(tor.FloatTensor(BATCHSIZE, 3, 64, 64)).cuda()
    ### Training
    for epoch in range(EPOCH):
        print("|Epoch: {:>4} |".format(epoch + 1))

        ### Training
        for step, (x_batch, y_batch) in enumerate(data_loader):
            print("Process: {}/{}".format(step,
                                          int(AVAILABLE_SIZE[0] / BATCHSIZE)),
                  end="\r")
            x.data.copy_(x_batch)
            y.data.copy_(y_batch)
            out, KLD = ave(x)
            recon_loss = loss_func(out, y)
            loss = (recon_loss + KLD_LAMBDA * KLD)

            loss.backward()
            optim.step()
            lr_step.step()
            optim.zero_grad()

            if step % RECORD_JSON_PERIOD == 0:
                save_record(model_index, epoch, optim, recon_loss, KLD)
            if step % RECORD_PIC_PERIOD == 0:
                save_pic("output_{}".format(model_index), ave, 3)
            if step % RECORD_MODEL_PERIOD == 0:
                tor.save(
                    ave.state_dict(),
                    os.path.join(MODEL_ROOT,
                                 "ave_model_{}.pkl".format(model_index)))
Exemplo n.º 3
0
def train(data_loader, model_index, x_eval_train, gn_fp, dn_fp, gan_gn_fp,
          gan_dn_fp):
    ### Model Initiation
    gn = GN().cuda()
    dn = DN().cuda()

    if gn_fp:
        gn_state_dict = tor.load(gn_fp)
        gn.load_state_dict(gn_state_dict)
    if dn_fp:
        dn_state_dict = tor.load(dn_fp)
        dn.load_state_dict(dn_state_dict)
    if gan_dn_fp:
        dn.load_dn_state(tor.load(gan_dn_fp))
    if gan_gn_fp:
        gn.load_gn_state(tor.load(gan_gn_fp))

    loss_func = tor.nn.BCELoss().cuda()

    optim_gn = tor.optim.Adam(gn.parameters(), lr=LR)
    optim_dn = tor.optim.Adam(dn.parameters(), lr=LR)

    lr_step_gn = StepLR(optim_gn, step_size=LR_STEPSIZE, gamma=LR_GAMMA)
    lr_step_dn = StepLR(optim_dn, step_size=LR_STEPSIZE, gamma=LR_GAMMA)

    x = Variable(tor.FloatTensor(BATCHSIZE, LATENT_SPACE)).cuda()
    img = Variable(tor.FloatTensor(BATCHSIZE, 3, 64, 64)).cuda()

    dis_true = Variable(tor.ones(BATCHSIZE, 1)).cuda()
    dis_false = Variable(tor.zeros(BATCHSIZE, 1)).cuda()
    x_eval_train = Variable(x_eval_train).cuda()

    loss_real, loss_fake = None, None

    ### Training
    for epoch in range(EPOCH):
        print("|Epoch: {:>4} |".format(epoch + 1))

        for step, (x_batch, cls_batch) in enumerate(data_loader):
            print("Process: {}/{}".format(step,
                                          int(AVAILABLE_SIZE[0] / BATCHSIZE)),
                  end="\r")

            ### train true/false pic
            if (step // PIVOT_STEPS) % 5 != 0:
                dn.training = True
                if step % 2 == 0:
                    img.data.copy_(x_batch)
                else:
                    rand_v = tor.randn(BATCHSIZE, LATENT_SPACE)
                    rand_v[:, 0] = tor.FloatTensor(BATCHSIZE).random_(
                        0, 2)  # set attribute dim
                    x.data.copy_(rand_v)
                    out = gn(x)
                    img.data.copy_(out.data)

                dis = dis_true if step % 2 == 0 else dis_false
                cls = Variable(cls_batch).cuda()
                dis_pred, cls_pred = dn(img)
                optim = optim_dn

                loss_dis = loss_func(dis_pred, dis)
                loss_cls = loss_func(cls_pred, cls)
                loss = loss_dis + loss_cls if step % 2 == 0 else loss_dis

                if step % 2 == 0:
                    loss_real = loss_cls
                else:
                    loss_fake = loss_cls

            else:
                dn.training = False
                rand_v = tor.randn(BATCHSIZE, LATENT_SPACE)
                cls = tor.FloatTensor(BATCHSIZE, 1).random_(0, 2)
                rand_v[:, 0] = cls  # set attribute dim
                x.data.copy_(rand_v)
                out = gn(x)
                dis = dis_true
                cls = Variable(cls).cuda()
                dis_pred, cls_pred = dn(out)

                optim = optim_gn

                loss_dis = loss_func(dis_pred, dis)
                loss_cls = loss_func(cls_pred, cls)
                loss = (loss_dis + loss_cls)
                loss_fake = loss_cls
            loss.backward()

            optim.step()

            optim_dn.zero_grad()
            optim_gn.zero_grad()
            lr_step_dn.step()
            lr_step_gn.step()

            if step % RECORD_JSON_PERIOD == 0 and step != 0:
                x_true = x_eval_train
                dis, cls = dn(x_true)
                acc_true = round(int((dis > 0.5).sum().data) / EVAL_SIZE, 5)
                x_noise = tor.randn((EVAL_SIZE, 512))
                x_noise[:, 0] = tor.FloatTensor(EVAL_SIZE, 1).random_(0, 2)
                x_noise = Variable(x_noise).cuda()
                x_false = gn(x_noise)
                dis, cls = dn(x_false)
                acc_false = round(int((dis <= 0.5).sum().data) / EVAL_SIZE, 5)

                print("|Acc True: {}   |Acc False: {}".format(
                    acc_true, acc_false))

                save_record(model_index, epoch, optim, loss_real, loss_fake,
                            acc_true, acc_false)

            if step % RECORD_PIC_PERIOD == 0:
                loss = float(loss.data)
                print("|Loss: {:<8}".format(loss))
                save_pic("output_{}".format(model_index), gn, 4, epoch, step)

        ### Save model
        if epoch != 0:
            tor.save(
                gn.state_dict(),
                os.path.join(
                    MODEL_ROOT,
                    "gan_gn_{}_{}.pkl".format(model_index, epoch // 5)))
            tor.save(
                dn.state_dict(),
                os.path.join(MODEL_ROOT,
                             "gan_dn_{}.pkl".format(model_index, epoch)))
Exemplo n.º 4
0
def run_prediction(estimator):
    global shard_id
    shard_id_list = FLAGS.shard_id.split(',')
    for cur_shard_id in shard_id_list:
        shard_id = int(cur_shard_id)

        worker_image_num = get_num_image()
        cnt, predict_result_list = predict_on_dataset(estimator,
                                                      worker_image_num)
        tf.logging.info('predicted on %d images', cnt)
        assert cnt == worker_image_num, (cnt, worker_image_num)

        tf.gfile.MakeDirs(FLAGS.output_dir)
        if FLAGS.reassign_label:
            sample_dir = os.path.join(FLAGS.output_dir, 'samples')
            uid_list = utils.get_uid_list()
            for uid in uid_list:
                tf.gfile.MakeDirs(os.path.join(sample_dir, uid))

            image_bytes_placeholder = tf.placeholder(dtype=tf.string)
            decoded_image = utils.decode_raw_image(image_bytes_placeholder)

            raw_dst = get_input_fn({'batch_size': 1}, raw_data=True)
            raw_iter = raw_dst.make_initializable_iterator()
            raw_elem = raw_iter.get_next()

            filename = utils.get_reassign_filename(FLAGS.label_data_dir,
                                                   FLAGS.file_prefix, shard_id,
                                                   FLAGS.num_shards,
                                                   FLAGS.worker_id)
            record_writer = tf.python_io.TFRecordWriter(
                os.path.join(FLAGS.output_dir, os.path.basename(filename)))
            sample_prob = 30000. / (worker_image_num * FLAGS.num_shards)
            with tf.Session() as sess:
                sess.run(raw_iter.initializer)
                for i in range(worker_image_num):
                    features = sess.run(raw_elem)
                    encoded_image = features['image/encoded']
                    features = {}
                    label = predict_result_list[i]['label']
                    prob = predict_result_list[i]['prob']
                    features['image/encoded'] = utils.bytes_feature(
                        encoded_image)
                    features['prob'] = utils.float_feature(prob)
                    features['label'] = utils.int64_feature(label)
                    features['probabilities'] = utils.float_feature(
                        predict_result_list[i]['probabilities'])
                    example = tf.train.Example(features=tf.train.Features(
                        feature=features))
                    record_writer.write(example.SerializeToString())
                    if np.random.random() < sample_prob:
                        uid = uid_list[label]
                        filename = os.path.join(
                            sample_dir, uid,
                            'image_{:d}_{:d}_{:.2f}.jpeg'.format(
                                shard_id, i, prob))
                        tf.logging.info('saving {:s}'.format(filename))
                        image = sess.run(
                            decoded_image,
                            feed_dict={image_bytes_placeholder: encoded_image})
                        utils.save_pic(image, filename)

            record_writer.close()
        else:
            filename = 'train-info-%.5d-of-%.5d-%.5d' % (
                shard_id, FLAGS.num_shards, FLAGS.worker_id)
            writer = tf.python_io.TFRecordWriter(
                os.path.join(FLAGS.output_dir, filename))
            for result in predict_result_list:
                features = {}
                features['probabilities'] = utils.float_feature(
                    result['probabilities'])
                features['classes'] = utils.int64_feature(result['label'])

                example = tf.train.Example(features=tf.train.Features(
                    feature=features))
                writer.write(example.SerializeToString())
            writer.close()
Exemplo n.º 5
0
        def closure():
            nonlocal best_loss
            nonlocal input_img
            nonlocal best_input

            input_img.data.clamp_(0, 1)
            optimizer.zero_grad()

            model(input_img)

            style_score = 0
            content_score = 0
            tv_score = 0

            for sl in style_losses:
                style_score += sl.loss

            for cl in content_losses:
                content_score += cl.loss

            for tl in tv_losses:
                tv_score += tl.loss

            style_score *= style_weight
            content_score *= content_weight
            tv_score *= tv_weight

            # Two stage optimaztion pipline
            if run[0] > num_steps // 2:
                # Realistic loss relate sparse matrix computing,
                # which do not support autogard in pytorch, so we compute it separately.
                rl_score, part_grid = realistic_loss_grad(
                    input_img, laplacian_m)
                rl_score *= rl_weight

                loss = style_score + content_score + tv_score + rl_score

                # Store the best result for outputing
                if loss < best_loss:
                    # print(best_loss)
                    best_loss = loss
                    best_input = input_img.data.clone()
            else:
                loss = style_score + content_score + tv_score

                rl_score = torch.zeros(1)  # Just to print

                if loss < best_loss and run[0] > 0:
                    # print(best_loss)
                    best_loss = loss
                    best_input = input_img.data

                if run[0] == num_steps // 2:
                    # Store the best temp result to initialize second stage input
                    input_img.data = best_input
                    best_loss = 1e10

            loss.backward()

            # Gradient cliping deal with gradient exploding
            clip_grad_norm(model.parameters(), 15.0)

            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}/{}:".format(run, num_steps))

                print(
                    'Style Loss: {:4f} Content Loss: {:4f} TV Loss: {:4f} real loss: {:4f}'
                    .format(style_score.item(), content_score.item(),
                            tv_score.item(), rl_score.item()))

                print('Total Loss: ', loss.item())

                saved_img = input_img.clone()
                saved_img.data.clamp_(0, 1)
                utils.save_pic(saved_img, run[0])
            return loss
Exemplo n.º 6
0
            merged_content_mask[count, :, :] = content_mask_origin[
                i, :, :].numpy()
            count += 1
        else:
            pass
    print('Total semantic classes in style image: {}'.format(count))
    style_mask_tensor = torch.from_numpy(
        merged_style_mask[:count, :, :]).float().to(config.device0)
    content_mask_tensor = torch.from_numpy(
        merged_content_mask[:count, :, :]).float().to(config.device0)
    #--------------------------
    print('Save each mask as an image for debugging')
    for i in range(count):
        utils.save_pic(
            torch.stack([
                style_mask_tensor[i, :, :], style_mask_tensor[i, :, :],
                style_mask_tensor[i, :, :]
            ],
                        dim=0), 'style_mask_' + str(i))
        utils.save_pic(
            torch.stack([
                content_mask_tensor[i, :, :], content_mask_tensor[i, :, :],
                content_mask_tensor[i, :, :]
            ],
                        dim=0), 'content_mask_' + str(i))

    # Using GPU or CPU
    device = torch.device(config.device0)

    style_img = utils.load_image(style_image_path, None)
    content_img = utils.load_image(content_image_path, None)
    width_s, height_s = style_img.size
Exemplo n.º 7
0
def train(data_loader, model_index, x_eval_train, gn_fp, dn_fp, ave_fp):
    ### Model Initiation
    gn = GN().cuda()
    dn = DN().cuda()

    ave_state_dict = tor.load(ave_fp)
    gn.load_ave_state(ave_state_dict)
    dn.load_ave_state(ave_state_dict)

    if gn_fp :
        gn_state_dict = tor.load(gn_fp)
        gn.load_state_dict(gn_state_dict)
    if dn_fp :
        dn_state_dict = tor.load(dn_fp)
        dn.load_state_dict(dn_state_dict)
    gn.cuda()
    dn.cuda()


    loss_func = tor.nn.BCELoss().cuda()

    #optim = tor.optim.SGD(fcn.parameters(), lr=LR, momentum=MOMENTUM)
    optim_gn = tor.optim.Adam(gn.parameters(), lr=LR)
    optim_dn = tor.optim.Adam(dn.parameters(), lr=LR)

    lr_step_gn = StepLR(optim_gn, step_size=LR_STEPSIZE, gamma=LR_GAMMA)
    lr_step_dn = StepLR(optim_dn, step_size=LR_STEPSIZE, gamma=LR_GAMMA)


    ### Training
    for epoch in range(EPOCH):
        print("|Epoch: {:>4} |".format(epoch + 1))

        for step, (x_batch, y_batch) in enumerate(data_loader):
            print("Process: {}/{}".format(step, int(AVAILABLE_SIZE[0] / BATCHSIZE)), end="\r")

            ### train true/false pic
            if (step // PIVOT_STEPS) % 3 != 2 :
                out = Variable(x_batch).cuda() if step % 2 == 0 else gn(Variable(tor.randn(BATCHSIZE, 512)).cuda())
                ans = Variable(tor.ones(BATCHSIZE, 1)).cuda() if step % 2 == 0 else Variable(tor.zeros(BATCHSIZE, 1)).cuda()
                dis = dn(out)
                optim = optim_dn

            else :
                out = gn(Variable(tor.randn(BATCHSIZE, 512)).cuda()).cuda()
                ans = Variable(tor.ones(BATCHSIZE, 1)).cuda()
                dis = dn(out)
                optim = optim_dn

            loss = loss_func(dis, ans)
            print (loss.data)
            loss.backward()
            if (step // PIVOT_STEPS) % 3 != 2 :
                optim_dn.step()
            else :
                optim_gn.step()

            optim_dn.zero_grad()
            optim_gn.zero_grad()
            lr_step_dn.step()
            lr_step_gn.step()


            if step % RECORD_JSON_PERIOD == 0 :
                x_true = Variable(x_eval_train).cuda()
                out = dn(x_true)
                acc_true = round(int((out > 0.5).sum().data) / EVAL_SIZE, 5)
                x_false = gn(Variable(tor.randn((EVAL_SIZE, 512))).cuda())
                out = dn(x_false)
                acc_false = round(int((out <= 0.5).sum().data) / EVAL_SIZE, 5)

                print ("|Acc True: {}   |Acc False: {}".format(acc_true, acc_false))

                save_record(model_index, epoch, optim, loss, acc_true, acc_false)

            if step % RECORD_PIC_PERIOD == 0 :
                loss = float(loss.data)
                print("|Loss: {:<8}".format(loss))
                save_pic("output_{}".format(model_index), gn, 3)

            if step % (2 * PIVOT_STEPS) == 0 :
                pass


        ### Save model
            if step % RECORD_MODEL_PERIOD == 0:
                tor.save(gn.state_dict(), os.path.join(MODEL_ROOT, "gan_gn_{}_{}.pkl".format(model_index, epoch)))