Exemple #1
0
 def build_model(self):
     self.model = DSN().to(self.device)
     self.model.apply(xavier_weights_init)
     self.optimizer = torch.optim.Adam(self.model.parameters(),
                                       lr=self.lr,
                                       betas=[self.beta1, self.beta2],
                                       weight_decay=self.weight_decay)
def main(_):
    
    model = DSN(mode=FLAGS.mode, learning_rate=0.0003)
    solver = Solver(model, svhn_dir='svhn', mnist_dir='mnist', model_save_path=FLAGS.model_save_path, sample_save_path=FLAGS.sample_save_path)
    
    # create directories if not exist
    if not tf.gfile.Exists(FLAGS.model_save_path):
	    tf.gfile.MakeDirs(FLAGS.model_save_path)
    if not tf.gfile.Exists(FLAGS.sample_save_path):
	    tf.gfile.MakeDirs(FLAGS.sample_save_path)
    
    if FLAGS.mode == 'pretrain':
	    solver.pretrain()
    elif FLAGS.mode == 'train_sampler':
	    solver.train_sampler()
    elif FLAGS.mode == 'train_dsn':
	    solver.train_dsn()
    elif FLAGS.mode == 'eval_dsn':
	    solver.eval_dsn()
    elif FLAGS.mode == 'test':
	    solver.test()
    elif FLAGS.mode == 'train_convdeconv':
	    solver.train_convdeconv()
    elif FLAGS.mode == 'train_gen_images':
	    solver.train_gen_images()
    elif FLAGS.mode == 'end_to_end':
	    solver.train_end_to_end()
    
    
    elif FLAGS.mode == 'train_all':
	
	start_img = 1600
	end_img = 3200
	
	for start,end,name in zip([3200,4800,6400,8000,9600],[4800,6400,8000,9600,11200],['Exp3','Exp4','Exp5','Exp6','Exp7']):
	
	    model = DSN(mode='train_dsn', learning_rate=0.0001)
	    solver = Solver(model, svhn_dir='svhn', mnist_dir='mnist', model_save_path=FLAGS.model_save_path, sample_save_path=FLAGS.sample_save_path, start_img = start_img, end_img = end_img)
	    solver.train_dsn()
	    
	    model = DSN(mode='eval_dsn')
	    solver = Solver(model, svhn_dir='svhn', mnist_dir='mnist', model_save_path=FLAGS.model_save_path, sample_save_path=FLAGS.sample_save_path)
	    solver.eval_dsn(name=name)

	    tf.reset_default_graph()

    else:
	print 'Unrecognized mode.'
Exemple #3
0
def main(_):

    with tf.device('/gpu:' + FLAGS.gpu):
        model = DSN(mode=FLAGS.mode, learning_rate=0.001)
        src_split, trg_split = FLAGS.splits.split('2')[0], FLAGS.splits.split(
            '2')[1]
        solver = Solver(model,
                        batch_size=64,
                        src_dir=src_split,
                        trg_dir=trg_split)

        if FLAGS.mode == 'pretrain':
            solver.pretrain()
        elif FLAGS.mode == 'train_sampler':
            solver.train_sampler()
        elif FLAGS.mode == 'train_dsn':
            solver.train_dsn()
        elif FLAGS.mode == 'eval_dsn':
            solver.eval_dsn()
        elif FLAGS.mode == 'test':
            solver.test()
        elif FLAGS.mode == 'features':
            solver.features()
        elif FLAGS.mode == 'test_ensemble':
            solver.test_ensemble()
        elif FLAGS.mode == 'train_adda_shared' or FLAGS.mode == 'train_adda':
            solver.train_adda_shared()
        else:
            print 'Unrecognized mode.'
Exemple #4
0
def main(_):

    model = DSN(mode=FLAGS.mode, learning_rate=0.0001)
    src_split, trg_split = FLAGS.splits.split('2')[0], FLAGS.splits.split(
        '2')[1]
    solver = Solver(model,
                    batch_size=128,
                    src_dir=src_split,
                    trg_dir=trg_split)

    if FLAGS.mode == 'pretrain':
        solver.pretrain()
    elif FLAGS.mode == 'train_sampler':
        solver.train_sampler()
    elif FLAGS.mode == 'train_dsn':
        solver.train_dsn()
    elif FLAGS.mode == 'eval_dsn':
        solver.eval_dsn()
    elif FLAGS.mode == 'test':
        solver.test()
    elif FLAGS.mode == 'test_ensemble':
        solver.test_ensemble()

    else:
        print 'Unrecognized mode.'
