示例#1
0
    isTrain = tf.placeholder(dtype=tf.bool)

    # networks : generator
    G_z = generator5(z, y_label, isTrain)
    # G=attention('attention',G_z)

    # networks : discriminator
    layer_out_r, D_real, D_real_logits, D_pre_labels = discriminator2(x, y_fill, isTrain)
    layer_out_f, D_fake, D_fake_logits, _ = discriminator2(G_z, y_fill, isTrain, reuse=True)

    # loss for each network
    if train_mode==1:
        #MMD
        image=tf.reshape(x, [batch_size, -1])
        G=tf.reshape(G_z,[batch_size, -1])
        kernel_loss = mix_rbf_mmd2(G, image)
        ada_loss = tf.sqrt(kernel_loss)
    else:
        #adaptation
        f_match = tf.constant(0., dtype=tf.float32)
        for i in range(4):
            f_match += tf.reduce_mean(tf.multiply(layer_out_f[i] - layer_out_r[i], layer_out_f[i] - layer_out_r[i]))
        ada_loss=f_match
    D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones([batch_size, 1, 1])))
    D_loss_dis = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_pre_labels, labels=y_label))
    D_loss_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.zeros([batch_size, 1, 1])))
    D_loss = D_loss_real + D_loss_fake + D_loss_dis + ada_loss

    G_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=D_fake_logits, labels=tf.ones([batch_size, 1, 1])))
示例#2
0
def mmd_loss(x_src, x_tar):
    return mmd.mix_rbf_mmd2(x_src, x_tar, [GAMMA])
示例#3
0
def main(_):
    # Load data.
    dataset_tools = import_module('tools.' + FLAGS.dataset)
    train_images, train_labels = dataset_tools.get_data('train')
    if FLAGS.target_dataset is not None:
        target_dataset_tools = import_module('tools.' + FLAGS.target_dataset)
        train_images_unlabeled, _ = target_dataset_tools.get_data(
            FLAGS.target_dataset_split)
    else:
        train_images_unlabeled, _ = dataset_tools.get_data('unlabeled')

    architecture = getattr(semisup.architectures, FLAGS.architecture)

    num_labels = dataset_tools.NUM_LABELS
    image_shape = dataset_tools.IMAGE_SHAPE

    # Sample labeled training subset.
    seed = FLAGS.sup_seed if FLAGS.sup_seed != -1 else None
    sup_by_label = semisup.sample_by_label(train_images, train_labels,
                                           FLAGS.sup_per_class, num_labels,
                                           seed)

    # Sample unlabeled training subset.
    if FLAGS.unsup_samples > -1:
        num_unlabeled = len(train_images_unlabeled)
        assert FLAGS.unsup_samples <= num_unlabeled, (
            'Chose more unlabeled samples ({})'
            ' than there are in the '
            'unlabeled batch ({}).'.format(FLAGS.unsup_samples, num_unlabeled))

        rng = np.random.RandomState(seed=seed)
        train_images_unlabeled = train_images_unlabeled[rng.choice(
            num_unlabeled, FLAGS.unsup_samples, False)]

    graph = tf.Graph()
    with graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(FLAGS.ps_tasks,
                                               merge_devices=True)):

            # Set up inputs.
            t_unsup_images = semisup.create_input(train_images_unlabeled, None,
                                                  FLAGS.unsup_batch_size)
            t_sup_images, t_sup_labels = semisup.create_per_class_inputs(
                sup_by_label, FLAGS.sup_per_batch)

            if FLAGS.remove_classes:
                t_sup_images = tf.slice(t_sup_images, [
                    0, 0, 0, 0
                ], [FLAGS.sup_per_batch *
                    (num_labels - FLAGS.remove_classes)] + image_shape)

            # Resize if necessary.
            if FLAGS.new_size > 0:
                new_shape = [FLAGS.new_size, FLAGS.new_size, image_shape[-1]]
            else:
                new_shape = None

            # Apply augmentation
            if FLAGS.augmentation:
                # TODO(haeusser) revert this to the general case
                def _random_invert(inputs, _):
                    randu = tf.random_uniform(
                        shape=[FLAGS.sup_per_batch * num_labels],
                        minval=0.,
                        maxval=1.,
                        dtype=tf.float32)
                    randu = tf.cast(tf.less(randu, 0.5), tf.float32)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    randu = tf.expand_dims(randu, 1)
                    inputs = tf.cast(inputs, tf.float32)
                    return tf.abs(inputs - 255 * randu)

                augmentation_function = _random_invert

                # if hasattr(dataset_tools, 'augmentation_params'):
                #    augmentation_function = partial(
                #        apply_augmentation, params=dataset_tools.augmentation_params)
                # else:
                #    augmentation_function = apply_affine_augmentation
            else:
                augmentation_function = None

            # Create function that defines the network.
            model_function = partial(
                architecture,
                new_shape=new_shape,
                img_shape=image_shape,
                augmentation_function=augmentation_function,
                batch_norm_decay=FLAGS.batch_norm_decay,
                emb_size=FLAGS.emb_size)

            # Set up semisup model.
            model = semisup.SemisupModel(model_function, num_labels,
                                         image_shape)

            # Compute embeddings and logits.
            t_sup_emb = model.image_to_embedding(t_sup_images)
            t_unsup_emb = model.image_to_embedding(t_unsup_images)

            # Add virtual embeddings.
            if FLAGS.virtual_embeddings:
                t_sup_emb = tf.concat(0, [
                    t_sup_emb,
                    semisup.create_virt_emb(FLAGS.virtual_embeddings,
                                            FLAGS.emb_size)
                ])

                if not FLAGS.remove_classes:
                    # need to add additional labels for virtual embeddings
                    t_sup_labels = tf.concat(0, [
                        t_sup_labels,
                        (num_labels +
                         tf.range(1, FLAGS.virtual_embeddings + 1, tf.int64)) *
                        tf.ones([FLAGS.virtual_embeddings], tf.int64)
                    ])

            t_sup_logit = model.embedding_to_logit(t_sup_emb)

            # Add losses.
            if FLAGS.mmd:
                sys.path.insert(0, '/usr/wiss/haeusser/libs/opt-mmd/gan')
                from mmd import mix_rbf_mmd2

                bandwidths = [2.0, 5.0, 10.0, 20.0, 40.0, 80.0]  # original

                t_sup_flat = tf.reshape(t_sup_emb,
                                        [FLAGS.sup_per_batch * num_labels, -1])
                t_unsup_flat = tf.reshape(t_unsup_emb,
                                          [FLAGS.unsup_batch_size, -1])
                mmd_loss = mix_rbf_mmd2(t_sup_flat,
                                        t_unsup_flat,
                                        sigmas=bandwidths)
                tf.losses.add_loss(mmd_loss)
                tf.summary.scalar('MMD_loss', mmd_loss)
            else:
                visit_weight_envelope_steps = FLAGS.walker_weight_envelope_steps if FLAGS.visit_weight_envelope_steps == -1 else FLAGS.visit_weight_envelope_steps
                visit_weight_envelope_delay = FLAGS.walker_weight_envelope_delay if FLAGS.visit_weight_envelope_delay == -1 else FLAGS.visit_weight_envelope_delay
                visit_weight = apply_envelope(
                    type=FLAGS.visit_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.visit_weight,
                    growing_steps=visit_weight_envelope_steps,
                    delay=visit_weight_envelope_delay)
                walker_weight = apply_envelope(
                    type=FLAGS.walker_weight_envelope,
                    step=model.step,
                    final_weight=FLAGS.walker_weight,
                    growing_steps=FLAGS.walker_weight_envelope_steps,
                    delay=FLAGS.walker_weight_envelope_delay)
                tf.summary.scalar('Weights_Visit', visit_weight)
                tf.summary.scalar('Weights_Walker', walker_weight)

                if FLAGS.unsup_samples != 0:
                    model.add_semisup_loss(t_sup_emb,
                                           t_unsup_emb,
                                           t_sup_labels,
                                           visit_weight=visit_weight,
                                           walker_weight=walker_weight)

            model.add_logit_loss(t_sup_logit,
                                 t_sup_labels,
                                 weight=FLAGS.logit_weight)

            # Set up learning rate schedule if necessary.
            if FLAGS.custom_lr_vals is not None and FLAGS.custom_lr_steps is not None:
                boundaries = [
                    tf.convert_to_tensor(x, tf.int64)
                    for x in FLAGS.custom_lr_steps
                ]

                t_learning_rate = piecewise_constant(model.step, boundaries,
                                                     FLAGS.custom_lr_vals)
            else:
                t_learning_rate = tf.maximum(
                    tf.train.exponential_decay(FLAGS.learning_rate,
                                               model.step,
                                               FLAGS.decay_steps,
                                               FLAGS.decay_factor,
                                               staircase=True),
                    FLAGS.minimum_learning_rate)

            # Create training operation and start the actual training loop.
            train_op = model.create_train_op(t_learning_rate)

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            # config.log_device_placement = True

            saver = tf_saver.Saver(max_to_keep=FLAGS.max_checkpoints,
                                   keep_checkpoint_every_n_hours=FLAGS.
                                   keep_checkpoint_every_n_hours)

            ''' BEGIN EVIL STUFF !!!
            checkpoint_path = '/usr/wiss/haeusser/experiments/inception_v4/model.ckpt'
            mapping = dict()
            for x in slim.get_model_variables():
                name = x.name[:-2]
                ok = True
                for banned in ['Logits', 'fc1', 'fully_connected', 'ExponentialMovingAverage']:
                    if banned in name:
                        ok = False
                if ok:
                    mapping[name[4:]] = x

            init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
                checkpoint_path, mapping)

            # Create an initial assignment function.
            def InitAssignFn(sess):
                sess.run(init_assign_op, init_feed_dict)
                print("#################################### Checkpoint loaded.")
             '''  # END EVIL STUFF !!!

            slim.learning.train(
                train_op,
                # init_fn=InitAssignFn,  # EVIL, too
                logdir=FLAGS.logdir + '/train',
                save_summaries_secs=FLAGS.save_summaries_secs,
                save_interval_secs=FLAGS.save_interval_secs,
                master=FLAGS.master,
                is_chief=(FLAGS.task == 0),
                startup_delay_steps=(FLAGS.task * 20),
                log_every_n_steps=FLAGS.log_every_n_steps,
                session_config=config,
                trace_every_n_steps=1000,
                saver=saver,
                number_of_steps=FLAGS.max_steps,
            )
