def train_epoch(model, data_loader, optimizer, scheduler, epoch, config): metrics = { "images": Last(), "loss": Mean(), "lr": Last(), } # loop over batches ################################################################################################ model.train() for images, meta, targets in tqdm( data_loader, desc="fold {}, epoch {}/{}, train".format(config.fold, epoch, config.train.epochs), ): images, meta, targets = ( images.to(DEVICE), {k: meta[k].to(DEVICE) for k in meta}, targets.to(DEVICE), ) # images, targets = mix_up(images, targets, alpha=1.) logits = model(images, meta) loss = compute_loss(input=logits, target=targets, config=config) metrics["images"].update(images.data.cpu()) metrics["loss"].update(loss.data.cpu()) metrics["lr"].update(np.squeeze(scheduler.get_last_lr())) optimizer.zero_grad() loss.mean().backward() optimizer.step() scheduler.step() # compute metrics ################################################################################################## with torch.no_grad(): metrics = {k: metrics[k].compute_and_reset() for k in metrics} writer = SummaryWriter(os.path.join(config.experiment_path, "train")) writer.add_image( "images", torchvision.utils.make_grid( metrics["images"], nrow=compute_nrow(metrics["images"]), normalize=True ), global_step=epoch, ) writer.add_scalar("loss", metrics["loss"], global_step=epoch) writer.add_scalar("lr", metrics["lr"], global_step=epoch) writer.flush() writer.close()
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config): metrics = { "loss": Mean(), "lr": Last(), "p/norm": Mean(), "z/norm": Mean(), "p/std": Mean(), "z/std": Mean(), } model.train() for images1, images2 in tqdm( data_loader, desc="epoch {}/{}, train".format(epoch, config.epochs), ): images1, images2 = images1.to(DEVICE), images2.to(DEVICE) p1, z1 = model(images1) p2, z2 = model(images2) loss = (sum([ compute_loss(p=p1, z=z2.detach()), compute_loss(p=p2, z=z1.detach()), ]) / 2) metrics["loss"].update(loss.detach()) metrics["lr"].update(np.squeeze(scheduler.get_last_lr())) for k, v in [("p", p1), ("p", p2), ("z", z1), ("z", z2)]: v = v.detach() metrics["{}/norm".format(k)].update(v.norm(dim=1)) v = F.normalize(v, dim=1) metrics["{}/std".format(k)].update(v.std(dim=0)) optimizer.zero_grad() loss.mean().backward() optimizer.step() scheduler.step() writer = SummaryWriter(os.path.join(config.experiment_path, "train")) with torch.no_grad(): images = torch.cat([images1, images2], 3) metrics = {k: metrics[k].compute_and_reset() for k in metrics} for k in metrics: writer.add_scalar(k, metrics[k], global_step=epoch) writer.add_image( "images", torchvision.utils.make_grid(images[:16], nrow=1, normalize=True), global_step=epoch, ) writer.flush() writer.close()
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config): metrics = { "loss": Mean(), "lr": Last(), } model.train() for images, targets in tqdm(data_loader, desc="epoch {}/{}, train".format( epoch, config.epochs)): images, targets = images.to(DEVICE), targets.to(DEVICE) logits = model(images) loss = F.cross_entropy(input=logits, target=targets, reduction="none") metrics["loss"].update(loss.data.cpu().numpy()) metrics["lr"].update(np.squeeze(scheduler.get_lr())) optimizer.zero_grad() loss.mean().backward() optimizer.step() scheduler.step() writer = SummaryWriter(os.path.join(config.experiment_path, "train")) with torch.no_grad(): for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "images", torchvision.utils.make_grid(denormalize(images), nrow=compute_nrow(images), normalize=True), global_step=epoch, ) writer.add_histogram("params", flatten_weights(model.parameters()), global_step=epoch) writer.flush() writer.close()
def train_epoch(model, optimizer, scheduler, data_loader, box_coder, class_names, epoch, config): metrics = { "loss": Mean(), "loss/class": Mean(), "loss/loc": Mean(), "loss/cent": Mean(), "learning_rate": Last(), } model.train() optimizer.zero_grad() for i, batch in tqdm(enumerate(data_loader, 1), desc="epoch {} train".format(epoch), total=len(data_loader)): images, targets, dets_true = apply_recursively(lambda x: x.to(DEVICE), batch) output = model(images) loss_dict = compute_loss(input=output, target=targets) loss = sum(loss_dict.values()) metrics["loss"].update(loss.data.cpu()) for k in loss_dict: metrics["loss/{}".format(k)].update(loss_dict[k].data.cpu()) metrics["learning_rate"].update(np.squeeze(scheduler.get_lr())) (loss.mean() / config.train.acc_steps).backward() if i % config.train.acc_steps == 0: optimizer.step() optimizer.zero_grad() scheduler.step() with torch.no_grad(): metrics = {k: metrics[k].compute_and_reset() for k in metrics} writer = SummaryWriter(os.path.join(config.experiment_path, "train")) for k in metrics: writer.add_scalar(k, metrics[k], global_step=epoch) images = denormalize(images, mean=MEAN, std=STD) dets_true = [ box_coder.decode(foreground_binary_coding(c, 80), r, s, images.size()[2:]) for c, r, s in zip(*targets) ] dets_pred = [ box_coder.decode(c.sigmoid(), r, s.sigmoid(), images.size()[2:]) for c, r, s in zip(*output) ] true = [ draw_boxes(i, d, class_names) for i, d in zip(images, dets_true) ] pred = [ draw_boxes(i, d, class_names) for i, d in zip(images, dets_pred) ] writer.add_image("detections/true", torchvision.utils.make_grid(true, nrow=4), global_step=epoch) writer.add_image("detections/pred", torchvision.utils.make_grid(pred, nrow=4), global_step=epoch) writer.flush() writer.close()
def main(): args = build_parser().parse_args() config = build_default_config() config.merge_from_file(args.config_path) config.experiment_path = args.experiment_path config.render = not args.no_render config.freeze() del args writer = SummaryWriter(config.experiment_path) seed_torch(config.seed) env = VecEnv([lambda: build_env(config) for _ in range(config.workers)]) if config.render: env = wrappers.TensorboardBatchMonitor(env, writer, config.log_interval) env = wrappers.torch.Torch(env, device=DEVICE) env.seed(config.seed) policy_model = ModelDQN(config.model, env.observation_space, env.action_space).to(DEVICE) target_model = ModelDQN(config.model, env.observation_space, env.action_space).to(DEVICE) target_model.load_state_dict(policy_model.state_dict()) optimizer = build_optimizer(config.opt, policy_model.parameters()) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.episodes) metrics = { "loss": Mean(), "lr": Last(), "eps": FPS(), "ep/length": Mean(), "ep/reward": Mean(), } # ================================================================================================================== # training loop policy_model.train() target_model.eval() episode = 0 s = env.reset() e_base = 0.95 e_step = np.exp(np.log(0.05 / e_base) / config.episodes) bar = tqdm(total=config.episodes, desc="training") history = History() while episode < config.episodes: with torch.no_grad(): for _ in range(config.horizon): av = policy_model(s) a = sample_action(av, e_base * e_step ** episode) s_prime, r, d, meta = env.step(a) history.append( state=s.cpu(), action=a.cpu(), reward=r.cpu(), done=d.cpu(), state_prime=s_prime.cpu(), ) # history.append(state=s, action=a, reward=r, done=d, state_prime=s_prime) s = s_prime (indices,) = torch.where(d) for i in indices: metrics["eps"].update(1) metrics["ep/length"].update(meta[i]["episode"]["l"]) metrics["ep/reward"].update(meta[i]["episode"]["r"]) episode += 1 scheduler.step() bar.update(1) if episode % 10 == 0: target_model.load_state_dict(policy_model.state_dict()) if episode % config.log_interval == 0 and episode > 0: for k in metrics: writer.add_scalar( k, metrics[k].compute_and_reset(), global_step=episode ) writer.add_scalar("e", e_base * e_step ** episode, global_step=episode) writer.add_histogram( "rollout/action", rollout.actions, global_step=episode ) writer.add_histogram( "rollout/reward", rollout.rewards, global_step=episode ) writer.add_histogram("rollout/return", returns, global_step=episode) writer.add_histogram( "rollout/action_value", action_values, global_step=episode ) rollout = history.full_rollout() action_values = policy_model(rollout.states) action_values = action_values * one_hot(rollout.actions, action_values.size(-1)) action_values = action_values.sum(-1) with torch.no_grad(): action_values_prime = target_model(rollout.states_prime) action_values_prime, _ = action_values_prime.detach().max(-1) returns = one_step_discounted_return( rollout.rewards, action_values_prime, rollout.dones, gamma=config.gamma ) # critic errors = returns - action_values critic_loss = errors ** 2 loss = (critic_loss * 0.5).mean(1) metrics["loss"].update(loss.data.cpu().numpy()) metrics["lr"].update(np.squeeze(scheduler.get_lr())) # training optimizer.zero_grad() loss.mean().backward() nn.utils.clip_grad_norm_(policy_model.parameters(), 0.5) optimizer.step() bar.close() env.close()
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config): metrics = { "loss/x": Mean(), "loss/u": Mean(), "weight/u": Last(), "lr": Last(), } model.train() for (images_x, targets_x), (images_u_0, images_u_1) in tqdm( data_loader, desc="epoch {}/{}, train".format(epoch, config.epochs)): # prepare data ################################################################################################# images_x, targets_x, images_u_0, images_u_1 = ( images_x.to(DEVICE), targets_x.to(DEVICE), images_u_0.to(DEVICE), images_u_1.to(DEVICE), ) targets_x = one_hot(targets_x, NUM_CLASSES) # mix-match #################################################################################################### with torch.no_grad(): (images_x, targets_x), (images_u, targets_u) = mix_match(x=(images_x, targets_x), u=(images_u_0, images_u_1), model=model, config=config) probs_x, probs_u = model(torch.cat([images_x, images_u])).split( [images_x.size(0), images_u.size(0)]) # x ############################################################################################################ loss_x = compute_loss_x(input=probs_x, target=targets_x) metrics["loss/x"].update(loss_x.data.cpu().numpy()) # u ############################################################################################################ loss_u = compute_loss_u(input=probs_u, target=targets_u) metrics["loss/u"].update(loss_u.data.cpu().numpy()) # opt step ##################################################################################################### metrics["lr"].update(np.squeeze(scheduler.get_last_lr())) weight_u = config.train.mix_match.weight_u * min( (epoch - 1) / config.epochs_warmup, 1.0) metrics["weight/u"].update(weight_u) optimizer.zero_grad() (loss_x.mean() + weight_u * loss_u.mean()).backward() optimizer.step() scheduler.step() if epoch % config.log_interval != 0: return writer = SummaryWriter(os.path.join(config.experiment_path, "train")) with torch.no_grad(): for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "images_x", torchvision.utils.make_grid(images_x, nrow=compute_nrow(images_x), normalize=True), global_step=epoch, ) writer.add_image( "images_u", torchvision.utils.make_grid(images_u, nrow=compute_nrow(images_u), normalize=True), global_step=epoch, ) writer.flush() writer.close()
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config): metrics = { "loss": Mean(), "lr": Last(), } model.train() for (text, text_mask), (audio, audio_mask) in tqdm( data_loader, desc="epoch {}/{}, train".format(epoch, config.train.epochs)): text, audio, text_mask, audio_mask = [ x.to(DEVICE) for x in [text, audio, text_mask, audio_mask] ] output, pre_output, target, target_mask, weight = model( text, text_mask, audio, audio_mask) loss = masked_mse(output, target, target_mask) + masked_mse( pre_output, target, target_mask) metrics["loss"].update(loss.data.cpu()) metrics["lr"].update(np.squeeze(scheduler.get_last_lr())) optimizer.zero_grad() loss.mean().backward() if config.train.clip_grad_norm is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.clip_grad_norm) optimizer.step() scheduler.step() writer = SummaryWriter(os.path.join(config.experiment_path, "train")) with torch.no_grad(): gl_true = griffin_lim(target, model.spectra) gl_pred = griffin_lim(output, model.spectra) output, pre_output, target, weight = [ x.unsqueeze(1) for x in [output, pre_output, target, weight] ] nrow = compute_nrow(target) for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "target", torchvision.utils.make_grid(target, nrow=nrow, normalize=True), global_step=epoch, ) writer.add_image( "output", torchvision.utils.make_grid(output, nrow=nrow, normalize=True), global_step=epoch, ) writer.add_image( "pre_output", torchvision.utils.make_grid(pre_output, nrow=nrow, normalize=True), global_step=epoch, ) writer.add_image( "weight", torchvision.utils.make_grid(weight, nrow=nrow, normalize=True), global_step=epoch, ) for i in tqdm(range(min(text.size(0), 4)), desc="writing audio"): writer.add_audio("audio/{}".format(i), audio[i], sample_rate=config.sample_rate, global_step=epoch) writer.add_audio( "griffin-lim-true/{}".format(i), gl_true[i], sample_rate=config.sample_rate, global_step=epoch, ) writer.add_audio( "griffin-lim-pred/{}".format(i), gl_pred[i], sample_rate=config.sample_rate, global_step=epoch, ) writer.flush() writer.close()
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config): writer = SummaryWriter(os.path.join(config.experiment_path, "train")) metrics = { "loss": Mean(), "lr": Last(), } model.train() for images, targets in tqdm( data_loader, desc="epoch {}/{}, train".format(epoch, config.epochs) ): images, targets = images.to(DEVICE), targets.to(DEVICE) targets = image_one_hot(targets, NUM_CLASSES) logits = model(images) loss = compute_loss(input=logits, target=targets) metrics["loss"].update(loss.data.cpu().numpy()) metrics["lr"].update(np.squeeze(scheduler.get_lr())) optimizer.zero_grad() loss.mean().backward() optimizer.step() scheduler.step() with torch.no_grad(): mask_true = F.interpolate(draw_masks(targets.argmax(1, keepdim=True)), scale_factor=1) mask_pred = F.interpolate(draw_masks(logits.argmax(1, keepdim=True)), scale_factor=1) for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "images", torchvision.utils.make_grid( denormalize(images), nrow=compute_nrow(images), normalize=True ), global_step=epoch, ) writer.add_image( "mask_true", torchvision.utils.make_grid(mask_true, nrow=compute_nrow(mask_true), normalize=True), global_step=epoch, ) writer.add_image( "mask_pred", torchvision.utils.make_grid(mask_pred, nrow=compute_nrow(mask_pred), normalize=True), global_step=epoch, ) writer.add_image( "images_true", torchvision.utils.make_grid( denormalize(images) + mask_true, nrow=compute_nrow(images), normalize=True ), global_step=epoch, ) writer.add_image( "images_pred", torchvision.utils.make_grid( denormalize(images) + mask_pred, nrow=compute_nrow(images), normalize=True ), global_step=epoch, ) writer.flush() writer.close()
def train_epoch(model, data_loader, optimizer, scheduler, epoch, config): metrics = { "x_loss": Mean(), "u_loss": Mean(), "u_loss_mask": Mean(), "lr": Last(), } model.train() for (x_w_images, x_targets), (u_w_images, u_s_images) in tqdm( data_loader, desc="epoch {}/{}, train".format(epoch, config.epochs) ): x_w_images, x_targets, u_w_images, u_s_images = ( x_w_images.to(DEVICE), x_targets.to(DEVICE), u_w_images.to(DEVICE), u_s_images.to(DEVICE), ) x_w_logits, u_w_logits, u_s_logits = model( torch.cat([x_w_images, u_w_images, u_s_images], 0) ).split([x_w_images.size(0), u_w_images.size(0), u_s_images.size(0)]) # x ############################################################################################################ x_loss = F.cross_entropy(input=x_w_logits, target=x_targets, reduction="none") metrics["x_loss"].update(x_loss.data.cpu().numpy()) # u ############################################################################################################ u_loss_mask, u_targets = F.softmax(u_w_logits.detach(), 1).max(1) u_loss_mask = (u_loss_mask >= config.train.tau).float() u_loss = u_loss_mask * F.cross_entropy( input=u_s_logits, target=u_targets, reduction="none" ) metrics["u_loss"].update(u_loss.data.cpu().numpy()) metrics["u_loss_mask"].update(u_loss_mask.data.cpu().numpy()) # opt step ##################################################################################################### metrics["lr"].update(np.squeeze(scheduler.get_lr())) optimizer.zero_grad() (x_loss.mean() + config.train.u_weight * u_loss.mean()).backward() optimizer.step() scheduler.step() if epoch % config.log_interval != 0: return writer = SummaryWriter(os.path.join(config.experiment_path, "train")) with torch.no_grad(): for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "x_w_images", torchvision.utils.make_grid( denormalize(x_w_images), nrow=compute_nrow(x_w_images), normalize=True ), global_step=epoch, ) writer.add_image( "u_w_images", torchvision.utils.make_grid( denormalize(u_w_images), nrow=compute_nrow(u_w_images), normalize=True ), global_step=epoch, ) writer.add_image( "u_s_images", torchvision.utils.make_grid( denormalize(u_s_images), nrow=compute_nrow(u_s_images), normalize=True ), global_step=epoch, ) writer.flush() writer.close()
def main(config_path, **kwargs): config = load_config(config_path, **kwargs) del config_path, kwargs writer = SummaryWriter(config.experiment_path) seed_torch(config.seed) env = VecEnv([lambda: build_env(config) for _ in range(config.workers)]) if config.render: env = wrappers.TensorboardBatchMonitor(env, writer, config.log_interval) env = wrappers.Torch(env, dtype=torch.float, device=DEVICE) env.seed(config.seed) model = Model(config.model, env.observation_space, env.action_space) model = model.to(DEVICE) if config.restore_path is not None: model.load_state_dict(torch.load(config.restore_path)) optimizer = build_optimizer(config.opt, model.parameters()) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, config.episodes) metrics = { "loss": Mean(), "lr": Last(), "eps": FPS(), "ep/length": Mean(), "ep/return": Mean(), "rollout/entropy": Mean(), } # ================================================================================================================== # training loop model.train() episode = 0 s = env.reset() bar = tqdm(total=config.episodes, desc="training") while episode < config.episodes: history = History() with torch.no_grad(): for _ in range(config.horizon): a, _ = model(s) a = a.sample() s_prime, r, d, info = env.step(a) history.append(state=s, action=a, reward=r, done=d, state_prime=s_prime) s = s_prime (indices, ) = torch.where(d) for i in indices: metrics["eps"].update(1) metrics["ep/length"].update(info[i]["episode"]["l"]) metrics["ep/return"].update(info[i]["episode"]["r"]) episode += 1 scheduler.step() bar.update(1) if episode % config.log_interval == 0 and episode > 0: for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=episode) writer.add_histogram("rollout/action", rollout.actions, global_step=episode) writer.add_histogram("rollout/reward", rollout.rewards, global_step=episode) writer.add_histogram("rollout/return", returns, global_step=episode) writer.add_histogram("rollout/value", values, global_step=episode) writer.add_histogram("rollout/advantage", advantages, global_step=episode) torch.save( model.state_dict(), os.path.join(config.experiment_path, "model_{}.pth".format(episode)), ) rollout = history.full_rollout() dist, values = model(rollout.states) with torch.no_grad(): _, value_prime = model(rollout.states_prime[:, -1]) returns = n_step_bootstrapped_return(rollout.rewards, value_prime, rollout.dones, discount=config.gamma) # critic errors = returns - values critic_loss = errors**2 # actor advantages = errors.detach() log_prob = dist.log_prob(rollout.actions) entropy = dist.entropy() if isinstance(env.action_space, gym.spaces.Box): log_prob = log_prob.sum(-1) entropy = entropy.sum(-1) assert log_prob.dim() == entropy.dim() == 2 actor_loss = -log_prob * advantages - config.entropy_weight * entropy # loss loss = (actor_loss + 0.5 * critic_loss).mean(1) metrics["loss"].update(loss.data.cpu().numpy()) metrics["lr"].update(np.squeeze(scheduler.get_lr())) metrics["rollout/entropy"].update(dist.entropy().data.cpu().numpy()) # training optimizer.zero_grad() loss.mean().backward() nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() bar.close() env.close()
def train_epoch(model, data_loader, opt_teacher, opt_student, sched_teacher, sched_student, epoch, config): metrics = { "teacher/loss": Mean(), "teacher/grad_norm": Mean(), "teacher/lr": Last(), "student/loss": Mean(), "student/grad_norm": Mean(), "student/lr": Last(), } model.train() for (x_image, x_target), (u_image, ) in tqdm( data_loader, desc="epoch {}/{}, train".format(epoch, config.epochs)): x_image, x_target, u_image = x_image.to(DEVICE), x_target.to( DEVICE), u_image.to(DEVICE) with higher.innerloop_ctx(model.student, opt_student) as (h_model_student, h_opt_student): # student ################################################################################################## loss_student = cross_entropy(input=h_model_student(u_image), target=model.teacher(u_image)).mean() metrics["student/loss"].update(loss_student.data.cpu().numpy()) metrics["student/lr"].update( np.squeeze(sched_student.get_last_lr())) def grad_callback(grads): metrics["student/grad_norm"].update( grad_norm(grads).data.cpu().numpy()) return grads h_opt_student.step(loss_student.mean(), grad_callback=grad_callback) sched_student.step() # teacher ################################################################################################## loss_teacher = ( cross_entropy(input=model.teacher(x_image), target=one_hot(x_target, NUM_CLASSES)).mean() + cross_entropy(input=h_model_student(x_image), target=one_hot(x_target, NUM_CLASSES)).mean()) metrics["teacher/loss"].update(loss_teacher.data.cpu().numpy()) metrics["teacher/lr"].update( np.squeeze(sched_teacher.get_last_lr())) opt_teacher.zero_grad() loss_teacher.mean().backward() opt_teacher.step() metrics["teacher/grad_norm"].update( grad_norm( p.grad for p in model.teacher.parameters()).data.cpu().numpy()) sched_teacher.step() # copy student weights ##################################################################################### with torch.no_grad(): for p, p_prime in zip(model.student.parameters(), h_model_student.parameters()): p.copy_(p_prime) if epoch % config.log_interval != 0: return writer = SummaryWriter(os.path.join(config.experiment_path, "train")) with torch.no_grad(): for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image( "x_image", torchvision.utils.make_grid(denormalize(x_image), nrow=compute_nrow(x_image), normalize=True), global_step=epoch, ) writer.add_image( "u_image", torchvision.utils.make_grid(denormalize(u_image), nrow=compute_nrow(u_image), normalize=True), global_step=epoch, ) writer.flush() writer.close()
def train_epoch(model, data_loader, fold_probs, optimizer, scheduler, epoch, config): writer = SummaryWriter( os.path.join(config.experiment_path, 'F{}'.format(config.fold), 'train')) metrics = { 'loss': Mean(), 'loss_hist': Concat(), 'entropy': Mean(), 'lr': Last(), } model.train() for images, targets, indices in tqdm(data_loader, desc='[F{}][epoch {}] train'.format( config.fold, epoch)): images, targets, indices = images.to(DEVICE), targets.to( DEVICE), indices.to(DEVICE) if epoch >= config.train.self_distillation.start_epoch: targets = weighted_sum( targets, fold_probs[indices], config.train.self_distillation.target_weight) if config.train.cutmix is not None: if np.random.uniform() > (epoch - 1) / (config.epochs - 1): images, targets = utils.cutmix(images, targets, config.train.cutmix) logits, etc = model(images) loss = compute_loss(input=logits, target=targets, config=config.train) metrics['loss'].update(loss.data.cpu().numpy()) metrics['loss_hist'].update(loss.data.cpu().numpy()) metrics['entropy'].update(compute_entropy(logits).data.cpu().numpy()) metrics['lr'].update(np.squeeze(scheduler.get_lr())) loss.mean().backward() optimizer.step() optimizer.zero_grad() scheduler.step() # FIXME: if epoch >= config.train.self_distillation.start_epoch: probs = torch.cat( [i.softmax(-1) for i in split_target(logits.detach())], -1) fold_probs[indices] = weighted_sum( fold_probs[indices], probs, config.train.self_distillation.pred_ewa) for k in metrics: if k.endswith('_hist'): writer.add_histogram(k, metrics[k].compute_and_reset(), global_step=epoch) else: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image('images', torchvision.utils.make_grid(images, nrow=compute_nrow(images), normalize=True), global_step=epoch) if 'stn' in etc: writer.add_image('stn', torchvision.utils.make_grid(etc['stn'], nrow=compute_nrow( etc['stn']), normalize=True), global_step=epoch) writer.flush() writer.close()
def main(config_path, **kwargs): config = load_config(config_path, **kwargs) del config_path, kwargs writer = SummaryWriter(config.experiment_path) seed_torch(config.seed) env = wrappers.Batch(build_env(config)) if config.render: env = wrappers.TensorboardBatchMonitor(env, writer, config.log_interval) env = wrappers.torch.Torch(env, device=DEVICE) env.seed(config.seed) model = Model(config.model, env.observation_space, env.action_space) model = model.to(DEVICE) if config.restore_path is not None: model.load_state_dict(torch.load(config.restore_path)) optimizer = build_optimizer(config.opt, model.parameters()) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.episodes) metrics = { "loss": Mean(), "lr": Last(), "eps": FPS(), "ep/length": Mean(), "ep/return": Mean(), "rollout/reward": Mean(), "rollout/advantage": Mean(), "rollout/entropy": Mean(), } # training loop ==================================================================================================== for episode in tqdm(range(config.episodes), desc="training"): hist = History() s = env.reset() h = model.zero_state(1) d = torch.ones(1, dtype=torch.bool) model.eval() with torch.no_grad(): while True: trans = hist.append_transition() trans.record(state=s, hidden=h, done=d) a, _, h = model(s, h, d) a = a.sample() s, r, d, info = env.step(a) trans.record(action=a, reward=r) if d: break # optimization ================================================================================================= model.train() # build rollout rollout = hist.full_rollout() # loss loss = compute_loss(env, model, rollout, metrics, config) # metrics metrics["loss"].update(loss.data.cpu().numpy()) metrics["lr"].update(np.squeeze(scheduler.get_last_lr())) # training optimizer.zero_grad() loss.mean().backward() if config.grad_clip_norm is not None: nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm) optimizer.step() scheduler.step() metrics["eps"].update(1) metrics["ep/length"].update(info[0]["episode"]["l"]) metrics["ep/return"].update(info[0]["episode"]["r"]) if episode % config.log_interval == 0 and episode > 0: for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=episode) torch.save( model.state_dict(), os.path.join(config.experiment_path, "model_{}.pth".format(episode)), )
def main(config_path, **kwargs): config = load_config(config_path, **kwargs) transform, update_transform = build_transform() if config.dataset == "mnist": dataset = torchvision.datasets.MNIST(config.dataset_path, transform=transform, download=True) elif config.dataset == "celeba": dataset = torchvision.datasets.ImageFolder(config.dataset_path, transform=transform) data_loader = torch.utils.data.DataLoader( dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.workers, drop_last=True, ) model = nn.ModuleDict({ "discriminator": Discriminator(config.image_size), "generator": Generator(config.image_size, config.latent_size), }) model.to(DEVICE) if config.restore_path is not None: model.load_state_dict(torch.load(config.restore_path)) discriminator_opt = torch.optim.Adam(model.discriminator.parameters(), lr=config.opt.lr, betas=config.opt.beta, eps=1e-8) generator_opt = torch.optim.Adam(model.generator.parameters(), lr=config.opt.lr, betas=config.opt.beta, eps=1e-8) noise_dist = torch.distributions.Normal(0, 1) writer = SummaryWriter(config.experiment_path) metrics = { "loss/discriminator": Mean(), "loss/generator": Mean(), "level": Last(), "alpha": Last(), } for epoch in range(1, config.epochs + 1): model.train() level, _ = compute_level(epoch - 1, config.epochs, 0, len(data_loader), config.image_size, config.grow_min_level) update_transform(int(4 * 2**level)) for i, (real, _) in enumerate( tqdm(data_loader, desc="epoch {} training".format(epoch))): _, a = compute_level( epoch - 1, config.epochs, i, len(data_loader), config.image_size, config.grow_min_level, ) real = real.to(DEVICE) # discriminator ############################################################################################ discriminator_opt.zero_grad() # real scores = model.discriminator(real, level=level, a=a) loss = F.softplus(-scores) loss.mean().backward() loss_real = loss # fake noise = noise_dist.sample( (config.batch_size, config.latent_size)).to(DEVICE) fake = model.generator(noise, level=level, a=a) assert real.size() == fake.size() scores = model.discriminator(fake, level=level, a=a) loss = F.softplus(scores) loss.mean().backward() loss_fake = loss discriminator_opt.step() metrics["loss/discriminator"].update( (loss_real + loss_fake).data.cpu().numpy()) # generator ################################################################################################ generator_opt.zero_grad() # fake noise = noise_dist.sample( (config.batch_size, config.latent_size)).to(DEVICE) fake = model.generator(noise, level=level, a=a) assert real.size() == fake.size() scores = model.discriminator(fake, level=level, a=a) loss = F.softplus(-scores) loss.mean().backward() generator_opt.step() metrics["loss/generator"].update(loss.data.cpu().numpy()) metrics["level"].update(level) metrics["alpha"].update(a) for k in metrics: writer.add_scalar(k, metrics[k].compute_and_reset(), global_step=epoch) writer.add_image("real", utils.make_grid((real + 1) / 2), global_step=epoch) writer.add_image("fake", utils.make_grid((fake + 1) / 2), global_step=epoch) torch.save( model.state_dict(), os.path.join(config.experiment_path, "model_{}.pth".format(epoch)))