Esempio n. 1
0
def main(argv=None):

    print("Parameters:")
    for k, v in FLAGS.__flags.items():
        print(k, "=", v)
    print()

    #sess = tf.InteractiveSession()
    input_generator = inputs.train_pipeline(batch_size=FLAGS.batch_size,
                                            num_epochs=FLAGS.num_epochs)

    classifier_model = ImageClassifier(NUM_CLASSES,
                                       IMAGE_SIZE,
                                       batch_size=FLAGS.batch_size)

    sess = tf.Session()

    summary_dir = FLAGS.output_dir + FLAGS.summary_train_dir
    train_writer = tf.summary.FileWriter(summary_dir, sess.graph)

    coord = tf.train.Coordinator()

    with sess.as_default():
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        step = 0
        start = datetime.now()
        for batch in input_generator:

            accuracy, loss, summary, run_metadata, embeddings = classifier_model.train(
                sess, batch)

            if step % FLAGS.report_every is 0:
                now = datetime.now()
                elapsed = now - start
                average = elapsed / step if not step is 0 else 0
                print(
                    "Step %08d, Accuracy %.6f, Loss %.6f, Average Time %s/step, Elapsed Time %s%s"
                    % (step, accuracy * 100, loss, average, elapsed,
                       ", Created Summary"
                       if step % FLAGS.summary_every is 0 else ""))
                sys.stdout.flush()

            if step % FLAGS.summary_every is 0:
                train_writer.add_run_metadata(run_metadata, 'step%d' % step)
                train_writer.add_summary(summary, step)

            if step % FLAGS.checkpoint_every is 0:
                classifier_model.save(sess, global_step=step)

            step += 1
Esempio n. 2
0
# MODE = 'eval'
MODE = 'classify'

if MODE == 'train':

    classes = ['diverse_normal_0.06', 'sensity']
    # classes = ['1-normal', 'sensity']
    # classes = ['2-bikini', 'all_naked']
    # classes = ['2-bikini', '4-naked', '6-sex']

    pclassifier = ImageClassifier(
        model_name='diversed_init_stage_porn_classifier',
        classes=classes,
        to_log_file=True)

    pclassifier.train(save_gen_imgs=False, use_tensorboard=True)

elif MODE == 'eval':

    # classes = ['1-normal', 'sensity']
    classes = ['2-bikini', 'all_naked']
    # classes = ['2-bikini', '4-naked', '6-sex']

    # model_id = 'soft_porn_classifier.20190622035912.01-1.0523'  # (0.075, 0.9826)
    # model_id = 'soft_porn_classifier.20190622035912.02-0.5248'  # (0.1830, 0.9304)
    # model_id = 'soft_porn_classifier.20190622035912.04-0.4564'  # (0.2270, 0.9130)
    model_id = 'soft_porn_classifier.20190622035912.05-0.3722'  # (0.3144, 0.8696)
    # model_id = 'soft_porn_classifier.20190622035912.06-0.3613'  # (0.4999, 0.8348)

    pclassifier = ImageClassifier(model_id=model_id,
                                  classes=classes,
def train(train_source_iter: ForeverDataIterator,
          train_target_iter: ForeverDataIterator, model: ImageClassifier,
          domain_adv: DomainAdversarialLoss, esem, optimizer: SGD,
          lr_scheduler: StepwiseLR, epoch: int, source_class_weight,
          args: argparse.Namespace):
    batch_time = AverageMeter('Time', ':5.2f')
    data_time = AverageMeter('Data', ':5.2f')
    losses = AverageMeter('Loss', ':6.2f')
    cls_accs = AverageMeter('Cls Acc', ':3.1f')
    domain_accs = AverageMeter('Domain Acc', ':3.1f')
    progress = ProgressMeter(
        args.iters_per_epoch,
        [batch_time, data_time, losses, cls_accs, domain_accs],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    domain_adv.train()
    esem.eval()

    end = time.time()
    for i in range(args.iters_per_epoch):
        lr_scheduler.step()

        # measure data loading time
        data_time.update(time.time() - end)

        x_s, labels_s = next(train_source_iter)
        x_t, _ = next(train_target_iter)

        x_s = x_s.to(device)
        x_t = x_t.to(device)
        labels_s = labels_s.to(device)

        # compute output
        # x = torch.cat((x_s, x_t), dim=0)
        # y, f = model(x)
        # y_s, y_t = y.chunk(2, dim=0)
        # f_s, f_t = f.chunk(2, dim=0)
        y_s, f_s = model(x_s)
        y_t, f_t = model(x_t)

        with torch.no_grad():
            yt_1, yt_2, yt_3, yt_4, yt_5 = esem(f_t)
            confidece = get_confidence(yt_1, yt_2, yt_3, yt_4, yt_5)
            entropy = get_entropy(yt_1, yt_2, yt_3, yt_4, yt_5)
            consistency = get_consistency(yt_1, yt_2, yt_3, yt_4, yt_5)
            w_t = (1 - entropy + 1 - consistency + confidece) / 3
            w_s = torch.tensor([source_class_weight[i]
                                for i in labels_s]).to(device)

        cls_loss = F.cross_entropy(y_s, labels_s)
        transfer_loss = domain_adv(f_s, f_t, w_s.detach(),
                                   w_t.to(device).detach())
        domain_acc = domain_adv.domain_discriminator_accuracy
        loss = cls_loss + transfer_loss * args.trade_off

        cls_acc = accuracy(y_s, labels_s)[0]

        losses.update(loss.item(), x_s.size(0))
        cls_accs.update(cls_acc.item(), x_s.size(0))
        domain_accs.update(domain_acc.item(), x_s.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)