Exemple #5
0
def main(_):

    src_split, trg_split = FLAGS.splits.split('2')[0], FLAGS.splits.split(
        '2')[1]

    from model import DSN
    model = DSN(mode='eval_dsn', learning_rate=0.0003)
    solver = Solver(model, src_dir=src_split, trg_dir=trg_split)
    solver.check_TSNE()
Exemple #6
0
def main(config):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    model = DSN().cuda()

    state = torch.load(config.ckp_path)
    model.load_state_dict(state['state_dict'])

    filenames = glob.glob(os.path.join(config.img_dir, '*.png'))
    filenames = sorted(filenames)

    out_filename = config.save_path
    os.makedirs(os.path.dirname(config.save_path), exist_ok=True)

    model.eval()
    with open(out_filename, 'w') as out_file:
        out_file.write('image_name,label\n')
        with torch.no_grad():
            for fn in filenames:
                data = Image.open(fn).convert('RGB')
                data = transform(data)
                data = torch.unsqueeze(data, 0)
                data = data.cuda()
                output, _, _, _, _ = model(data, mode=config.mode)
                pred = output.max(1, keepdim=True)[
                    1]  # get the index of the max log-probability
                out_file.write(
                    fn.split('/')[-1] + ',' + str(pred.item()) + '\n')
Exemple #7
0
def main(_):

    model = DSN(mode=FLAGS.mode, learning_rate=0.00001)
    solver = Solver(model, batch_size=32)

    if FLAGS.mode == 'pretrain':
        solver.pretrain()
    elif FLAGS.mode == 'train_sampler':
        solver.train_sampler()
    elif FLAGS.mode == 'train_dsn':
        solver.train_dsn()
    elif FLAGS.mode == 'eval_dsn':
        solver.eval_dsn()
    elif FLAGS.mode == 'test':
        solver.test()
    elif FLAGS.mode == 'features':
        solver.features()
    elif FLAGS.mode == 'test_ensemble':
        solver.test_ensemble()
    elif FLAGS.mode == 'train_adda_shared' or FLAGS.mode == 'train_adda':
        solver.train_adda_shared()
    else:
        print 'Unrecognized mode.'
Exemple #8
0
def main(_):

    npr.seed(291)

    GPU_ID = 3

    os.environ[
        "CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152 on stackoverflow
    os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)

    model = DSN(mode=FLAGS.mode, learning_rate=0.0003)
    solver = Solver(model,
                    svhn_dir='/data/svhn',
                    syn_dir='/data/syn',
                    model_save_path=FLAGS.model_save_path,
                    sample_save_path=FLAGS.sample_save_path)

    # create directories if not exist
    if not tf.gfile.Exists(FLAGS.model_save_path):
        tf.gfile.MakeDirs(FLAGS.model_save_path)
    if not tf.gfile.Exists(FLAGS.sample_save_path):
        tf.gfile.MakeDirs(FLAGS.sample_save_path)

    if FLAGS.mode == 'pretrain':
        solver.pretrain()
    elif FLAGS.mode == 'train_sampler':
        solver.train_sampler()
    elif FLAGS.mode == 'train_dsn':
        solver.train_dsn()
    elif FLAGS.mode == 'eval_dsn':
        solver.eval_dsn()
    elif FLAGS.mode == 'test':
        solver.test()
    elif FLAGS.mode == 'train_convdeconv':
        solver.train_convdeconv()
    elif FLAGS.mode == 'train_gen_images':
        solver.train_gen_images()

    elif FLAGS.mode == 'train_all':

        start_img = 1600
        end_img = 3200

        for start, end, name in zip([3200, 4800, 6400, 8000, 9600],
                                    [4800, 6400, 8000, 9600, 11200],
                                    ['Exp3', 'Exp4', 'Exp5', 'Exp6', 'Exp7']):

            model = DSN(mode='train_dsn', learning_rate=0.0001)
            solver = Solver(model,
                            svhn_dir='svhn',
                            mnist_dir='mnist',
                            model_save_path=FLAGS.model_save_path,
                            sample_save_path=FLAGS.sample_save_path,
                            start_img=start_img,
                            end_img=end_img)
            solver.train_dsn()

            model = DSN(mode='eval_dsn')
            solver = Solver(model,
                            svhn_dir='svhn',
                            mnist_dir='mnist',
                            model_save_path=FLAGS.model_save_path,
                            sample_save_path=FLAGS.sample_save_path)
            solver.eval_dsn(name=name)

            tf.reset_default_graph()

    else:
        print 'Unrecognized mode.'