示例#4
0
def main():
    print('Start Training\nInitiliazing\n')
    print('src:', args.source)
    print('tar:', args.target)

    # Data loading

    data_func = {
        'modelnet': Modelnet40_data,
        'scannet': Scannet_data_h5,
        'shapenet': Shapenet_data,
        "sapples": AppleTreeData,
        "rapples": AppleTreeData
    }

    source_train_dataset = data_func[args.source](pc_input_num=1024,
                                                  status='train',
                                                  aug=True,
                                                  pc_root=dir_root +
                                                  args.source)
    target_train_dataset1 = data_func[args.target](pc_input_num=1024,
                                                   status='train',
                                                   aug=True,
                                                   pc_root=dir_root +
                                                   args.target)
    source_test_dataset = data_func[args.source](pc_input_num=1024, status='test', aug=False, pc_root= \
        dir_root + args.source)
    target_test_dataset1 = data_func[args.target](pc_input_num=1024, status='test', aug=False, pc_root= \
        dir_root + args.target)

    num_source_train = len(source_train_dataset)
    num_source_test = len(source_test_dataset)
    num_target_train1 = len(target_train_dataset1)
    num_target_test1 = len(target_test_dataset1)

    source_train_dataloader = DataLoader(source_train_dataset,
                                         batch_size=BATCH_SIZE,
                                         shuffle=True,
                                         num_workers=2,
                                         drop_last=True)
    source_test_dataloader = DataLoader(source_test_dataset,
                                        batch_size=BATCH_SIZE,
                                        shuffle=True,
                                        num_workers=2,
                                        drop_last=True)
    target_train_dataloader1 = DataLoader(target_train_dataset1,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True,
                                          num_workers=2,
                                          drop_last=True)
    target_test_dataloader1 = DataLoader(target_test_dataset1,
                                         batch_size=BATCH_SIZE,
                                         shuffle=True,
                                         num_workers=2,
                                         drop_last=True)

    print(
        'num_source_train: {:d}, num_source_test: {:d}, num_target_test1: {:d} '
        .format(num_source_train, num_source_test, num_target_test1))
    print('batch_size:', BATCH_SIZE)

    # Model

    model = Model.Net_MDA()
    model = model.to(device=device)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device=device)

    remain_epoch = 50

    # Optimizer

    params = [{
        'params': v
    } for k, v in model.g.named_parameters() if 'pred_offset' not in k]

    optimizer_g = optim.Adam(params, lr=LR, weight_decay=weight_decay)
    lr_schedule_g = optim.lr_scheduler.CosineAnnealingLR(optimizer_g,
                                                         T_max=args.epochs +
                                                         remain_epoch)

    optimizer_c = optim.Adam([{
        'params': model.c1.parameters()
    }, {
        'params': model.c2.parameters()
    }],
                             lr=LR * 2,
                             weight_decay=weight_decay)
    lr_schedule_c = optim.lr_scheduler.CosineAnnealingLR(optimizer_c,
                                                         T_max=args.epochs +
                                                         remain_epoch)

    optimizer_dis = optim.Adam([{
        'params': model.g.parameters()
    }, {
        'params': model.attention_s.parameters()
    }, {
        'params': model.attention_t.parameters()
    }],
                               lr=LR * args.scaler,
                               weight_decay=weight_decay)
    lr_schedule_dis = optim.lr_scheduler.CosineAnnealingLR(optimizer_dis,
                                                           T_max=args.epochs +
                                                           remain_epoch)

    def adjust_learning_rate(optimizer, epoch):
        """Sets the learning rate to the initial LR decayed by half by every 5 or 10 epochs"""
        if epoch > 0:
            if epoch <= 30:
                lr = args.lr * args.scaler * (0.5**(epoch // 5))
            else:
                lr = args.lr * args.scaler * (0.5**(epoch // 10))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            writer.add_scalar('lr_dis', lr, epoch)

    def discrepancy(out1, out2):
        """discrepancy loss"""
        out = torch.mean(
            torch.abs(F.softmax(out1, dim=-1) - F.softmax(out2, dim=-1)))
        return out

    def make_variable(tensor, volatile=False):
        """Convert Tensor to Variable."""
        if torch.cuda.is_available():
            tensor = tensor.cuda()
        return Variable(tensor, volatile=volatile)

    best_target_test_acc = 0

    for epoch in range(max_epoch):
        since_e = time.time()

        lr_schedule_g.step(epoch=epoch)
        lr_schedule_c.step(epoch=epoch)
        adjust_learning_rate(optimizer_dis, epoch)

        writer.add_scalar('lr_g', lr_schedule_g.get_lr()[0], epoch)
        writer.add_scalar('lr_c', lr_schedule_c.get_lr()[0], epoch)

        model.train()

        loss_total = 0
        loss_adv_total = 0
        loss_node_total = 0
        correct_total = 0
        data_total = 0
        data_t_total = 0
        cons = math.sin((epoch + 1) / max_epoch * math.pi / 2)

        # Training

        for batch_idx, (batch_s, batch_t) in enumerate(
                zip(source_train_dataloader, target_train_dataloader1)):

            data, label = batch_s
            data_t, label_t = batch_t

            data = data.to(device=device)
            label = label.to(device=device).long()
            data_t = data_t.to(device=device)
            label_t = label_t.to(device=device).long()

            pred_s1, pred_s2 = model(data)
            pred_t1, pred_t2 = model(data_t, constant=cons, adaptation=True)

            # Classification loss

            loss_s1 = criterion(pred_s1, label)
            loss_s2 = criterion(pred_s2, label)

            # Adversarial loss

            loss_adv = -1 * discrepancy(pred_t1, pred_t2)

            loss_s = loss_s1 + loss_s2
            loss = args.weight * loss_s + loss_adv

            loss.backward()
            optimizer_g.step()
            optimizer_c.step()
            optimizer_g.zero_grad()
            optimizer_c.zero_grad()

            # Local Alignment

            feat_node_s = model(data, node_adaptation_s=True)
            feat_node_t = model(data_t, node_adaptation_t=True)
            sigma_list = [0.01, 0.1, 1, 10, 100]
            loss_node_adv = 1 * mmd.mix_rbf_mmd2(feat_node_s, feat_node_t,
                                                 sigma_list)
            loss = loss_node_adv

            loss.backward()
            optimizer_dis.step()
            optimizer_dis.zero_grad()

            loss_total += loss_s.item() * data.size(0)
            loss_adv_total += loss_adv.item() * data.size(0)
            loss_node_total += loss_node_adv.item() * data.size(0)
            data_total += data.size(0)
            data_t_total += data_t.size(0)

            if (batch_idx + 1) % 10 == 0:
                print(
                    'Train:{} [{} {}/{}  loss_s: {:.4f} \t loss_adv: {:.4f} \t loss_node_adv: {:.4f} \t cons: {:.4f}]'
                    .format(epoch, data_total, data_t_total, num_source_train,
                            loss_total / data_total,
                            loss_adv_total / data_total,
                            loss_node_total / data_total, cons))

        # Testing

        with torch.no_grad():
            model.eval()
            loss_total = 0
            correct_total = 0
            data_total = 0
            acc_class = torch.zeros(10, 1)
            acc_to_class = torch.zeros(10, 1)
            acc_to_all_class = torch.zeros(10, 10)

            for batch_idx, (data, label) in enumerate(target_test_dataloader1):
                print(data.size(0))
                data = data.to(device=device)
                label = label.to(device=device).long()
                pred1, pred2 = model(data)
                output = (pred1 + pred2) / 2
                loss = criterion(output, label)
                _, pred = torch.max(output, 1)

                loss_total += loss.item() * data.size(0)
                correct_total += torch.sum(pred == label)
                data_total += data.size(0)

            pred_loss = loss_total / data_total
            pred_acc = correct_total.double() / data_total

            if pred_acc > best_target_test_acc:
                best_target_test_acc = pred_acc
            print(
                'Target 1:{} [overall_acc: {:.4f} \t loss: {:.4f} \t Best Target Acc: {:.4f}]'
                .format(epoch, pred_acc, pred_loss, best_target_test_acc))
            writer.add_scalar('accs/target_test_acc', pred_acc, epoch)

        time_pass_e = time.time() - since_e
        print('The {} epoch takes {:.0f}m {:.0f}s'.format(
            epoch, time_pass_e // 60, time_pass_e % 60))
        print(args)
        print(' ')
示例#5
0
                        [batch_size, max_sentence_length, embedding_size])

logits_data, logits_generated, encoding_data, encoding_generated, features_data, features_generated = build_discriminator(
    x_data, x_generated, batch_size, max_sentence_length, embedding_size)
y_data, y_generated = tf.nn.softmax(logits_data), tf.nn.softmax(
    logits_generated)

# TODO classifications come out as pairs of numbers; could instead come out as
# single numbers representing the probability that the sentence is real.
y_data, y_generated = y_data[:, 0], y_generated[:, 0]

# Loss, as described in Zhang 2017
# Lambda values meant to weight gan ~= recon > mmd
lambda_r, lambda_m = 1.0e-1, 1.0e-1
mmd_val = mmd.mix_rbf_mmd2(features_data,
                           features_generated,
                           sigmas=args.mmd_sigmas)
gan_val = tf.reduce_mean(tf.log(y_data)) + tf.reduce_mean(
    tf.log(1 - y_generated))
recon_val = tf.reduce_mean(tf.norm(z_prior - encoding_generated, axis=1))
d_loss = -gan_val + lambda_r * recon_val - lambda_m * mmd_val
g_loss = mmd_val

tf.summary.scalar("mmd", mmd_val)
tf.summary.scalar("gan", gan_val)
tf.summary.scalar("recon", recon_val)
tf.summary.scalar("d_loss", d_loss)
tf.summary.scalar("g_loss", g_loss)

# Clipping gradients.
# Not only is this described in Zhang, but it's necessary due to the presence
 def mmd_loss(self, x_src, x_tar):
     return mmd.mix_rbf_mmd2(x_src, x_tar, [self.GAMMA])
def train(model,
          cadset,
          realset,
          optimizer,
          hot=False,
          summarywriter=None,
          savefilename=None,
          **kwargs):
    """Train the network using our method

    Arguments:
        model {nn.Module or nn.DataParallel} -- The network to be trained
        cadset {MyDataset} -- Dataset of synthetic images
        realset {MyDataset} -- Dataset of real images
        optimizer {torch.optim.Optimizer} -- Optimizer

    Keyword Arguments:
        hot {bool} -- Whether training in hot stage (default: {False})
        summarywriter {tensorboardX.SummaryWriter} -- Tensorboard writer (default: {None})
        savefilename {str} -- Filename of the saved model (default: {None})
        epoch {int} -- Number of epoches to train
        batch_size {int} -- Batch size
        n_classes {int} -- Number of classes of your dataset
        test_steps {int} -- Test intervals while training
        GPUs {None/int/(int)} -- CUDA device IDs

    Returns:
        max_recall {float} Maximum recall during training process
        max_ap {[float]} Maximum AP during training process
        max_mean_ap {[float]} Maximum mean AP during training process
    """

    if not hot:
        print('Traning in cold stage!')
    else:
        print('Training in hot stage!')
    if isinstance(model, nn.DataParallel):
        print(
            'Warning: Your are using DataParallel. We will only save the state dict of the module, instead of the whole DataParallel object.'
        )
    if summarywriter is None:
        print(
            'Warning: summarywriter is None. The result will not be displayed on Tensorboard!'
        )
    if savefilename is None:
        print(
            'Warning: savefilename is None. The trained model will not be saved!'
        )
    if 'epoch' not in kwargs:
        raise ValueError(
            'Please specify the number of epoches by passing "epoch=YOUR_EPOCHES"!'
        )
    if 'batch_size' not in kwargs:
        raise ValueError(
            'Please specify the batch size by passing "batch_size=YOUR_BATCH_SIZE"!'
        )
    if 'n_classes' not in kwargs:
        raise ValueError(
            'Please specify the number of classes in your dataset by passing "n_classes=YOUR_CLASSES"!'
        )
    if 'test_steps' not in kwargs:
        kwargs['test_steps'] = 50
        print(
            'Warning: test_steps is not specified, we will use 50 by default.')

    max_recall, max_ap, max_mean_ap = 0, None, 0
    for epoch in range(kwargs['epoch']):
        cadloader = DataLoader(cadset,
                               batch_size=kwargs['batch_size'],
                               shuffle=True,
                               num_workers=cpu_count(),
                               drop_last=True)
        for batch, (images_cad, labels_cad) in enumerate(cadloader):
            # Test accuracies
            model.eval()
            if (epoch * len(cadloader) + batch) % kwargs['test_steps'] == 0:
                _, all_output, all_pred, all_label = predict(
                    model, realset, **kwargs)
                recall = np.sum(all_pred == all_label) / float(len(realset))
                ap = AP(all_output, all_label)
                mean_ap = meanAP(all_output, all_label)
                print('Mean Recall: ', recall)
                print('AP: ', ap)
                print('Mean AP: ', mean_ap)
                print('Previous Maximum Mean AP: ', max_mean_ap)
                print('Previous Maximum Accuracy: ', max_recall)
                if mean_ap >= max_mean_ap:
                    max_ap, max_mean_ap = ap, mean_ap
                if recall >= max_recall:
                    max_recall = recall
                    if hot:
                        print('Update pseudo labels!')
                        realset.update_pseudo_labels(all_pred)
                    if savefilename is not None:
                        if isinstance(model, nn.DataParallel):
                            torch.save(model.module.state_dict(), savefilename)
                        else:
                            torch.save(model.state_dict(), savefilename)

            # Read training samples
            if hot:
                images_real, labels_real = realset.random_choice(
                    labels_cad, use_pseudo=True)
            else:
                images_real, labels_real = realset.random_choice([
                    random.randint(0, kwargs['n_classes'] - 1)
                    for _ in range(kwargs['batch_size'])
                ])

            # Convert torch.Tensor to torch.autograd.Variable
            model.train()
            images_cad = Variable(images_cad)
            labels_cad = Variable(labels_cad)
            images_real = Variable(images_real)
            labels_real = Variable(labels_real)

            if kwargs['GPUs']:
                images_cad = images_cad.cuda(kwargs['GPUs'][0])
                labels_cad = labels_cad.cuda(kwargs['GPUs'][0])
                images_real = images_real.cuda(kwargs['GPUs'][0])
                labels_real = labels_real.cuda(kwargs['GPUs'][0])

            # Feed to our network
            mmd_cad, mmd_real, out_cad, out_real = model(
                images_cad, images_real)

            # Calculate the loss
            loss_class = F.cross_entropy(out_cad, labels_cad)
            loss_mmd = mix_rbf_mmd2(mmd_cad, mmd_real, [1, 2, 4, 8, 16])
            loss = loss_class + loss_mmd

            # Calculate the accuracy within this batch
            accuracy_cad = torch.sum(
                labels_cad == torch.max(out_cad, 1)[1]).item() / float(
                    kwargs['batch_size'])
            accuracy_pseudo = torch.sum(
                labels_cad == torch.max(out_real, 1)[1]).item() / float(
                    kwargs['batch_size'])
            accuracy_real = torch.sum(
                labels_real == torch.max(out_real, 1)[1]).item() / float(
                    kwargs['batch_size'])

            # Print the loss and the accuracy
            if hot:
                print(
                    'epoch:%d, batch:%d, loss:%0.5f, loss_class:%0.5f, loss_mmd:%0.5f, accuracy of CAD:%0.5f, accuracy of pseudo:%0.5f, accuracy of real:%0.5f'
                    % (epoch, batch, loss.item(), loss_class.item(),
                       loss_mmd.item(), accuracy_cad, accuracy_pseudo,
                       accuracy_real))
            else:
                print(
                    'epoch:%d, batch:%d, loss:%0.5f, loss_class:%0.5f, loss_mmd:%0.5f, accuracy of CAD:%0.5f, accuracy of real:%0.5f'
                    % (epoch, batch, loss.item(), loss_class.item(),
                       loss_mmd.item(), accuracy_cad, accuracy_real))

            # Print to Tensorboard
            if summarywriter:
                summarywriter.add_scalar('accuracy_of_cad', accuracy_cad,
                                         epoch * len(cadloader) + batch)
                summarywriter.add_scalar('accuracy_of_real', accuracy_real,
                                         epoch * len(cadloader) + batch)
                summarywriter.add_scalar('loss_of_classification',
                                         loss_class.item(),
                                         epoch * len(cadloader) + batch)
                summarywriter.add_scalar('loss_of_mmd', loss_mmd.item(),
                                         epoch * len(cadloader) + batch)
                summarywriter.add_scalar('loss', loss.item(),
                                         epoch * len(cadloader) + batch)

            # Optimize the network
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return max_recall, max_ap, max_mean_ap
示例#8
0
                        torch.FloatTensor)).cuda()
                y_eval_dec_1s = Variable(
                    torch.from_numpy(y_eval_dec_1s_np).type(
                        torch.FloatTensor)).cuda()
                y_eval_dec_0s = Variable(
                    torch.from_numpy(y_eval_dec_0s_np).type(
                        torch.FloatTensor)).cuda()
                # Get logistic regr error on x_eval, and class distr on y_eval.
                x_eval_error = np.mean(abs(x_eval_enc_p1_np -
                                           x_eval_labels_np))
                x_eval_labels = clf.predict(x_eval_enc_np)
                y_eval_labels = clf.predict(y_eval_enc_np)

            # compute biased MMD2 and use ReLU to prevent negative value
            if not weighted:
                mmd2_D = mix_rbf_mmd2(f_enc_X_D, f_enc_Y_D, sigma_list)
            else:
                try:
                    if args.thin_type == 'kernel':
                        mmd2_D = mix_rbf_mmd2_weighted(f_enc_X_D,
                                                       f_enc_Y_D,
                                                       sigma_list,
                                                       args.exp_const,
                                                       args.thinning_scale,
                                                       t_mean=t_mean,
                                                       t_cov_inv=t_cov_inv)
                    elif args.thin_type == 'logistic':
                        mmd2_D = mix_rbf_mmd2_weighted(f_enc_X_D,
                                                       f_enc_Y_D,
                                                       sigma_list,
                                                       args.exp_const,
z = tf.random_normal(shape=(batch_size, z_dim),
                     mean=0,
                     stddev=1,
                     dtype=tf.float32)
# z =  tf.placeholder(tf.float32, shape=[None, z_dim])
fake_set = generator(z, reuse=False)
fake = tf.concat(fake_set, 0)

# ========encoder==============
z_mean_real, _ = encoder(real, reuse=False)
z_mean_fake, _ = encoder(fake)

#======MMD loss===============
bandwidths = [2.0, 5.0, 10.0, 20.0, 40.0, 80.0]
from mmd import mix_rbf_mmd2
kernel_loss = mix_rbf_mmd2(z_mean_fake, z_mean_real, sigmas=bandwidths)
kernel_loss = tf.sqrt(kernel_loss)

#=======================
# trainable variables for each network
T_vars = tf.trainable_variables()
en_var = [var for var in T_vars if var.name.startswith('encoder')]
g_var = [var for var in T_vars if var.name.startswith('generator')]
""" train """
''' init '''
# session
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.6)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

# saver
saver = tf.train.Saver(max_to_keep=5)
示例#10
0
    def build_model(self):
        self.global_step = tf.Variable(0, name="global_step", trainable=False)
        self.lr = tf.placeholder(tf.float32, shape=[])
        self.images = tf.placeholder(tf.float32, [self.batch_size] + [self.output_size, self.output_size, self.c_dim],
                                    name='real_images')
        self.sample_images= tf.placeholder(tf.float32, [self.sample_size] + [self.output_size, self.output_size, self.c_dim],
                                        name='sample_images')
        self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z')

        tf.summary.histogram("z", self.z)

        self.G = self.generator_mnist(self.z)
        images = tf.reshape(self.images, [self.batch_size, -1])
        G = tf.reshape(self.G, [self.batch_size, -1])

        phi_images = self.discriminator_k(images)
        bandwidths = [2.0, 5.0, 10.0, 20.0, 40.0, 80.0]

        if self.config.use_weighted_layer_kernel:
            n_layer = len(phi_images)
            for layer_id in range(n_layer):
                with tf.variable_scope("dk_att_" + str(layer_id)):
                    dk_att = tf.get_variable(
                        "weight", [len(bandwidths)], tf.float32,
                        tf.constant_initializer(1 / len(bandwidths)))
        phi_G = self.discriminator_k(G, reuse=True)
        self.kernel_loss = tf.Variable(0.0, trainable=False, name="kernel_loss")

        if self.config.use_weighted_layer_kernel:
            n_layer = len(phi_images)
            if self.config.use_gan:
                n_layer -= 1
            tf.summary.histogram("phi_images", tf.concat(1, phi_images[1:]))
            tf.summary.histogram("phi_G", tf.concat(1, phi_G[1:]))

            for layer_id in range(n_layer):
                self.kernel_loss += mix_rbf_mmd2(
                    phi_G[layer_id], phi_images[layer_id],
                    sigmas=bandwidths,
                    wts=tf.exp(dk_att) / tf.reduce_sum(tf.exp(dk_att)))
        elif self.config.use_layer_kernel:
            n_layer = len(phi_images)
            for layer_id in range(n_layer):
                kernel_loss = mix_rbf_mmd2(
                    phi_G[layer_id], phi_images[layer_id], sigmas=bandwidths)
                tf.summary.scalar("kernel_loss_" + str(layer_id), kernel_loss)
                self.kernel_loss += kernel_loss #pow(2, n_layer) * kernel_loss
            if self.config.use_gan:
                self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(phi_images[-1], tf.ones_like(phi_images[-1])))
                self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(phi_G[-1], tf.zeros_like(phi_G[-1])))
                self.d_loss = self.d_loss_real + self.d_loss_fake
                tf.summary.scalar("d_loss", self.d_loss)
        else:
            if self.config.use_gan:
                self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(phi_images[-1], tf.ones_like(phi_images[-1])))
                self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(phi_G[-1], tf.zeros_like(phi_images[-1])))
                self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(phi_G[-1], tf.ones_like(phi_images[-1])))
                self.d_loss = self.d_loss_real + self.d_loss_fake
                tf.summary.scalar("d_loss", self.d_loss)
                self.phiG = phi_G[-1]
                phi_G = tf.concat(1, phi_G[0:-1])
                phi_images = tf.concat(1, phi_images[0:-1])

                print("use_gan")
            phi_G = [phi_G]
            phi_images = [phi_images]
            n_layer = 1
            self.kernel_loss = mix_rbf_mmd2(phi_G[0], phi_images[0], sigmas=bandwidths)
        tf.summary.scalar("kernel_loss", self.kernel_loss)
        self.kernel_loss = tf.sqrt(self.kernel_loss)

        tf.summary.image("train/input image", self.imageRearrange(tf.clip_by_value(self.images, 0, 1), 8))
        tf.summary.image("train/gen image", self.imageRearrange(tf.clip_by_value(self.G, 0, 1), 8))

        self.sampler = self.generator_mnist(self.z, is_train=False, reuse=True)
        t_vars = tf.trainable_variables()

        self.d_vars = [var for var in t_vars if 'd_' in var.name]
        self.g_vars = [var for var in t_vars if 'g_' in var.name]
        self.dk_vars = [var for var in t_vars if 'dk_' in var.name]

        self.saver = tf.train.Saver()
    def train_MMD(self, epoch, record_file=None):
        criterion = nn.CrossEntropyLoss().cuda()
        self.G.train()
        self.C1.train()

        torch.cuda.manual_seed(1)

        for batch_idx, data in enumerate(self.datasets):
            img_t = Variable(data['T'].cuda())
            img_s1 = Variable(data['S1'].cuda())
            img_s2 = Variable(data['S2'].cuda())
            img_s3 = Variable(data['S3'].cuda())
            img_s4 = Variable(data['S4'].cuda())

            label_s1 = Variable(data['S1_label'].long().cuda())
            label_s2 = Variable(data['S2_label'].long().cuda())
            label_s3 = Variable(data['S3_label'].long().cuda())
            label_s4 = Variable(data['S4_label'].long().cuda())

            if img_s1.size()[0] < self.batch_size or img_s2.size(
            )[0] < self.batch_size or img_s3.size(
            )[0] < self.batch_size or img_s4.size(
            )[0] < self.batch_size or img_t.size()[0] < self.batch_size:
                break
            self.reset_grad()
            feat_s1 = self.G(img_s1)
            output_s1 = self.C1(feat_s1)

            feat_s2 = self.G(img_s2)
            output_s2 = self.C1(feat_s2)

            feat_s3 = self.G(img_s3)
            output_s3 = self.C1(feat_s3)

            feat_s4 = self.G(img_s4)
            output_s4 = self.C1(feat_s4)

            feat_t = self.G(img_t)
            output_t = self.C1(feat_t)

            print('->shape', output_s1.shape, label_s1.shape)
            loss_s1 = criterion(output_s1, label_s1)
            loss_s2 = criterion(output_s2, label_s2)
            loss_s3 = criterion(output_s3, label_s3)
            loss_s4 = criterion(output_s4, label_s4)

            loss_s = (loss_s1 + loss_s2 + loss_s3 + loss_s4) / 4

            sigma = [1, 2, 5, 10]
            loss_msda =  mmd.mix_rbf_mmd2(feat_s1, feat_s2, sigma) + mmd.mix_rbf_mmd2(feat_s1, feat_s3, sigma) + mmd.mix_rbf_mmd2(feat_s1,feat_s4, sigma) +\
                mmd.mix_rbf_mmd2(feat_s1, feat_t, sigma) + mmd.mix_rbf_mmd2(feat_s2, feat_s3, sigma) + mmd.mix_rbf_mmd2(feat_s2, feat_t, sigma) +\
                mmd.mix_rbf_mmd2(feat_s2, feat_s4, sigma) + mmd.mix_rbf_mmd2(feat_s3, feat_s4, sigma) + mmd.mix_rbf_mmd2(feat_s3, feat_t, sigma) +\
                mmd.mix_rbf_mmd2(feat_s4, feat_t, sigma)
            loss = 10 * loss_msda + loss_s
            loss.backward()

            self.opt_c1.step()
            self.opt_g.step()
            self.reset_grad()
            if batch_idx > 500:
                return batch_idx

            if batch_idx % self.interval == 0:
                print(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss1: {:.6f}\t Loss2: {:.6f}\t Loss3: {:.6f}\t Loss4: {:.6f}\t Discrepancy: {:.6f}'
                    .format(epoch, batch_idx, 100, 100. * batch_idx / 70000,
                            loss_s1.data[0], loss_s2.data[0], loss_s3.data[0],
                            loss_s4.data[0], loss_msda.data[0]))
                if record_file:
                    record = open(record_file, 'a')
                    record.write(
                        '%s %s %s %s %s\n' %
                        (loss_msda.data[0], loss_s1.data[0], loss_s2.data[0],
                         loss_s3.data[0], loss_s4.data[0]))
                    record.close()
        return batch_idx
