def train(epoch): torch.autograd.set_detect_anomaly(True) t = time.time() model.train() optimizer.zero_grad() output = model(features, train_pos_adj) # output = model.encode(features, train_pos_edge_index) # loss & optim if args.link: loss_train = recon_loss(output, train_pos_edge_index) # TODO: add variational loss else: loss_train = F.nll_loss(output[idx_train], labels[idx_train]) acc_train = accuracy(output[idx_train], labels[idx_train]) loss_train.backward() optimizer.step() # if use_nmf: # Ws_new, Hs_new = nmf_optim(X, Ws, Hs) # state_dict = model.state_dict() # state_dict['W_list.weight'] = Ws_new # state_dict['H_list.weight'] = Hs_new # model.load_state_dict(state_dict) # loss_train_nmf = nmf_loss(X, Ws, Hs) # loss_train_nmf.backward() # optim_nmf.step() if not args.fastmode: # Evaluate validation set performance separately, # deactivates dropout during validation run. model.eval() with torch.no_grad(): output = model(features, adj) # output = model.encode(features, train_pos_edge_index) # TODO: autoencoder validation loss_val = F.nll_loss(output[idx_val], labels[idx_val]) acc_val = accuracy(output[idx_val], labels[idx_val]) print( 'Epoch: {:04d}'.format(epoch + 1), 'loss_train: {:.4f}'.format(loss_train.data.item()), # 'acc_train: {:.4f}'.format(acc_train.data.item()), # 'loss_val: {:.4f}'.format(loss_val.data.item()), # 'acc_val: {:.4f}'.format(acc_val.data.item()), 'time: {:.4f}s'.format(time.time() - t)) return loss_val.data.item()
def train(args): # Settings b, c, h, w = 1, 3, 256, 256 beta1 = 0.5 beta2 = 0.999 pool_size = 50 lambda_recon = args.lambda_recon lambda_idt = args.lambda_idt base_lr = args.learning_rate init_method = args.init_method # Context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = get_extension_context(extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Inputs x_raw = nn.Variable([b, c, h, w], need_grad=False) y_raw = nn.Variable([b, c, h, w], need_grad=False) x_real = image_augmentation(x_raw) y_real = image_augmentation(y_raw) x_history = nn.Variable([b, c, h, w]) y_history = nn.Variable([b, c, h, w]) x_real_test = nn.Variable([b, c, h, w], need_grad=False) y_real_test = nn.Variable([b, c, h, w], need_grad=False) # Models for training # Generate y_fake = models.g(x_real, unpool=args.unpool, init_method=init_method) x_fake = models.f(y_real, unpool=args.unpool, init_method=init_method) y_fake.persistent, x_fake.persistent = True, True # Reconstruct x_recon = models.f(y_fake, unpool=args.unpool, init_method=init_method) y_recon = models.g(x_fake, unpool=args.unpool, init_method=init_method) # Discriminate d_y_fake = models.d_y(y_fake, init_method=init_method) d_x_fake = models.d_x(x_fake, init_method=init_method) d_y_real = models.d_y(y_real, init_method=init_method) d_x_real = models.d_x(x_real, init_method=init_method) d_y_history = models.d_y(y_history, init_method=init_method) d_x_history = models.d_x(x_history, init_method=init_method) # Models for test y_fake_test = models.g( x_real_test, unpool=args.unpool, init_method=init_method) x_fake_test = models.f( y_real_test, unpool=args.unpool, init_method=init_method) y_fake_test.persistent, x_fake_test.persistent = True, True # Reconstruct x_recon_test = models.f( y_fake_test, unpool=args.unpool, init_method=init_method) y_recon_test = models.g( x_fake_test, unpool=args.unpool, init_method=init_method) # Losses # Reconstruction Loss loss_recon = models.recon_loss(x_recon, x_real) \ + models.recon_loss(y_recon, y_real) # Generator loss loss_gen = models.lsgan_loss(d_y_fake) \ + models.lsgan_loss(d_x_fake) \ + lambda_recon * loss_recon # Identity loss if lambda_idt != 0: logger.info("Identity loss was added.") # Identity y_idt = models.g(y_real, unpool=args.unpool, init_method=init_method) x_idt = models.f(x_real, unpool=args.unpool, init_method=init_method) loss_idt = models.recon_loss(x_idt, x_real) \ + models.recon_loss(y_idt, y_real) loss_gen += lambda_recon * lambda_idt * loss_idt # Discriminator losses loss_dis_y = models.lsgan_loss(d_y_history, d_y_real) loss_dis_x = models.lsgan_loss(d_x_history, d_x_real) # Solvers solver_gen = S.Adam(base_lr, beta1, beta2) solver_dis_x = S.Adam(base_lr, beta1, beta2) solver_dis_y = S.Adam(base_lr, beta1, beta2) with nn.parameter_scope('generator'): solver_gen.set_parameters(nn.get_parameters()) with nn.parameter_scope('discriminator'): with nn.parameter_scope("x"): solver_dis_x.set_parameters(nn.get_parameters()) with nn.parameter_scope("y"): solver_dis_y.set_parameters(nn.get_parameters()) # Datasets rng = np.random.RandomState(313) ds_train_B = cycle_gan_data_source( args.dataset, train=True, domain="B", shuffle=True, rng=rng) ds_train_A = cycle_gan_data_source( args.dataset, train=True, domain="A", shuffle=True, rng=rng) ds_test_B = cycle_gan_data_source( args.dataset, train=False, domain="B", shuffle=False, rng=rng) ds_test_A = cycle_gan_data_source( args.dataset, train=False, domain="A", shuffle=False, rng=rng) di_train_B = cycle_gan_data_iterator(ds_train_B, args.batch_size) di_train_A = cycle_gan_data_iterator(ds_train_A, args.batch_size) di_test_B = cycle_gan_data_iterator(ds_test_B, args.batch_size) di_test_A = cycle_gan_data_iterator(ds_test_A, args.batch_size) # Monitors monitor = Monitor(args.monitor_path) def make_monitor(name): return MonitorSeries(name, monitor, interval=1) monitor_loss_gen = make_monitor('generator_loss') monitor_loss_dis_x = make_monitor('discriminator_B_domain_loss') monitor_loss_dis_y = make_monitor('discriminator_A_domain_loss') def make_monitor_image(name): return MonitorImage(name, monitor, interval=1, normalize_method=lambda x: (x + 1.0) * 127.5) monitor_train_gx = make_monitor_image('fake_images_train_A') monitor_train_fy = make_monitor_image('fake_images_train_B') monitor_train_x_recon = make_monitor_image('fake_images_B_recon_train') monitor_train_y_recon = make_monitor_image('fake_images_A_recon_train') monitor_test_gx = make_monitor_image('fake_images_test_A') monitor_test_fy = make_monitor_image('fake_images_test_B') monitor_test_x_recon = make_monitor_image('fake_images_recon_test_B') monitor_test_y_recon = make_monitor_image('fake_images_recon_test_A') monitor_train_list = [ (monitor_train_gx, y_fake), (monitor_train_fy, x_fake), (monitor_train_x_recon, x_recon), (monitor_train_y_recon, y_recon), (monitor_loss_gen, loss_gen), (monitor_loss_dis_x, loss_dis_x), (monitor_loss_dis_y, loss_dis_y), ] monitor_test_list = [ (monitor_test_gx, y_fake_test), (monitor_test_fy, x_fake_test), (monitor_test_x_recon, x_recon_test), (monitor_test_y_recon, y_recon_test)] # ImagePool pool_x = ImagePool(pool_size) pool_y = ImagePool(pool_size) # Training loop epoch = 0 n_images = np.max([ds_train_B.size, ds_train_A.size] ) # num. images for each domain max_iter = args.max_epoch * n_images // args.batch_size for i in range(max_iter): # Validation if int((i+1) % (n_images // args.batch_size)) == 0: logger.info("Mode:Test,Epoch:{}".format(epoch)) # Monitor for train for monitor, v in monitor_train_list: monitor.add(i, v.d) # Use training graph since there are no test mode x_data, _ = di_test_B.next() y_data, _ = di_test_A.next() x_real_test.d = x_data y_real_test.d = y_data x_recon_test.forward() y_recon_test.forward() # Monitor for test for monitor, v in monitor_test_list: monitor.add(i, v.d) # Save model nn.save_parameters(os.path.join( args.model_save_path, 'params_%06d.h5' % i)) # Learning rate decay for solver in [solver_gen, solver_dis_x, solver_dis_y]: linear_decay(solver, base_lr, epoch, args.max_epoch) epoch += 1 # Get data x_data, _ = di_train_B.next() y_data, _ = di_train_A.next() x_raw.d = x_data y_raw.d = y_data # Train Generators loss_gen.forward(clear_no_need_grad=False) solver_gen.zero_grad() loss_gen.backward(clear_buffer=True) solver_gen.update() # Insert and Get to/from pool x_history.d = pool_x.insert_then_get(x_fake.d) y_history.d = pool_y.insert_then_get(y_fake.d) # Train Discriminator Y loss_dis_y.forward(clear_no_need_grad=False) solver_dis_y.zero_grad() loss_dis_y.backward(clear_buffer=True) solver_dis_y.update() # Train Discriminator X loss_dis_x.forward(clear_no_need_grad=False) solver_dis_x.zero_grad() loss_dis_x.backward(clear_buffer=True) solver_dis_x.update()
def train(args): # Create Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank mpi_local_rank = comm.local_rank device_id = mpi_local_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # Input b, c, h, w = args.batch_size, 3, args.image_size, args.image_size x_real_a = nn.Variable([b, c, h, w]) x_real_b = nn.Variable([b, c, h, w]) # Model # workaround for starting with the same model among devices. np.random.seed(412) maps = args.maps # within-domain reconstruction (domain A) x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a") x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a") x_recon_a = decoder(x_content_a, x_style_a, name="decoder-a") # within-domain reconstruction (domain B) x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b") x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b") x_recon_b = decoder(x_content_b, x_style_b, name="decoder-b") # generate over domains and reconstruction of content and style (domain A) z_style_a = F.randn(shape=x_style_a.shape) x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a") x_content_rec_b = content_encoder(x_fake_a, maps, name="content-encoder-a") x_style_rec_a = style_encoder(x_fake_a, maps, name="style-encoder-a") # generate over domains and reconstruction of content and style (domain B) z_style_b = F.randn(shape=x_style_b.shape) x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b") x_content_rec_a = content_encoder(x_fake_b, maps, name="content-encoder-b") x_style_rec_b = style_encoder(x_fake_b, maps, name="style-encoder-b") # discriminate (domain A) p_x_fake_a_list = discriminators(x_fake_a) p_x_real_a_list = discriminators(x_real_a) p_x_fake_b_list = discriminators(x_fake_b) p_x_real_b_list = discriminators(x_real_b) # Loss # within-domain reconstruction loss_recon_x_a = recon_loss(x_recon_a, x_real_a).apply(persistent=True) loss_recon_x_b = recon_loss(x_recon_b, x_real_b).apply(persistent=True) # content and style reconstruction loss_recon_x_style_a = recon_loss(x_style_rec_a, z_style_a).apply(persistent=True) loss_recon_x_content_b = recon_loss(x_content_rec_b, x_content_b).apply(persistent=True) loss_recon_x_style_b = recon_loss(x_style_rec_b, z_style_b).apply(persistent=True) loss_recon_x_content_a = recon_loss(x_content_rec_a, x_content_a).apply(persistent=True) # adversarial def f(x, y): return x + y loss_gen_a = reduce(f, [lsgan_loss(p_f) for p_f in p_x_fake_a_list]).apply(persistent=True) loss_dis_a = reduce(f, [ lsgan_loss(p_f, p_r) for p_f, p_r in zip(p_x_fake_a_list, p_x_real_a_list) ]).apply(persistent=True) loss_gen_b = reduce(f, [lsgan_loss(p_f) for p_f in p_x_fake_b_list]).apply(persistent=True) loss_dis_b = reduce(f, [ lsgan_loss(p_f, p_r) for p_f, p_r in zip(p_x_fake_b_list, p_x_real_b_list) ]).apply(persistent=True) # loss for generator-related models loss_gen = loss_gen_a + loss_gen_b \ + args.lambda_x * (loss_recon_x_a + loss_recon_x_b) \ + args.lambda_c * (loss_recon_x_content_a + loss_recon_x_content_b) \ + args.lambda_s * (loss_recon_x_style_a + loss_recon_x_style_b) # loss for discriminators loss_dis = loss_dis_a + loss_dis_b # Solver lr_g, lr_d, beta1, beta2 = args.lr_g, args.lr_d, args.beta1, args.beta2 # solver for generator-related models solver_gen = S.Adam(lr_g, beta1, beta2) with nn.parameter_scope("generator"): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) # solver for discriminators solver_dis = S.Adam(lr_d, beta1, beta2) with nn.parameter_scope("discriminators"): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor monitor = Monitor(args.monitor_path) # time monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10) # reconstruction monitor_loss_recon_x_a = MonitorSeries("Recon Loss Image A", monitor, interval=10) monitor_loss_recon_x_content_b = MonitorSeries("Recon Loss Content B", monitor, interval=10) monitor_loss_recon_x_style_a = MonitorSeries("Recon Loss Style A", monitor, interval=10) monitor_loss_recon_x_b = MonitorSeries("Recon Loss Image B", monitor, interval=10) monitor_loss_recon_x_content_a = MonitorSeries("Recon Loss Content A", monitor, interval=10) monitor_loss_recon_x_style_b = MonitorSeries("Recon Loss Style B", monitor, interval=10) # adversarial monitor_loss_gen_a = MonitorSeries("Gen Loss A", monitor, interval=10) monitor_loss_dis_a = MonitorSeries("Dis Loss A", monitor, interval=10) monitor_loss_gen_b = MonitorSeries("Gen Loss B", monitor, interval=10) monitor_loss_dis_b = MonitorSeries("Dis Loss B", monitor, interval=10) monitor_losses = [ # reconstruction (monitor_loss_recon_x_a, loss_recon_x_a), (monitor_loss_recon_x_content_b, loss_recon_x_content_b), (monitor_loss_recon_x_style_a, loss_recon_x_style_a), (monitor_loss_recon_x_b, loss_recon_x_b), (monitor_loss_recon_x_content_a, loss_recon_x_content_a), (monitor_loss_recon_x_style_b, loss_recon_x_style_b), # adaversarial (monitor_loss_gen_a, loss_gen_a), (monitor_loss_dis_a, loss_dis_a), (monitor_loss_gen_b, loss_gen_b), (monitor_loss_dis_b, loss_dis_b) ] # image monitor_image_a = MonitorImage("Fake Image B to A Train", monitor, interval=1) monitor_image_b = MonitorImage("Fake Image A to B Train", monitor, interval=1) monitor_images = [ (monitor_image_a, x_fake_a), (monitor_image_b, x_fake_b), ] # DataIterator rng_a = np.random.RandomState(device_id) rng_b = np.random.RandomState(device_id + n_devices) di_a = munit_data_iterator(args.img_path_a, args.batch_size, rng=rng_a) di_b = munit_data_iterator(args.img_path_b, args.batch_size, rng=rng_b) # Train for i in range(args.max_iter // n_devices): ii = i * n_devices # Train generator-related models x_data_a, x_data_b = di_a.next()[0], di_b.next()[0] x_real_a.d, x_real_b.d = x_data_a, x_data_b solver_gen.zero_grad() loss_gen.forward(clear_no_need_grad=True) loss_gen.backward(clear_buffer=True) comm.all_reduce([w.grad for w in params_gen.values()]) solver_gen.weight_decay(args.weight_decay_rate) solver_gen.update() # Train discriminators x_data_a, x_data_b = di_a.next()[0], di_b.next()[0] x_real_a.d, x_real_b.d = x_data_a, x_data_b x_fake_a.need_grad, x_fake_b.need_grad = False, False solver_dis.zero_grad() loss_dis.forward(clear_no_need_grad=True) loss_dis.backward(clear_buffer=True) comm.all_reduce([w.grad for w in params_dis.values()]) solver_dis.weight_decay(args.weight_decay_rate) solver_dis.update() x_fake_a.need_grad, x_fake_b.need_grad = True, True # LR schedule if (i + 1) % (args.lr_decay_at_every // n_devices) == 0: lr_d = solver_dis.learning_rate() * args.lr_decay_rate lr_g = solver_gen.learning_rate() * args.lr_decay_rate solver_dis.set_learning_rate(lr_d) solver_gen.set_learning_rate(lr_g) if mpi_local_rank == 0: # Monitor monitor_time.add(ii) for mon, loss in monitor_losses: mon.add(ii, loss.d) # Save if (i + 1) % (args.model_save_interval // n_devices) == 0: for mon, x in monitor_images: mon.add(ii, x.d) nn.save_parameters( os.path.join(args.monitor_path, "param_{:05d}.h5".format(i))) if mpi_local_rank == 0: # Monitor for mon, loss in monitor_losses: mon.add(ii, loss.d) # Save for mon, x in monitor_images: mon.add(ii, x.d) nn.save_parameters( os.path.join(args.monitor_path, "param_{:05d}.h5".format(i)))