Exemple #9
0
                print ('Step: [%d/%d] test acc [%.3f]' \
                    %(t+1, self.pretrain_iter, test_trg_acc))

                print confusion_matrix(test_labels, trg_pred)

                acc.append(test_trg_acc)
                with open('test_acc.pkl', 'wb') as f:
                    cPickle.dump(acc, f, cPickle.HIGHEST_PROTOCOL)

                #~ gen_acc = sess.run(fetches=[model.trg_accuracy, model.trg_pred],
                #~ feed_dict={model.src_images: gen_images,
                #~ model.src_labels: gen_labels,
                #~ model.trg_images: gen_images,
                #~ model.trg_labels: gen_labels})

                #~ print ('Step: [%d/%d] src train acc [%.2f]  src test acc [%.2f] trg test acc [%.2f]' \
                #~ %(t+1, self.pretrain_iter, gen_acc))

                time.sleep(10.1)


if __name__ == '__main__':

    from model import DSN
    model = DSN(mode='eval_dsn', learning_rate=0.0003)
    solver = Solver(model)
    #~ solver.find_closest_samples()

    solver.check_TSNE()
                    })
                src_acc = sess.run(model.src_accuracy,
                                   feed_dict={
                                       model.src_images:
                                       src_images[:20000],
                                       model.src_labels:
                                       src_labels[:20000],
                                       model.trg_images:
                                       trg_test_images[trg_rand_idxs],
                                       model.trg_labels:
                                       trg_test_labels[trg_rand_idxs]
                                   })

                print ('Step: [%d/%d] src train acc [%.3f]  src test acc [%.3f] trg test acc [%.3f]' \
                    %(t+1, self.pretrain_iter, src_acc, test_src_acc, test_trg_acc))

                print confusion_matrix(trg_test_labels[trg_rand_idxs],
                                       trg_pred)

                acc.append(test_trg_acc)
                with open(self.protocol + '_' + algorithm + '.pkl', 'wb') as f:
                    cPickle.dump(acc, f, cPickle.HIGHEST_PROTOCOL)


if __name__ == '__main__':

    from model import DSN
    model = DSN(mode='eval_dsn')
    solver = Solver(model)
    solver.check_TSNE()