示例#12
0
def mmd_loss(x_src, x_tar, gamma):
    return mmd.mix_rbf_mmd2(x_src, x_tar, [gamma])
示例#13
0
def training_epoch(epoch, model, cluster_loader_dict, cluster_pairs, nn_paras):
    """ Training an epoch
        Args:
            epoch: number of the current epoch
            model: autoencoder
            cluster_loader_dict: dict of DataLoaders indexed by clusters
            cluster_pairs: pairs of similar clusters with weights
            nn_paras: parameters for neural network training
        Returns:
            avg_total_loss: average total loss of mini-batches
            avg_reco_loss: average reconstruction loss of mini-batches
            avg_tran_loss: average transfer loss of mini-batches
        """
    log_interval = nn_paras['log_interval']
    # load nn parameters
    base_lr = nn_paras['base_lr']
    lr_step = nn_paras['lr_step']
    num_epochs = nn_paras['num_epochs']
    l2_decay = nn_paras['l2_decay']
    gamma = nn_paras['gamma']
    cuda = nn_paras['cuda']

    # step decay of learning rate
    learning_rate = base_lr / math.pow(2, math.floor(epoch / lr_step))
    # regularization parameterbetween two losses
    gamma_rate = 2 / (1 + math.exp(-10 * (epoch) / num_epochs)) - 1
    gamma = gamma_rate * gamma

    if epoch % log_interval == 0:
        print('{:}, Epoch {}, learning rate {:.3E}, gamma {:.3E}'.format(
            time.asctime(time.localtime()), epoch, learning_rate, gamma))

    optimizer = torch.optim.Adam([
        {
            'params': model.encoder.parameters()
        },
        {
            'params': model.decoder.parameters()
        },
    ],
                                 lr=learning_rate,
                                 weight_decay=l2_decay)

    model.train()

    iter_data_dict = {}
    for cls in cluster_loader_dict:
        iter_data = iter(cluster_loader_dict[cls])
        iter_data_dict[cls] = iter_data
    # use the largest dataset to define an epoch
    num_iter = 0
    for cls in cluster_loader_dict:
        num_iter = max(num_iter, len(cluster_loader_dict[cls]))

    total_loss = 0
    total_reco_loss = 0
    total_tran_loss = 0
    num_batches = 0

    for it in range(0, num_iter):
        data_dict = {}
        label_dict = {}
        code_dict = {}
        reconstruct_dict = {}
        for cls in iter_data_dict:
            data, labels = iter_data_dict[cls].next()
            data_dict[cls] = data
            label_dict[cls] = labels
            if it % len(cluster_loader_dict[cls]) == 0:
                iter_data_dict[cls] = iter(cluster_loader_dict[cls])
            data_dict[cls] = Variable(data_dict[cls])
            label_dict[cls] = Variable(label_dict[cls])

        for cls in data_dict:
            code, reconstruct = model(data_dict[cls])
            code_dict[cls] = code
            reconstruct_dict[cls] = reconstruct

        optimizer.zero_grad()

        # transfer loss for cluster pairs in cluster_pairs matrix
        loss_transfer = torch.FloatTensor([0])
        if cuda:
            loss_transfer = loss_transfer.cuda()
        for i in range(cluster_pairs.shape[0]):
            cls_1 = int(cluster_pairs[i, 0])
            cls_2 = int(cluster_pairs[i, 1])
            if cls_1 not in code_dict or cls_2 not in code_dict:
                continue
            mmd2_D = mix_rbf_mmd2(code_dict[cls_1], code_dict[cls_2],
                                  sigma_list)
            loss_transfer += mmd2_D * cluster_pairs[i, 2]

        # reconstruction loss for all clusters
        loss_reconstruct = torch.FloatTensor([0])
        if cuda:
            loss_reconstruct = loss_reconstruct.cuda()
        for cls in data_dict:
            loss_reconstruct += F.mse_loss(reconstruct_dict[cls],
                                           data_dict[cls])

        loss = loss_reconstruct + gamma * loss_transfer

        loss.backward()
        optimizer.step()

        # update total loss
        num_batches += 1
        total_loss += loss.data.item()
        total_reco_loss += loss_reconstruct.data.item()
        total_tran_loss += loss_transfer.data.item()

    avg_total_loss = total_loss / num_batches
    avg_reco_loss = total_reco_loss / num_batches
    avg_tran_loss = total_tran_loss / num_batches

    if epoch % log_interval == 0:
        print(
            'Avg_loss {:.3E}\t Avg_reconstruct_loss {:.3E}\t Avg_transfer_loss {:.3E}'
            .format(avg_total_loss, avg_reco_loss, avg_tran_loss))
    return avg_total_loss, avg_reco_loss, avg_tran_loss
