def main(): # parse options parser = FinetuneOptions() opts = parser.parse() # data loader print('--- load parameter ---') outer_iter = opts.outer_iter epochs = opts.epoch batchsize = opts.batchsize datasize = opts.datasize stylename = opts.style_name # model print('--- create model ---') tetGAN = TETGAN(gpu=(opts.gpu != 0)) if opts.gpu != 0: tetGAN.cuda() tetGAN.load_state_dict(torch.load(opts.load_model_name)) tetGAN.train() print('--- training ---') # supervised one shot learning if opts.supervise == 1: for i in range(outer_iter): fnames = load_oneshot_batchfnames(stylename, batchsize, datasize) for epoch in range(epochs): for fname in fnames: x, y_real, y = prepare_batch(fname, 3, 0, 0, 0, opts.gpu) losses = tetGAN.one_pass(x[0], None, y[0], None, y_real[0], None, 3, 0) print('Iter[%d/%d], Epoch [%d/%d]' % (i + 1, outer_iter, epoch + 1, epochs)) print( 'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f' % (losses[0], losses[1], losses[2], losses[3], losses[4])) # unsupervised one shot learning else: for i in range(outer_iter): fnames = load_oneshot_batchfnames(stylename, batchsize, datasize) for epoch in range(epochs): for fname in fnames: # no ground truth x provided _, y_real, _ = prepare_batch(fname, 3, 0, 0, 0, opts.gpu) Lsrec = tetGAN.update_style_autoencoder(y_real[0]) for fname in fnames: # no ground truth x provided _, y_real, y = prepare_batch(fname, 3, 0, 0, 0, opts.gpu) with torch.no_grad(): x_auxiliary = tetGAN.desty_forward(y_real[0]) losses = tetGAN.one_pass(x_auxiliary, None, y[0], None, y_real[0], None, 3, 0) print('Iter[%d/%d], Epoch [%d/%d]' % (i + 1, outer_iter, epoch + 1, epochs)) print( 'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f, Lsrec: %.3f' % (losses[0], losses[1], losses[2], losses[3], losses[4], Lsrec)) print('--- save ---') torch.save(tetGAN.state_dict(), opts.save_model_name)
def main(): # parse options parser = TrainOptions() opts = parser.parse() # data loader print('--- load parameter ---') outer_iter = opts.outer_iter fade_iter = max(1.0, float(outer_iter / 2)) epochs = opts.epoch batchsize = opts.batchsize datasize = opts.datasize datarange = opts.datarange augementratio = opts.augementratio centercropratio = opts.centercropratio # model print('--- create model ---') tetGAN = TETGAN(gpu=(opts.gpu != 0)) if opts.gpu != 0: tetGAN.cuda() tetGAN.init_networks(weights_init) tetGAN.train() print('--- training ---') stylenames = os.listdir(opts.train_path) print('List of %d styles:' % (len(stylenames)), *stylenames, sep=' ') if opts.progressive == 1: # proressive training. From level1 64*64, to level2 128*128, to level3 256*256 # level 1 for i in range(outer_iter): jitter = min(1.0, i / fade_iter) fnames = load_trainset_batchfnames(opts.train_path, batchsize * 4, datarange, datasize * 2) for epoch in range(epochs): for fname in fnames: x, y_real, y = prepare_batch(fname, 1, jitter, centercropratio, augementratio, opts.gpu) losses = tetGAN.one_pass(x[0], None, y[0], None, y_real[0], None, 1, None) print('Level1, Iter[%d/%d], Epoch [%d/%d]' % (i + 1, outer_iter, epoch + 1, epochs)) print( 'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f' % (losses[0], losses[1], losses[2], losses[3], losses[4])) # level 2 for i in range(outer_iter): w = max(0.0, 1 - i / fade_iter) fnames = load_trainset_batchfnames(opts.train_path, batchsize * 2, datarange, datasize * 2) for epoch in range(epochs): for fname in fnames: x, y_real, y = prepare_batch(fname, 2, 1, centercropratio, augementratio, opts.gpu) losses = tetGAN.one_pass(x[0], x[1], y[0], y[1], y_real[0], y_real[1], 2, w) print('Level2, Iter[%d/%d], Epoch [%d/%d]' % (i + 1, outer_iter, epoch + 1, epochs)) print( 'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f' % (losses[0], losses[1], losses[2], losses[3], losses[4])) # level 3 for i in range(outer_iter): w = max(0.0, 1 - i / fade_iter) fnames = load_trainset_batchfnames(opts.train_path, batchsize, datarange, datasize) for epoch in range(epochs): for fname in fnames: x, y_real, y = prepare_batch(fname, 3, 1, centercropratio, augementratio, opts.gpu) losses = tetGAN.one_pass(x[0], x[1], y[0], y[1], y_real[0], y_real[1], 3, w) print('Level3, Iter[%d/%d], Epoch [%d/%d]' % (i + 1, outer_iter, epoch + 1, epochs)) print( 'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f' % (losses[0], losses[1], losses[2], losses[3], losses[4])) else: # directly train on level3 256*256 for i in range(outer_iter): fnames = load_trainset_batchfnames(opts.train_path, batchsize, datarange, datasize) for epoch in range(epochs): for fname in fnames: x, y_real, y = prepare_batch(fname, 3, 1, centercropratio, augementratio, opts.gpu) losses = tetGAN.one_pass(x[0], None, y[0], None, y_real[0], None, 3, 0) print('Iter[%d/%d], Epoch [%d/%d]' % (i + 1, outer_iter, epoch + 1, epochs)) print( 'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f' % (losses[0], losses[1], losses[2], losses[3], losses[4])) print('--- save ---') torch.save(tetGAN.state_dict(), opts.save_model_name)
def main(): # parse options parser = TrainOptions() opts = parser.parse() # data loader print('--- load parameter ---') # outer_iter = opts.outer_iter # fade_iter = max(1.0, float(outer_iter / 2)) epochs = opts.epoch batchsize = opts.batchsize # datasize = opts.datasize # datarange = opts.datarange augementratio = opts.augementratio centercropratio = opts.centercropratio # model print('--- create model ---') tetGAN = TETGAN(gpu=(opts.gpu != 0)) if opts.gpu != 0: tetGAN.cuda() tetGAN.init_networks(weights_init) num_params = 0 for param in tetGAN.parameters(): num_params += param.numel() print('Total number of parameters in TET-GAN: %.3f M' % (num_params / 1e6)) print('--- training ---') texture_class = 'base_gray_texture' in opts.dataset_class or 'skeleton_gray_texture' in opts.dataset_class if texture_class: tetGAN.load_state_dict(torch.load(opts.model)) dataset_path = os.path.join(opts.train_path, opts.dataset_class, 'style') val_dataset_path = os.path.join(opts.train_path, opts.dataset_class, 'val') if 'base_gray_texture' in opts.dataset_class: few_size = 6 batchsize = 2 epochs = 1500 elif 'skeleton_gray_texture' in opts.dataset_class: few_size = 30 batchsize = 10 epochs = 300 fnames = load_trainset_batchfnames_dualnet(dataset_path, batchsize, few_size=few_size) val_fnames = sorted(os.listdir(val_dataset_path)) style_fnames = sorted(os.listdir(dataset_path)[:few_size]) else: dataset_path = os.path.join(opts.train_path, opts.dataset_class, 'train') fnames = load_trainset_batchfnames_dualnet(dataset_path, batchsize) tetGAN.train() train_size = os.listdir(dataset_path) print('List of %d styles:' % (len(train_size))) result_dir = os.path.join(opts.result_dir, opts.dataset_class) if not os.path.exists(result_dir): os.mkdir(result_dir) for epoch in range(epochs): for idx, fname in enumerate(fnames): x, y_real, y = prepare_batch(fname, 1, 1, centercropratio, augementratio, opts.gpu) losses = tetGAN.one_pass(x[0], None, y[0], None, y_real[0], None, 1, 0) if (idx + 1) % 100 == 0: print('Epoch [%d/%d], Iter [%d/%d]' % (epoch + 1, epochs, idx + 1, len(fnames))) print( 'Lrec: %.3f, Ldadv: %.3f, Ldesty: %.3f, Lsadv: %.3f, Lsty: %.3f' % (losses[0], losses[1], losses[2], losses[3], losses[4])) if texture_class and ((epoch + 1) % (epochs / 20)) == 0: outname = 'save/' + 'val_epoch' + str( epoch + 1) + '_' + opts.dataset_class + '_' + opts.save_model_name print('--- save model Epoch [%d/%d] ---' % (epoch + 1, epochs)) torch.save(tetGAN.state_dict(), outname) print('--- validating model [%d/%d] ---' % (epoch + 1, epochs)) for val_idx, val_fname in enumerate(val_fnames): v_fname = os.path.join(val_dataset_path, val_fname) random.shuffle(style_fnames) sty_fname = style_fnames[0] s_fname = os.path.join(dataset_path, sty_fname) with torch.no_grad(): val_content = load_image_dualnet(v_fname, load_type=1) val_sty = load_image_dualnet(s_fname, load_type=0) if opts.gpu != 0: val_content = val_content.cuda() val_sty = val_sty.cuda() result = tetGAN(val_content, val_sty) if opts.gpu != 0: result = to_data(result) result_filename = os.path.join( result_dir, str(epoch) + '_' + val_fname) print(result_filename) save_image(result[0], result_filename) elif not texture_class and ((epoch + 1) % 2) == 0: outname = 'save/' + 'epoch' + str(epoch + 1) + '_' + opts.save_model_name print('--- save model ---') torch.save(tetGAN.state_dict(), outname)