Exemple #11
0
class Solver(object):
    def __init__(self,
                 src_trainset_loader,
                 src_valset_loader,
                 tgt_trainset_loader=None,
                 config=None):
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.use_cuda else 'cpu')
        self.src_trainset_loader = src_trainset_loader
        self.src_valset_loader = src_valset_loader
        self.tgt_trainset_loader = tgt_trainset_loader
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.step_decay_weight = config.step_decay_weight
        self.active_domain_loss_step = config.active_domain_loss_step
        self.resume_iters = config.resume_iters
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.weight_decay = config.weight_decay
        self.alpha_weight = config.alpha_weight
        self.beta_weight = config.beta_weight
        self.gamma_weight = config.gamma_weight
        self.src_only = config.src_only
        self.exp_name = config.name
        os.makedirs(config.ckp_dir, exist_ok=True)
        self.ckp_dir = os.path.join(config.ckp_dir, self.exp_name)
        os.makedirs(self.ckp_dir, exist_ok=True)
        self.example_dir = os.path.join(self.ckp_dir, "output")
        os.makedirs(self.example_dir, exist_ok=True)
        self.log_interval = config.log_interval
        self.save_interval = config.save_interval
        self.use_wandb = config.use_wandb

        self.build_model()

    def build_model(self):
        self.model = DSN().to(self.device)
        self.model.apply(xavier_weights_init)
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=self.lr,
                                          betas=[self.beta1, self.beta2],
                                          weight_decay=self.weight_decay)

    def save_checkpoint(self, step):
        state = {
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }
        new_checkpoint_path = os.path.join(self.ckp_dir,
                                           '{}-dsn.pth'.format(step + 1))
        torch.save(state, new_checkpoint_path)
        print('model saved to %s' % new_checkpoint_path)

    def load_checkpoint(self, resume_iters):
        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        new_checkpoint_path = os.path.join(self.ckp_dir,
                                           '{}-dsn.pth'.format(resume_iters))
        state = torch.load(new_checkpoint_path)
        self.model.load_state_dict(state['state_dict'])
        self.optimizer.load_state_dict(state['optimizer'])
        print('model loaded from %s' % new_checkpoint_path)

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.optimizer.zero_grad()

    def train(self):
        task_criterion = nn.CrossEntropyLoss()
        recon_criterion = SIMSE()
        diff_criterion = DiffLoss()
        sim_criterion = nn.CrossEntropyLoss()
        fix_src_data, _ = next(iter(self.src_valset_loader))
        fix_src_data = fix_src_data.to(self.device)
        if not self.src_only:
            fix_tgt_data, _ = next(iter(self.tgt_trainset_loader))
            fix_tgt_data = fix_tgt_data.to(self.device)

        best_acc = 0
        best_loss = 1e15
        iteration = 0
        if self.resume_iters:
            print("resuming step %d ..." % self.resume_iters)
            iteration = self.resume_iters
            self.load_checkpoint(self.resume_iters)
            best_loss, best_acc = self.eval()

        while iteration < self.num_iters:
            self.model.train()
            self.optimizer.zero_grad()
            loss = 0.0

            if self.src_only:
                tgt_domain_loss = torch.zeros(1)
                tgt_recon_loss = torch.zeros(1)
                tgt_diff_loss = torch.zeros(1)

            else:
                try:
                    tgt_data, _ = next(tgt_data_iter)
                except:
                    tgt_data_iter = iter(self.tgt_trainset_loader)
                    tgt_data, _ = next(tgt_data_iter)

                tgt_data = tgt_data.to(self.device)
                tgt_batch_size = len(tgt_data)

                if iteration > self.active_domain_loss_step:
                    p = float(iteration - self.active_domain_loss_step) / (
                        self.num_iters - self.active_domain_loss_step)
                    p = 2. / (1. + np.exp(-10 * p)) - 1

                    _, tgt_domain_output, tgt_private_code, tgt_shared_code, tgt_recon = self.model(
                        tgt_data, mode='target', p=p)
                    tgt_domain_label = torch.ones((tgt_batch_size, ),
                                                  dtype=torch.long,
                                                  device=self.device)
                    tgt_domain_loss = sim_criterion(tgt_domain_output,
                                                    tgt_domain_label)
                    loss += self.gamma_weight * tgt_domain_loss

                else:
                    _, tgt_domain_output, tgt_private_code, tgt_shared_code, tgt_recon = self.model(
                        tgt_data, mode='target')
                    tgt_domain_loss = torch.zeros(1)

                tgt_recon_loss = recon_criterion(tgt_recon, tgt_data)
                tgt_diff_loss = diff_criterion(tgt_private_code,
                                               tgt_shared_code)

                loss += (self.alpha_weight * tgt_recon_loss +
                         self.beta_weight * tgt_diff_loss)

            try:
                src_data, src_class_label = next(src_data_iter)
            except:
                src_data_iter = iter(self.src_trainset_loader)
                src_data, src_class_label = next(src_data_iter)

            src_data, src_class_label = src_data.to(
                self.device), src_class_label.to(self.device)
            src_batch_size = src_data.size(0)

            if iteration > self.active_domain_loss_step:
                p = float(iteration - self.active_domain_loss_step) / (
                    self.num_iters - self.active_domain_loss_step)
                p = 2. / (1. + np.exp(-10 * p)) - 1

                src_class_output, src_domain_output, src_private_code, src_shared_code, src_recon = self.model(
                    src_data, mode='source', p=p)
                src_domain_label = torch.zeros((src_batch_size, ),
                                               dtype=torch.long,
                                               device=self.device)
                src_domain_loss = sim_criterion(src_domain_output,
                                                src_domain_label)
                loss += self.gamma_weight * src_domain_loss

            else:
                src_class_output, src_domain_output, src_private_code, src_shared_code, src_recon = self.model(
                    src_data, mode='source')
                src_domain_loss = torch.zeros(1)

            src_class_loss = task_criterion(src_class_output, src_class_label)
            src_recon_loss = recon_criterion(src_recon, src_data)
            src_diff_loss = diff_criterion(src_private_code, src_shared_code)

            loss += (src_class_loss + self.alpha_weight * src_recon_loss +
                     self.beta_weight * src_diff_loss)

            loss.backward()
            self.optimizer = exp_lr_scheduler(
                optimizer=self.optimizer,
                step=iteration,
                init_lr=self.lr,
                lr_decay_step=self.num_iters_decay,
                step_decay_weight=self.step_decay_weight)
            self.optimizer.step()

            # Output training stats
            if (iteration + 1) % self.log_interval == 0:
                print(
                    'Iteration: {:5d} / {:d} loss: {:.6f} loss_src_class: {:.6f} loss_src_domain: {:.6f} loss_src_recon: {:.6f} loss_src_diff: {:.6f} loss_tgt_domain: {:.6f} loss_tgt_recon: {:.6f} loss_tgt_diff: {:.6f}'
                    .format(iteration + 1, self.num_iters, loss.item(),
                            src_class_loss.item(), src_domain_loss.item(),
                            src_recon_loss.item(), src_diff_loss.item(),
                            tgt_domain_loss.item(), tgt_recon_loss.item(),
                            tgt_diff_loss.item()))

                if self.use_wandb:
                    import wandb
                    wandb.log(
                        {
                            "loss": loss.item(),
                            "loss_src_class": src_class_loss.item(),
                            "loss_src_domain": src_domain_loss.item(),
                            "loss_src_recon": src_recon_loss.item(),
                            "loss_src_diff": src_diff_loss.item(),
                            "loss_tgt_domain": tgt_domain_loss.item(),
                            "loss_tgt_recon": tgt_recon_loss.item(),
                            "loss_tgt_diff": tgt_diff_loss.item()
                        },
                        step=iteration + 1)

                # Save model checkpoints
            if (iteration + 1) % self.save_interval == 0 and iteration > 0:
                val_loss, val_acc = self.eval()
                if self.use_wandb:
                    import wandb
                    wandb.log({
                        "val_loss": val_loss,
                        "val_acc": val_acc
                    },
                              step=iteration + 1,
                              commit=False)

                self.save_checkpoint(iteration)

                if (val_acc > best_acc):
                    print('val acc: %.2f > %.2f' % (val_acc, best_acc))
                    best_acc = val_acc
                if (val_loss < best_loss):
                    print('val loss: %.4f < %.4f' % (val_loss, best_loss))
                    best_loss = val_loss

                _, _, _, _, rec_all = self.model(fix_src_data,
                                                 mode='source',
                                                 rec_scheme='all')
                _, _, _, _, rec_share = self.model(fix_src_data,
                                                   mode='source',
                                                   rec_scheme='share')
                _, _, _, _, rec_private = self.model(fix_src_data,
                                                     mode='source',
                                                     rec_scheme='private')
                vutils.save_image(torch.cat(
                    (fix_src_data, rec_all, rec_share, rec_private)),
                                  os.path.join(self.example_dir,
                                               '%d_src.png' % (iteration + 1)),
                                  nrow=16,
                                  normalize=True)

                if not self.src_only:
                    _, _, _, _, rec_all = self.model(fix_tgt_data,
                                                     mode='target',
                                                     rec_scheme='all')
                    _, _, _, _, rec_share = self.model(fix_tgt_data,
                                                       mode='target',
                                                       rec_scheme='share')
                    _, _, _, _, rec_private = self.model(fix_tgt_data,
                                                         mode='target',
                                                         rec_scheme='private')
                    vutils.save_image(torch.cat(
                        (fix_tgt_data, rec_all, rec_share, rec_private)),
                                      os.path.join(
                                          self.example_dir,
                                          '%d_tgt.png' % (iteration + 1)),
                                      nrow=16,
                                      normalize=True)

            iteration += 1

    def eval(self):
        criterion = nn.CrossEntropyLoss()
        self.model.eval()
        val_loss = 0.0
        correct = 0.0
        with torch.no_grad():
            for b_idx, (data, label) in enumerate(self.src_valset_loader):
                data, label = data.to(self.device), label.to(self.device)
                output, _, _, _, _ = self.model(data,
                                                mode='source',
                                                rec_scheme='all')
                val_loss += criterion(output, label).item()
                pred = torch.exp(output).max(1, keepdim=True)[1]
                correct += pred.eq(label.view_as(pred)).sum().item()

        val_loss /= len(self.src_valset_loader)
        val_acc = 100. * correct / len(self.src_valset_loader.dataset)
        print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.
              format(val_loss, correct, len(self.src_valset_loader.dataset),
                     val_acc))

        return val_loss, val_acc
Exemple #12
0
                                                   shuffle=True,
                                                   num_workers=8)

dataset_target = datasets.MNIST(
    root=target_dataset,
    train=True,
    transform=img_tgt_transform,
)

datasetloader_target = torch.utils.data.DataLoader(dataset=dataset_target,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=8)

# load models
my_net = DSN(n_class=10, code_size=3072, channels=n_channels)
my_net.apply(weights_init)

# setup optimizer
optimizer = optim.Adam(my_net.parameters(), lr=lr, weight_decay=weight_decay)

loss_class = nn.CrossEntropyLoss()
loss_rec = func.mean_pairwise_square_loss()
loss_diff = func.difference_loss()
if cuda:
    my_net = my_net.cuda()
    loss_class = loss_class.cuda()
    loss_rec = loss_rec.cuda()
    loss_diff = loss_diff.cuda()

#loss coefficients