示例#14
0
    def _build_model(self):
        
        self.X = tf.placeholder(tf.uint8, [None, 28, 28, 3])
        self.y = tf.placeholder(tf.float32, [None, 10])
        self.domain = tf.placeholder(tf.float32, [None, 2])
        self.l = tf.placeholder(tf.float32, [])
        self.train = tf.placeholder(tf.bool, [])
        
        X_input = (tf.cast(self.X, tf.float32) - pixel_mean) / 255.
        
        # CNN model for feature extraction
        with tf.variable_scope('feature_extractor'):

            W_conv0 = weight_variable([5, 5, 3, 32])
            b_conv0 = bias_variable([32])
            h_conv0 = tf.nn.relu(conv2d(X_input, W_conv0) + b_conv0)
            h_pool0 = max_pool_2x2(h_conv0)
            
            W_conv1 = weight_variable([5, 5, 32, 48])
            b_conv1 = bias_variable([48])
            h_conv1 = tf.nn.relu(conv2d(h_pool0, W_conv1) + b_conv1)
            h_pool1 = max_pool_2x2(h_conv1)
            
            # The domain-invariant feature
            self.feature = tf.reshape(h_pool1, [-1, 7*7*48])
            
        # MLP for class prediction
        with tf.variable_scope('label_predictor'):
            
            # Switches to route target examples (second half of batch) differently
            # depending on train or test mode.
            all_features = lambda: self.feature
            source_features = lambda: tf.slice(self.feature, [0, 0], [batch_size // 2, -1])
            target_features = lambda: tf.slice(self.feature, [batch_size // 2, 0], [batch_size // 2, -1])
            classify_feats = tf.cond(self.train, source_features, all_features)
            
            all_labels = lambda: self.y
            source_labels = lambda: tf.slice(self.y, [0, 0], [batch_size // 2, -1])
            self.classify_labels = tf.cond(self.train, source_labels, all_labels)
            
            W_fc0 = weight_variable([7 * 7 * 48, 100])
            b_fc0 = bias_variable([100])
            h_fc0 = tf.nn.relu(tf.matmul(classify_feats, W_fc0) + b_fc0)

            W_fc1 = weight_variable([100, 100])
            b_fc1 = bias_variable([100])
            h_fc1 = tf.nn.relu(tf.matmul(h_fc0, W_fc1) + b_fc1)

            W_fc2 = weight_variable([100, 10])
            b_fc2 = bias_variable([10])
            logits = tf.matmul(h_fc1, W_fc2) + b_fc2
            
            self.pred = tf.nn.softmax(logits)
            self.pred_loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=self.classify_labels)

            # mmd loss
            bandwidths = [2.0, 5.0, 10.0, 20.0, 40.0, 80.0]
            self.mmd_loss = mix_rbf_mmd2(source_features(), target_features())

        # Small MLP for domain prediction with adversarial loss
        with tf.variable_scope('domain_predictor'):
            
            # Flip the gradient when backpropagating through this operation
            feat = flip_gradient(self.feature, self.l)
            
            d_W_fc0 = weight_variable([7 * 7 * 48, 100])
            d_b_fc0 = bias_variable([100])
            d_h_fc0 = tf.nn.relu(tf.matmul(feat, d_W_fc0) + d_b_fc0)
            
            d_W_fc1 = weight_variable([100, 2])
            d_b_fc1 = bias_variable([2])
            d_logits = tf.matmul(d_h_fc0, d_W_fc1) + d_b_fc1
            
            self.domain_pred = tf.nn.softmax(d_logits)
            self.domain_loss = tf.nn.softmax_cross_entropy_with_logits(logits=d_logits, labels=self.domain)
示例#15
0
    def update_params(batch, i_iter):
        dataSize = min(args.min_batch_size, len(batch.state))
        states = torch.from_numpy(np.stack(
            batch.state)[:dataSize, ]).to(dtype).to(device)
        actions = torch.from_numpy(np.stack(
            batch.action)[:dataSize, ]).to(dtype).to(device)
        rewards = torch.from_numpy(np.stack(
            batch.reward)[:dataSize, ]).to(dtype).to(device)
        masks = torch.from_numpy(np.stack(
            batch.mask)[:dataSize, ]).to(dtype).to(device)
        with torch.no_grad():
            values = value_net(states)
            fixed_log_probs = policy_net.get_log_prob(states, actions)
        """estimate reward"""
        """get advantage estimation from the trajectories"""
        advantages, returns = estimate_advantages(rewards, masks, values,
                                                  args.gamma, args.tau, device)
        """update discriminator"""
        for _ in range(args.discriminator_epochs):
            #dataSize = states.size()[0]
            # expert_state_actions = torch.from_numpy(expert_traj).to(dtype).to(device)
            exp_idx = random.sample(range(expert_traj.shape[0]), dataSize)
            expert_state_actions = torch.from_numpy(
                expert_traj[exp_idx, :]).to(dtype).to(device)

            dis_input_real = expert_state_actions
            if len(actions.shape) == 1:
                actions.unsqueeze_(-1)
                dis_input_fake = torch.cat([states, actions], 1)
                actions.squeeze_(-1)
            else:
                dis_input_fake = torch.cat([states, actions], 1)

            if args.EBGAN or args.GMMIL or args.GEOMGAN:
                # tbd, no discriminaotr learning
                pass
            else:
                g_o = discrim_net(dis_input_fake)
                e_o = discrim_net(dis_input_real)

            optimizer_discrim.zero_grad()
            if args.GEOMGAN:
                optimizer_kernel.zero_grad()

            if args.WGAN:
                if args.LSGAN:
                    pdist = l1dist(dis_input_real,
                                   dis_input_fake).mul(args.lamb)
                    discrim_loss = LeakyReLU(e_o - g_o + pdist).mean()
                else:
                    discrim_loss = torch.mean(e_o) - torch.mean(g_o)
            elif args.EBGAN:
                e_recon = elementwise_loss(e_o, dis_input_real)
                g_recon = elementwise_loss(g_o, dis_input_fake)
                discrim_loss = e_recon
                if (args.margin - g_recon).item() > 0:
                    discrim_loss += (args.margin - g_recon)
            elif args.GMMIL:
                #mmd2_D,K = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                mmd2_D, K = mix_rbf_mmd2(dis_input_real, dis_input_fake,
                                         args.sigma_list)
                #tbd
                #rewards = K[0]+K[1]-2*K[2]
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards.detach(
                )  # exp - gen, maximize (gen label negative)
                errD = mmd2_D
                discrim_loss = -errD  # maximize errD

                # prep for generator
                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
            elif args.GEOMGAN:
                # larger, better, but slower
                noise_num = 100
                mmd2_D, K = mix_imp_mmd2(e_o_enc, g_o_enc, noise_num,
                                         noise_dim, kernel_net, cuda)
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards.detach()
                errD = mmd2_D  #+ args.lambda_rg * one_side_errD
                discrim_loss = -errD  # maximize errD

                # prep for generator
                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
            else:
                discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \
                               discrim_criterion(e_o, zeros((e_o.shape[0], 1), device=device))
            if args.GEOMGAN:
                optimizer_kernel.step()
        """perform mini-batch PPO update"""
        optim_iter_num = int(math.ceil(states.shape[0] / args.ppo_batch_size))
        for _ in range(args.generator_epochs):
            perm = np.arange(states.shape[0])
            np.random.shuffle(perm)
            perm = LongTensor(perm).to(device)

            states, actions, returns, advantages, fixed_log_probs = \
                states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), \
                fixed_log_probs[perm].clone()

            for i in range(optim_iter_num):
                ind = slice(
                    i * args.ppo_batch_size,
                    min((i + 1) * args.ppo_batch_size, states.shape[0]))
                states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                    states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

                ppo_step(policy_net, value_net, optimizer_policy,
                         optimizer_value, 1, states_b, actions_b, returns_b,
                         advantages_b, fixed_log_probs_b, args.clip_epsilon,
                         args.l2_reg)

        return rewards
    def update_params(batch, i_iter):
        dataSize = min(args.min_batch_size, len(batch.state))
        states = torch.from_numpy(np.stack(
            batch.state)[:dataSize, ]).to(dtype).to(device)
        actions = torch.from_numpy(np.stack(
            batch.action)[:dataSize, ]).to(dtype).to(device)
        rewards = torch.from_numpy(np.stack(
            batch.reward)[:dataSize, ]).to(dtype).to(device)
        masks = torch.from_numpy(np.stack(
            batch.mask)[:dataSize, ]).to(dtype).to(device)
        with torch.no_grad():
            values = value_net(states)
            fixed_log_probs = policy_net.get_log_prob(states, actions)
        """estimate reward"""
        """get advantage estimation from the trajectories"""
        advantages, returns = estimate_advantages(rewards, masks, values,
                                                  args.gamma, args.tau, device)
        """update discriminator"""
        for _ in range(args.discriminator_epochs):
            exp_idx = random.sample(range(expert_traj.shape[0]), dataSize)
            expert_state_actions = torch.from_numpy(
                expert_traj[exp_idx, :]).to(dtype).to(device)

            dis_input_real = expert_state_actions
            if len(actions.shape) == 1:
                actions.unsqueeze_(-1)
                dis_input_fake = torch.cat([states, actions], 1)
                actions.squeeze_(-1)
            else:
                dis_input_fake = torch.cat([states, actions], 1)

            if args.EBGAN or args.GMMIL or args.VAKLIL:
                g_o_enc, g_mu, g_sigma = discrim_net(dis_input_fake,
                                                     mean_mode=False)
                e_o_enc, e_mu, e_sigma = discrim_net(dis_input_real,
                                                     mean_mode=False)
            else:
                g_o = discrim_net(dis_input_fake)
                e_o = discrim_net(dis_input_real)

            optimizer_discrim.zero_grad()
            if args.VAKLIL:
                optimizer_kernel.zero_grad()

            if args.AL:
                if args.LSGAN:
                    pdist = l1dist(dis_input_real,
                                   dis_input_fake).mul(args.lamb)
                    discrim_loss = LeakyReLU(e_o - g_o + pdist).mean()
                else:
                    discrim_loss = torch.mean(e_o) - torch.mean(g_o)
            elif args.EBGAN:
                e_recon = elementwise_loss(e_o, dis_input_real)
                g_recon = elementwise_loss(g_o, dis_input_fake)
                discrim_loss = e_recon
                if (args.margin - g_recon).item() > 0:
                    discrim_loss += (args.margin - g_recon)
            elif args.GMMIL:
                mmd2_D, K = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards.detach(
                )  # exp - gen, maximize (gen label negative)
                errD = mmd2_D
                discrim_loss = -errD  # maximize errD

                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
            elif args.VAKLIL:
                noise_num = 20000
                mmd2_D_net, _, penalty = mix_imp_with_bw_mmd2(
                    e_o_enc, g_o_enc, noise_num, noise_dim, kernel_net, cuda,
                    args.sigma_list)
                mmd2_D_rbf, _ = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                errD = (mmd2_D_net + mmd2_D_rbf) / 2
                # 1e-8: small number for numerical stability
                i_c = 0.2
                bottleneck_loss = torch.mean((0.5 * torch.sum((torch.cat(
                    (e_mu, g_mu), dim=0)**2) + (torch.cat(
                        (e_sigma, g_sigma), dim=0)**2) - torch.log((torch.cat(
                            (e_sigma, g_sigma), dim=0)**2) + 1e-8) - 1,
                                                              dim=1))) - i_c
                discrim_loss = -errD + (args.beta * bottleneck_loss) + (
                    args.lambda_h * penalty)
            else:
                discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \
                               discrim_criterion(e_o, zeros((e_o.shape[0], 1), device=device))

            discrim_loss.backward()
            optimizer_discrim.step()
            if args.VAKLIL:
                optimizer_kernel.step()

        if args.VAKLIL:
            with torch.no_grad():
                noise_num = 20000
                g_o_enc, _, _ = discrim_net(dis_input_fake)
                e_o_enc, _, _ = discrim_net(dis_input_real)
                _, K_net, _ = mix_imp_with_bw_mmd2(e_o_enc, g_o_enc, noise_num,
                                                   noise_dim, kernel_net, cuda,
                                                   args.sigma_list)
                _, K_rbf = mix_rbf_mmd2(e_o_enc, g_o_enc, args.sigma_list)
                K = [sum(x) / 2 for x in zip(K_net, K_rbf)]
                rewards = K[1] - K[2]  # -(exp - gen): -(kxy-kyy)=kyy-kxy
                rewards = -rewards  #.detach()
                advantages, returns = estimate_advantages(
                    rewards, masks, values, args.gamma, args.tau, device)
        """perform mini-batch PPO update"""
        optim_iter_num = int(math.ceil(states.shape[0] / args.ppo_batch_size))
        for _ in range(args.generator_epochs):
            perm = np.arange(states.shape[0])
            np.random.shuffle(perm)
            perm = LongTensor(perm).to(device)

            states, actions, returns, advantages, fixed_log_probs = \
                states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), \
                fixed_log_probs[perm].clone()

            for i in range(optim_iter_num):
                ind = slice(
                    i * args.ppo_batch_size,
                    min((i + 1) * args.ppo_batch_size, states.shape[0]))
                states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                    states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

                ppo_step(policy_net, value_net, optimizer_policy,
                         optimizer_value, 1, states_b, actions_b, returns_b,
                         advantages_b, fixed_log_probs_b, args.clip_epsilon,
                         args.l2_reg)

        return rewards
示例#17
0
            x_cpu, _ = data
            x = Variable(x_cpu.cuda())
            batch_size = x.size(0)

            f_enc_X_D, f_dec_X_D = netD(x)

            noise = torch.cuda.FloatTensor(batch_size, args.nz, 1,
                                           1).normal_(0, 1)
            noise = Variable(noise, volatile=True)  # total freeze netG
            y = Variable(netG(noise).data)

            f_enc_Y_D, f_dec_Y_D = netD(y)

            # compute biased MMD2 and use ReLU to prevent negative value
            mmd2_D = mix_rbf_mmd2(f_enc_X_D, f_enc_Y_D, sigma_list)
            mmd2_D = F.relu(mmd2_D)

            # compute rank hinge loss
            #print('f_enc_X_D:', f_enc_X_D.size())
            #print('f_enc_Y_D:', f_enc_Y_D.size())
            one_side_errD = one_sided(f_enc_X_D.mean(0) - f_enc_Y_D.mean(0))

            # compute L2-loss of AE
            L2_AE_X_D = util.match(x.view(batch_size, -1), f_dec_X_D, 'L2')
            L2_AE_Y_D = util.match(y.view(batch_size, -1), f_dec_Y_D, 'L2')

            errD = torch.sqrt(
                mmd2_D
            ) + lambda_rg * one_side_errD - lambda_AE_X * L2_AE_X_D - lambda_AE_Y * L2_AE_Y_D
            errD.backward(mone)
def run_gan(args,loss, batch_size, epsilon = 1, niter_sink = 1):

    os.system('mkdir {0}_{1}_eps{2}_niter{3}_batch{4}'.format(args.experiment,loss,epsilon,niter_sink,batch_size))

    args.manual_seed = 1126
    np.random.seed(seed=args.manual_seed)
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed(args.manual_seed)
    cudnn.benchmark = True

    # Get data
    trn_dataset = util.get_data(args, train_flag=True)
    trn_loader = torch.utils.data.DataLoader(trn_dataset,
                                             batch_size=batch_size,
                                             shuffle=True)

    # construct encoder/decoder modules
    hidden_dim = args.nz
    G_decoder = base_module.Decoder(args.image_size, args.nc, k=args.nz, ngf=64)
    D_encoder = base_module.Encoder(args.image_size, args.nc, k=hidden_dim, ndf=64)
    D_decoder = base_module.Decoder(args.image_size, args.nc, k=hidden_dim, ngf=64)

    netG = NetG(G_decoder)
    netD = NetD(D_encoder, D_decoder)
    one_sided = ONE_SIDED()
    #print("netG:", netG)
    #print("netD:", netD)
    #print("oneSide:", one_sided)

    netG.apply(base_module.weights_init)
    netD.apply(base_module.weights_init)
    one_sided.apply(base_module.weights_init)

    # sigma for MMD
    base = 1.0
    sigma_list = [1, 2, 4, 8, 16]
    sigma_list = [sigma / base for sigma in sigma_list]

    # put variable into cuda device
    fixed_noise = torch.cuda.FloatTensor(10**4, args.nz, 1, 1).normal_(0, 1)
    one = torch.cuda.FloatTensor([1])
    mone = one * -1
    if args.cuda:
        netG.cuda()
        netD.cuda()
        one_sided.cuda()
    fixed_noise = Variable(fixed_noise, requires_grad=False)

    # setup optimizer
    optimizerG = torch.optim.RMSprop(netG.parameters(), lr=args.lr)
    optimizerD = torch.optim.RMSprop(netD.parameters(), lr=args.lr)


    time = timeit.default_timer()
    gen_iterations = 0
    for t in range(args.max_iter):
        data_iter = iter(trn_loader)
        i = 0
        while (i < len(trn_loader)):
            # ---------------------------
            #        Optimize over NetD
            # ---------------------------
            for p in netD.parameters():
                p.requires_grad = True

            
            if i == len(trn_loader):
                break

            # clamp parameters of NetD encoder to a cube
            # do not clamp paramters of NetD decoder!!!
            for p in netD.encoder.parameters():
                p.data.clamp_(-0.01, 0.01)

            data = data_iter.next()
            i += 1
            netD.zero_grad()

            x_cpu, _ = data
            x = Variable(x_cpu.cuda())
            batch_size = x.size(0)

            f_enc_X_D, f_dec_X_D = netD(x)

            noise = torch.cuda.FloatTensor(batch_size, args.nz, 1, 1).normal_(0, 1)
            noise = Variable(noise, volatile=True)  # total freeze netG
            y = Variable(netG(noise).data)

            f_enc_Y_D, f_dec_Y_D = netD(y)

            if loss == 'sinkhorn_primal':

                sink_D = 2*sinkhorn_loss_primal(f_enc_X_D, f_enc_Y_D, epsilon,batch_size,niter_sink) \
                        - sinkhorn_loss_primal(f_enc_Y_D, f_enc_Y_D, epsilon, batch_size,niter_sink) \
                        - sinkhorn_loss_primal(f_enc_X_D, f_enc_X_D, epsilon, batch_size,niter_sink)
                errD = sink_D 

            elif loss == 'sinkhorn_dual' :

                sink_D = 2*sinkhorn_loss_dual(f_enc_X_D, f_enc_Y_D, epsilon,batch_size,niter_sink) \
                        - sinkhorn_loss_dual(f_enc_Y_D, f_enc_Y_D, epsilon, batch_size,niter_sink) \
                        - sinkhorn_loss_dual(f_enc_X_D, f_enc_X_D, epsilon, batch_size,niter_sink)
                errD = sink_D 

            else:
                mmd2_D = mix_rbf_mmd2(f_enc_X_D, f_enc_Y_D, sigma_list)
                mmd2_D = F.relu(mmd2_D)
                errD = mmd2_D 


            errD.backward(mone)
            optimizerD.step()

            # ---------------------------
            #        Optimize over NetG
            # ---------------------------
            for p in netD.parameters():
                p.requires_grad = False


            netG.zero_grad()

            x_cpu, _ = data
            x = Variable(x_cpu.cuda())
            batch_size = x.size(0)

            f_enc_X, f_dec_X = netD(x)

            noise = torch.cuda.FloatTensor(batch_size, args.nz, 1, 1).normal_(0, 1)
            noise = Variable(noise)
            y = netG(noise)

            f_enc_Y, f_dec_Y = netD(y)

          
            ###### Sinkhorn loss #########

            if loss == 'sinkhorn_primal':
            
                sink_G = 2*sinkhorn_loss_primal(f_enc_X, f_enc_Y, epsilon,batch_size,niter_sink) \
                        - sinkhorn_loss_primal(f_enc_Y, f_enc_Y, epsilon, batch_size,niter_sink) \
                        - sinkhorn_loss_primal(f_enc_X, f_enc_X, epsilon, batch_size,niter_sink)
                errG = sink_G 

            elif loss == 'sinkhorn_dual':
            
                sink_G = 2*sinkhorn_loss_dual(f_enc_X, f_enc_Y, epsilon,batch_size,niter_sink) \
                        - sinkhorn_loss_dual(f_enc_Y, f_enc_Y, epsilon, batch_size,niter_sink) \
                        - sinkhorn_loss_dual(f_enc_X, f_enc_X, epsilon, batch_size,niter_sink)
                errG = sink_G 

            else :
                mmd2_G = mix_rbf_mmd2(f_enc_X, f_enc_Y, sigma_list)
                mmd2_G = F.relu(mmd2_G)
                errG = mmd2_G 

            errG.backward(one)
            optimizerG.step()

            gen_iterations += 1

            if gen_iterations%20 == 1:
            	print('generator iterations ='+str(gen_iterations))

            if gen_iterations%500 == 1:
            	y_fixed = netG(fixed_noise)
            	y_fixed.data = y_fixed.data.mul(0.5).add(0.5)
            	imgfilename = '{0}_{1}_eps{2}_niter{3}_batch{5}/imglist_{4}'.format(args.experiment,loss,epsilon,niter_sink,gen_iterations , batch_size)
            	torch.save(y_fixed.data,imgfilename)
            	print('images saved! generator iterations ='+str(gen_iterations))

            if gen_iterations>10**5:
            	print('done!')
            	break

        if gen_iterations>10**5:
        	print('done!')
        	break