def maybe_log(self): num_steps = self.env.get_total_steps() if self.log_freq is not None and num_steps > 0 and num_steps % self.log_freq == 0: self.tensorboard_writer.add_scalar('Epsilon', self.dqn.epsilon_value(), num_steps) if len(self.huber_loss) > 0: self.tensorboard_writer.add_scalar('Huber loss', np.mean(self.huber_loss), num_steps) self.tensorboard_writer.add_scalar( 'FPS', num_steps / (time.time() - self.start_time), num_steps) self.huber_loss = [ ] # clear the loss values and start recollecting them again # Periodically save DQN models if self.checkpoint_freq is not None and num_steps > 0 and num_steps % self.checkpoint_freq == 0: ckpt_model_name = f'dqn_{self.config["env_id"]}_ckpt_steps_{num_steps}.pth' torch.save(utils.get_training_state(self.config, self.dqn), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) # Log the gradients if self.grads_log_freq is not None and self.learner_cnt > 0 and self.learner_cnt % self.grads_log_freq == 0: total_grad_l2_norm = 0 for cnt, (name, weight_or_bias_parameters) in enumerate( self.dqn.named_parameters()): grad_l2_norm = weight_or_bias_parameters.grad.data.norm( p=2).item() self.tensorboard_writer.add_scalar(f'grad_norms/{name}', grad_l2_norm, self.learner_cnt) total_grad_l2_norm += grad_l2_norm**2 # As if we concatenated all of the params into a single vector and took L2 total_grad_l2_norm = total_grad_l2_norm**(1 / 2) self.tensorboard_writer.add_scalar(f'grad_norms/total', total_grad_l2_norm, self.learner_cnt)
def train_dqn(config): env = utils.get_env_wrapper(config['env_id']) replay_buffer = ReplayBuffer( config['replay_buffer_size'], crash_if_no_mem=config['dont_crash_if_no_mem']) utils.set_random_seeds(env, config['seed']) linear_schedule = utils.LinearSchedule(config['epsilon_start_value'], config['epsilon_end_value'], config['epsilon_duration']) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dqn = DQN(env, number_of_actions=env.action_space.n, epsilon_schedule=linear_schedule).to(device) target_dqn = DQN(env, number_of_actions=env.action_space.n).to(device) # Don't get confused by the actor-learner terminology, DQN is not an actor-critic method, but conceptually # we can split the learning process into collecting experience/acting in the env and learning from that experience actor_learner = ActorLearner(config, env, replay_buffer, dqn, target_dqn, env.reset()) while actor_learner.get_number_of_env_steps( ) < config['num_of_training_steps']: num_env_steps = actor_learner.get_number_of_env_steps() if config['console_log_freq'] is not None and num_env_steps % config[ 'console_log_freq'] == 0: actor_learner.log_to_console() actor_learner.collect_experience() if num_env_steps > config['num_warmup_steps']: actor_learner.learn_from_experience() torch.save( # save the best DQN model overall (gave the highest reward in an episode) utils.get_training_state(config, actor_learner.best_dqn_model), os.path.join(BINARIES_PATH, utils.get_available_binary_name(config['env_id'])))
def train_gat(config): global BEST_VAL_ACC, BEST_VAL_LOSS device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checking whether you have a GPU, I hope so! # Step 1: load the graph data node_features, node_labels, edge_index, train_indices, val_indices, test_indices = load_graph_data(config, device) # Step 2: prepare the model gat = GAT( num_of_layers=config['num_of_layers'], num_heads_per_layer=config['num_heads_per_layer'], num_features_per_layer=config['num_features_per_layer'], add_skip_connection=config['add_skip_connection'], bias=config['bias'], dropout=config['dropout'], layer_type=config['layer_type'], log_attention_weights=False # no need to store attentions, used only in playground.py while visualizing ).to(device) # Step 3: Prepare other training related utilities (loss & optimizer and decorator function) loss_fn = nn.CrossEntropyLoss(reduction='mean') optimizer = Adam(gat.parameters(), lr=config['lr'], weight_decay=config['weight_decay']) # The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops main_loop = get_main_loop( config, gat, loss_fn, optimizer, node_features, node_labels, edge_index, train_indices, val_indices, test_indices, config['patience_period'], time.time()) BEST_VAL_ACC, BEST_VAL_LOSS, PATIENCE_CNT = [0, 0, 0] # reset vars used for early stopping # Step 4: Start the training procedure for epoch in range(config['num_of_epochs']): # Training loop main_loop(phase=LoopPhase.TRAIN, epoch=epoch) # Validation loop with torch.no_grad(): try: main_loop(phase=LoopPhase.VAL, epoch=epoch) except Exception as e: # "patience has run out" exception :O print(str(e)) break # break out from the training loop # Step 5: Potentially test your model # Don't overfit to the test dataset - only when you've fine-tuned your model on the validation dataset should you # report your final loss and accuracy on the test dataset. Friends don't let friends overfit to the test data. <3 if config['should_test']: test_acc = main_loop(phase=LoopPhase.TEST) config['test_acc'] = test_acc print(f'Test accuracy = {test_acc}') else: config['test_acc'] = -1 # Save the latest GAT in the binaries directory torch.save(utils.get_training_state(config, gat), os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
def main_loop(phase, epoch=0): global BEST_VAL_ACC, BEST_VAL_LOSS, PATIENCE_CNT, writer # Certain modules behave differently depending on whether we're training the model or not. # e.g. nn.Dropout - we only want to drop model weights during the training. if phase == LoopPhase.TRAIN: gat.train() else: gat.eval() node_indices = get_node_indices(phase) gt_node_labels = get_node_labels(phase) # gt stands for ground truth # Do a forwards pass and extract only the relevant node scores (train/val or test ones) # Note: [0] just extracts the node_features part of the data (index 1 contains the edge_index) # shape = (N, C) where N is the number of nodes in the split (train/val/test) and C is the number of classes nodes_unnormalized_scores = gat(graph_data)[0].index_select(node_dim, node_indices) # Example: let's take an output for a single node on Cora - it's a vector of size 7 and it contains unnormalized # scores like: V = [-1.393, 3.0765, -2.4445, 9.6219, 2.1658, -5.5243, -4.6247] # What PyTorch's cross entropy loss does is for every such vector it first applies a softmax, and so we'll # have the V transformed into: [1.6421e-05, 1.4338e-03, 5.7378e-06, 0.99797, 5.7673e-04, 2.6376e-07, 6.4848e-07] # secondly, whatever the correct class is (say it's 3), it will then take the element at position 3, # 0.99797 in this case, and the loss will be -log(0.99797). It does this for every node and applies a mean. # You can see that as the probability of the correct class for most nodes approaches 1 we get to 0 loss! <3 loss = cross_entropy_loss(nodes_unnormalized_scores, gt_node_labels) if phase == LoopPhase.TRAIN: optimizer.zero_grad() # clean the trainable weights gradients in the computational graph (.grad fields) loss.backward() # compute the gradients for every trainable weight in the computational graph optimizer.step() # apply the gradients to weights # Finds the index of maximum (unnormalized) score for every node and that's the class prediction for that node. # Compare those to true (ground truth) labels and find the fraction of correct predictions -> accuracy metric. class_predictions = torch.argmax(nodes_unnormalized_scores, dim=-1) accuracy = torch.sum(torch.eq(class_predictions, gt_node_labels).long()).item() / len(gt_node_labels) # # Logging # if phase == LoopPhase.TRAIN: # Log metrics if config['enable_tensorboard']: writer.add_scalar('training_loss', loss.item(), epoch) writer.add_scalar('training_acc', accuracy, epoch) # Save model checkpoint if config['checkpoint_freq'] is not None and (epoch + 1) % config['checkpoint_freq'] == 0: ckpt_model_name = f"gat_ckpt_epoch_{epoch + 1}.pth" config['test_acc'] = -1 torch.save(utils.get_training_state(config, gat), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) elif phase == LoopPhase.VAL: # Log metrics if config['enable_tensorboard']: writer.add_scalar('val_loss', loss.item(), epoch) writer.add_scalar('val_acc', accuracy, epoch) # Log to console if config['console_log_freq'] is not None and epoch % config['console_log_freq'] == 0: print(f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] | epoch={epoch + 1} | val acc={accuracy}') # The "patience" logic - should we break out from the training loop? If either validation acc keeps going up # or the val loss keeps going down we won't stop if accuracy > BEST_VAL_ACC or loss.item() < BEST_VAL_LOSS: BEST_VAL_ACC = max(accuracy, BEST_VAL_ACC) # keep track of the best validation accuracy so far BEST_VAL_LOSS = min(loss.item(), BEST_VAL_LOSS) PATIENCE_CNT = 0 # reset the counter every time we encounter new best accuracy else: PATIENCE_CNT += 1 # otherwise keep counting if PATIENCE_CNT >= patience_period: raise Exception('Stopping the training, the universe has no more patience for this training.') else: return accuracy # in the case of test phase we just report back the test accuracy
def train_vanilla_gan(training_config): writer = SummaryWriter() # (tensorboard) writer will output to ./runs/ directory by default device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checking whether you have a GPU # Prepare MNIST data loader (it will download MNIST the first time you run it) mnist_data_loader = utils.get_mnist_data_loader(training_config['batch_size']) # Fetch feed-forward nets (place them on GPU if present) and optimizers which will tweak their weights discriminator_net, generator_net = utils.get_gan(device, GANType.VANILLA.name) discriminator_opt, generator_opt = utils.get_optimizers(discriminator_net, generator_net) # 1s will configure BCELoss into -log(x) whereas 0s will configure it to -log(1-x) # So that means we can effectively use binary cross-entropy loss to achieve adversarial loss! adversarial_loss = nn.BCELoss() real_images_gt = torch.ones((training_config['batch_size'], 1), device=device) fake_images_gt = torch.zeros((training_config['batch_size'], 1), device=device) # For logging purposes ref_batch_size = 16 ref_noise_batch = utils.get_gaussian_latent_batch(ref_batch_size, device) # Track G's quality during training discriminator_loss_values = [] generator_loss_values = [] img_cnt = 0 ts = time.time() # start measuring time # GAN training loop, it's always smart to first train the discriminator so as to avoid mode collapse! utils.print_training_info_to_console(training_config) for epoch in range(training_config['num_epochs']): for batch_idx, (real_images, _) in enumerate(mnist_data_loader): real_images = real_images.to(device) # Place imagery on GPU (if present) # # Train discriminator: maximize V = log(D(x)) + log(1-D(G(z))) or equivalently minimize -V # Note: D = discriminator, x = real images, G = generator, z = latent Gaussian vectors, G(z) = fake images # # Zero out .grad variables in discriminator network (otherwise we would have corrupt results) discriminator_opt.zero_grad() # -log(D(x)) <- we minimize this by making D(x)/discriminator_net(real_images) as close to 1 as possible real_discriminator_loss = adversarial_loss(discriminator_net(real_images), real_images_gt) # G(z) | G == generator_net and z == utils.get_gaussian_latent_batch(batch_size, device) fake_images = generator_net(utils.get_gaussian_latent_batch(training_config['batch_size'], device)) # D(G(z)), we call detach() so that we don't calculate gradients for the generator during backward() fake_images_predictions = discriminator_net(fake_images.detach()) # -log(1 - D(G(z))) <- we minimize this by making D(G(z)) as close to 0 as possible fake_discriminator_loss = adversarial_loss(fake_images_predictions, fake_images_gt) discriminator_loss = real_discriminator_loss + fake_discriminator_loss discriminator_loss.backward() # this will populate .grad vars in the discriminator net discriminator_opt.step() # perform D weights update according to optimizer's strategy # # Train generator: minimize V1 = log(1-D(G(z))) or equivalently maximize V2 = log(D(G(z))) (or min of -V2) # The original expression (V1) had problems with diminishing gradients for G when D is too good. # # if you want to cause mode collapse probably the easiest way to do that would be to add "for i in range(n)" # here (simply train G more frequent than D), n = 10 worked for me other values will also work - experiment. # Zero out .grad variables in discriminator network (otherwise we would have corrupt results) generator_opt.zero_grad() # D(G(z)) (see above for explanations) generated_images_predictions = discriminator_net(generator_net(utils.get_gaussian_latent_batch(training_config['batch_size'], device))) # By placing real_images_gt here we minimize -log(D(G(z))) which happens when D approaches 1 # i.e. we're tricking D into thinking that these generated images are real! generator_loss = adversarial_loss(generated_images_predictions, real_images_gt) generator_loss.backward() # this will populate .grad vars in the G net (also in D but we won't use those) generator_opt.step() # perform G weights update according to optimizer's strategy # # Logging and checkpoint creation # generator_loss_values.append(generator_loss.item()) discriminator_loss_values.append(discriminator_loss.item()) if training_config['enable_tensorboard']: writer.add_scalars('losses/g-and-d', {'g': generator_loss.item(), 'd': discriminator_loss.item()}, len(mnist_data_loader) * epoch + batch_idx + 1) # Save debug imagery to tensorboard also (some redundancy but it may be more beginner-friendly) if training_config['debug_imagery_log_freq'] is not None and batch_idx % training_config['debug_imagery_log_freq'] == 0: with torch.no_grad(): log_generated_images = generator_net(ref_noise_batch) log_generated_images_resized = nn.Upsample(scale_factor=2, mode='nearest')(log_generated_images) intermediate_imagery_grid = make_grid(log_generated_images_resized, nrow=int(np.sqrt(ref_batch_size)), normalize=True) writer.add_image('intermediate generated imagery', intermediate_imagery_grid, len(mnist_data_loader) * epoch + batch_idx + 1) if training_config['console_log_freq'] is not None and batch_idx % training_config['console_log_freq'] == 0: print(f'GAN training: time elapsed = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(mnist_data_loader)}]') # Save intermediate generator images (more convenient like this than through tensorboard) if training_config['debug_imagery_log_freq'] is not None and batch_idx % training_config['debug_imagery_log_freq'] == 0: with torch.no_grad(): log_generated_images = generator_net(ref_noise_batch) log_generated_images_resized = nn.Upsample(scale_factor=2.5, mode='nearest')(log_generated_images) save_image(log_generated_images_resized, os.path.join(training_config['debug_path'], f'{str(img_cnt).zfill(6)}.jpg'), nrow=int(np.sqrt(ref_batch_size)), normalize=True) img_cnt += 1 # Save generator checkpoint if training_config['checkpoint_freq'] is not None and (epoch + 1) % training_config['checkpoint_freq'] == 0 and batch_idx == 0: ckpt_model_name = f"vanilla_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth" torch.save(utils.get_training_state(generator_net, GANType.VANILLA.name), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) # Save the latest generator in the binaries directory torch.save(utils.get_training_state(generator_net, GANType.VANILLA.name), os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
def train_gan(training_config): writer = SummaryWriter() device = torch.device("cpu") # Download MNIST dataset in the directory data mnist_data_loader = utils.get_mnist_data_loader( training_config['batch_size']) discriminator_net, generator_net = utils.get_gan(device, GANType.CLASSIC.name) discriminator_opt, generator_opt = utils.get_optimizers( discriminator_net, generator_net) adversarial_loss = nn.BCELoss() real_image_gt = torch.ones((training_config['batch_size'], 1), device=device) fake_image_gt = torch.zeros((training_config['batch_size'], 1), device=device) ref_batch_size = 16 ref_noise_batch = utils.get_gaussian_latent_batch(ref_batch_size, device) discriminator_loss_values = [] generator_loss_values = [] img_cnt = 0 ts = time.time() utils.print_training_info_to_console(training_config) for epoch in range(training_config['num_epochs']): for batch_idx, (real_images, _) in enumerate(mnist_data_loader): real_images = real_images.to(device) # Train discriminator discriminator_opt.zero_grad() real_discriminator_loss = adversarial_loss( discriminator_net(real_images), real_image_gt) fake_images = generator_net( utils.get_gaussian_latent_batch(training_config['batch_size'], device)) fake_images_predictions = discriminator_net(fake_images.detach()) fake_discriminator_loss = adversarial_loss(fake_images_predictions, fake_image_gt) discriminator_loss = real_discriminator_loss + fake_discriminator_loss discriminator_loss.backward() discriminator_opt.step() # Train generator generator_opt.zero_grad() generated_images_prediction = discriminator_net( generator_net( utils.get_gaussian_latent_batch( training_config['batch_size'], device))) generator_loss = adversarial_loss(generated_images_prediction, real_image_gt) generator_loss.backward() generator_opt.step() # Logging and checkpoint creation generator_loss_values.append(generator_loss.item()) discriminator_loss_values.append(discriminator_loss.item()) if training_config['enable_tensorboard']: writer.add_scalars( 'Losses/g-and-d', { 'g': generator_loss.item(), 'd': discriminator_loss.item() }, len(mnist_data_loader) * epoch + batch_idx + 1) if training_config[ 'debug_imagery_log_freq'] is not None and batch_idx % training_config[ 'debug_imagery_log_freq'] == 0: with torch.no_grad(): log_generated_images = generator_net(ref_noise_batch) log_generated_images_resized = nn.Upsample( scale_factor=2, mode='nearest')(log_generated_images) intermediate_imagery_grid = make_grid( log_generated_images_resized, nrow=int(np.sqrt(ref_batch_size)), normalize=True) writer.add_image( 'intermediate generated imagery', intermediate_imagery_grid, len(mnist_data_loader) * epoch + batch_idx + 1) if training_config[ 'console_log_freq'] is not None and batch_idx % training_config[ 'console_log_freq'] == 0: print( f'GAN training: time elapsed = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(mnist_data_loader)}]' ) # Save intermediate generator images if training_config[ 'debug_imagery_log_freq'] is not None and batch_idx % training_config[ 'debug_imagery_log_freq'] == 0: with torch.no_grad(): log_generated_images = generator_net(ref_noise_batch) log_generated_images_resized = nn.Upsample( scale_factor=2, mode='nearest')(log_generated_images) save_image(log_generated_images_resized, os.path.join(training_config['debug_path'], f'{str(img_cnt).zfill(6)}.jpg'), nrow=int(np.sqrt(ref_batch_size)), normalize=True) img_cnt += 1 # Save generator checkpoint if training_config['checkpoint_freq'] is not None and ( epoch + 1 ) % training_config['checkpoint_freq'] == 0 and batch_idx == 0: ckpt_model_name = f"Classic_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth" torch.save( utils.get_training_state(generator_net, GANType.CLASSIC.name), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) torch.save(utils.get_training_state(generator_net, GANType.CLASSIC.name), os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
def train_gat_ppi(config): """ Very similar to Cora's training script. The main differences are: 1. Using dataloaders since we're dealing with an inductive setting - multiple graphs per batch 2. Doing multi-class classification (BCEWithLogitsLoss) and reporting micro-F1 instead of accuracy 3. Model architecture and hyperparams are a bit different (as reported in the GAT paper) """ global BEST_VAL_PERF, BEST_VAL_LOSS # Checking whether you have a strong GPU. Since PPI training requires almost 8 GBs of VRAM # I've added the option to force the use of CPU even though you have a GPU on your system (but it's too weak). device = torch.device("cuda" if torch.cuda.is_available() and not config['force_cpu'] else "cpu") # Step 1: prepare the data loaders data_loader_train, data_loader_val, data_loader_test = load_graph_data(config, device) # Step 2: prepare the model gat = GAT( num_of_layers=config['num_of_layers'], num_heads_per_layer=config['num_heads_per_layer'], num_features_per_layer=config['num_features_per_layer'], add_skip_connection=config['add_skip_connection'], bias=config['bias'], dropout=config['dropout'], layer_type=config['layer_type'], log_attention_weights=False # no need to store attentions, used only in playground.py for visualizations ).to(device) # Step 3: Prepare other training related utilities (loss & optimizer and decorator function) loss_fn = nn.BCEWithLogitsLoss(reduction='mean') optimizer = Adam(gat.parameters(), lr=config['lr'], weight_decay=config['weight_decay']) # The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops main_loop = get_main_loop( config, gat, loss_fn, optimizer, config['patience_period'], time.time()) BEST_VAL_PERF, BEST_VAL_LOSS, PATIENCE_CNT = [0, 0, 0] # reset vars used for early stopping # Step 4: Start the training procedure for epoch in range(config['num_of_epochs']): # Training loop main_loop(phase=LoopPhase.TRAIN, data_loader=data_loader_train, epoch=epoch) # Validation loop with torch.no_grad(): try: main_loop(phase=LoopPhase.VAL, data_loader=data_loader_val, epoch=epoch) except Exception as e: # "patience has run out" exception :O print(str(e)) break # break out from the training loop # Step 5: Potentially test your model # Don't overfit to the test dataset - only when you've fine-tuned your model on the validation dataset should you # report your final loss and micro-F1 on the test dataset. Friends don't let friends overfit to the test data. <3 if config['should_test']: micro_f1 = main_loop(phase=LoopPhase.TEST, data_loader=data_loader_test) config['test_perf'] = micro_f1 print('*' * 50) print(f'Test micro-F1 = {micro_f1}') else: config['test_perf'] = -1 # Save the latest GAT in the binaries directory torch.save( utils.get_training_state(config, gat), os.path.join(BINARIES_PATH, utils.get_available_binary_name(config['dataset_name'])) )
def main_loop(phase, data_loader, epoch=0): global BEST_VAL_PERF, BEST_VAL_LOSS, PATIENCE_CNT, writer # Certain modules behave differently depending on whether we're training the model or not. # e.g. nn.Dropout - we only want to drop model weights during the training. if phase == LoopPhase.TRAIN: gat.train() else: gat.eval() # Iterate over batches of graph data (2 graphs per batch was used in the original paper for the PPI dataset) # We merge them into a single graph with 2 connected components, that's the main idea. After that # the implementation #3 is agnostic to the fact that those are multiple and not a single graph! for batch_idx, (node_features, gt_node_labels, edge_index) in enumerate(data_loader): # Push the batch onto GPU - note PPI is to big to load the whole dataset into a normal GPU # it takes almost 8 GBs of VRAM to train it on a GPU edge_index = edge_index.to(device) node_features = node_features.to(device) gt_node_labels = gt_node_labels.to(device) # I pack data into tuples because GAT uses nn.Sequential which expects this format graph_data = (node_features, edge_index) # Note: [0] just extracts the node_features part of the data (index 1 contains the edge_index) # shape = (N, C) where N is the number of nodes in the batch and C is the number of classes (121 for PPI) # GAT imp #3 is agnostic to the fact that we actually have multiple graphs # (it sees a single graph with multiple connected components) nodes_unnormalized_scores = gat(graph_data)[0] # Example: because PPI has 121 labels let's make a simple toy example that will show how the loss works. # Let's say we have 3 labels instead and a single node's unnormalized (raw GAT output) scores are [-3, 0, 3] # What this loss will do is first it will apply a sigmoid and so we'll end up with: [0.048, 0.5, 0.95] # next it will apply a binary cross entropy across all of these and find the average, and that's it! # So if the true classes were [0, 0, 1] the loss would be (-log(1-0.048) + -log(1-0.5) + -log(0.95))/3. # You can see that the logarithm takes 2 forms depending on whether the true label is 0 or 1, # either -log(1-x) or -log(x) respectively. Easy-peasy. <3 loss = sigmoid_cross_entropy_loss(nodes_unnormalized_scores, gt_node_labels) if phase == LoopPhase.TRAIN: optimizer.zero_grad() # clean the trainable weights gradients in the computational graph (.grad fields) loss.backward() # compute the gradients for every trainable weight in the computational graph optimizer.step() # apply the gradients to weights # Calculate the main metric - micro F1 # Convert unnormalized scores into predictions. Explanation: # If the unnormalized score is bigger than 0 that means that sigmoid would have a value higher than 0.5 # (by sigmoid's definition) and thus we have predicted 1 for that label otherwise we have predicted 0. pred = (nodes_unnormalized_scores > 0).float().cpu().numpy() gt = gt_node_labels.cpu().numpy() micro_f1 = f1_score(gt, pred, average='micro') # # Logging # global_step = len(data_loader) * epoch + batch_idx if phase == LoopPhase.TRAIN: # Log metrics if config['enable_tensorboard']: writer.add_scalar('training_loss', loss.item(), global_step) writer.add_scalar('training_micro_f1', micro_f1, global_step) # Log to console if config['console_log_freq'] is not None and batch_idx % config['console_log_freq'] == 0: print(f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] |' f' epoch={epoch + 1} | batch={batch_idx + 1} | train micro-F1={micro_f1}.') # Save model checkpoint if config['checkpoint_freq'] is not None and (epoch + 1) % config['checkpoint_freq'] == 0 and batch_idx == 0: ckpt_model_name = f'gat_{config["dataset_name"]}_ckpt_epoch_{epoch + 1}.pth' config['test_perf'] = -1 # test perf not calculated yet, note: perf means main metric micro-F1 here torch.save(utils.get_training_state(config, gat), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) elif phase == LoopPhase.VAL: # Log metrics if config['enable_tensorboard']: writer.add_scalar('val_loss', loss.item(), global_step) writer.add_scalar('val_micro_f1', micro_f1, global_step) # Log to console if config['console_log_freq'] is not None and batch_idx % config['console_log_freq'] == 0: print(f'GAT validation: time elapsed= {(time.time() - time_start):.2f} [s] |' f' epoch={epoch + 1} | batch={batch_idx + 1} | val micro-F1={micro_f1}') # The "patience" logic - should we break out from the training loop? If either validation micro-F1 # keeps going up or the val loss keeps going down we won't stop if micro_f1 > BEST_VAL_PERF or loss.item() < BEST_VAL_LOSS: BEST_VAL_PERF = max(micro_f1, BEST_VAL_PERF) # keep track of the best validation micro_f1 so far BEST_VAL_LOSS = min(loss.item(), BEST_VAL_LOSS) # and the minimal loss PATIENCE_CNT = 0 # reset the counter every time we encounter new best micro_f1 else: PATIENCE_CNT += 1 # otherwise keep counting if PATIENCE_CNT >= patience_period: raise Exception('Stopping the training, the universe has no more patience for this training.') else: return micro_f1 # in the case of test phase we just report back the test micro_f1
def main_loop(phase, data_loader, epoch=0): global BEST_VAL_PERF, BEST_VAL_LOSS, PATIENCE_CNT, writer # Certain modules behave differently depending on whether we're training the model or not. # e.g. nn.Dropout - we only want to drop model weights during the training. if phase == LoopPhase.TRAIN: gat.train() else: gat.eval() # Iterate over batches of graph data (2 graphs per batch was used in the original paper for the PPI dataset) # We merge them into a single graph with 2 connected components, that's the main idea. After that # the implementation #3 is agnostic to the fact that those are multiple and not a single graph! for batch_idx, (node_features, gt_node_labels, edge_index) in enumerate(data_loader): # Push the batch onto GPU - note PPI is to big to load the whole dataset into a normal GPU # it takes almost 8 GBs of VRAM to train it on a GPU edge_index = edge_index.to(device) node_features = node_features.to(device) gt_node_labels = gt_node_labels.to(device) # I pack data into tuples because GAT uses nn.Sequential which expects this format graph_data = (node_features, edge_index) # Note: [0] just extracts the node_features part of the data (index 1 contains the edge_index) # shape = (N, C) where N is the number of nodes in the batch and C is the number of classes (121 for PPI) # GAT imp #3 is agnostic to the fact that we actually have multiple graphs # (it sees a single graph with multiple connected components) nodes_unnormalized_scores = gat(graph_data)[0] # Example: because PPI has 121 labels let's make a simple toy example that will show how the loss works. # Let's say we have 3 labels instead and a single node's unnormalized (raw GAT output) scores are [-3, 0, 3] # What this loss will do is first it will apply a sigmoid and so we'll end up with: [0.048, 0.5, 0.95] # next it will apply a binary cross entropy across all of these and find the average, and that's it! # So if the true classes were [0, 0, 1] the loss would be (-log(1-0.048) + -log(1-0.5) + -log(0.95))/3. # You can see that the logarithm takes 2 forms depending on whether the true label is 0 or 1, # either -log(1-x) or -log(x) respectively. Easy-peasy. <3 loss = sigmoid_cross_entropy_loss(nodes_unnormalized_scores, gt_node_labels) if phase == LoopPhase.TRAIN: optimizer.zero_grad() # clean the trainable weights gradients in the computational graph (.grad fields) loss.backward() # compute the gradients for every trainable weight in the computational graph optimizer.step() # apply the gradients to weights # Calculate the main metric - micro F1 # Convert unnormalized scores into predictions. Explanation: # If the unnormalized score is bigger than 0 that means that sigmoid would have a value higher than 0.5 # (by sigmoid's definition) and thus we have predicted 1 for that label otherwise we have predicted 0. pred = (nodes_unnormalized_scores > 0).float().cpu().numpy() gt = gt_node_labels.cpu().numpy() micro_f1 = f1_score(gt, pred, average='micro') # # Logging # global_step = len(data_loader) * epoch + batch_idx if phase == LoopPhase.TRAIN: # Log metrics if config['enable_tensorboard']: # writer.add_scalar('training_loss', loss.item(), global_step) # writer.add_scalar('training_micro_f1', micro_f1, global_step) # Log to console if config['console_log_freq'] is not None and batch_idx % config['console_log_freq'] == 0: print(f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] |' f' epoch={epoch + 1} | batch={batch_idx + 1} | train micro-F1={micro_f1}.') # Save model checkpoint if config['checkpoint_freq'] is not None and (epoch + 1) % config['checkpoint_freq'] == 0 and batch_idx == 0: ckpt_model_name = f'gat_{config["dataset_name"]}_ckpt_epoch_{epoch + 1}.pth' config['test_perf'] = -1 # test perf not calculated yet, note: perf means main metric micro-F1 here torch.save(utils.get_training_state(config, gat), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) elif phase == LoopPhase.VAL: # Log metrics if config['enable_tensorboard']: # writer.add_scalar('val_loss', loss.item(), global_step) # writer.add_scalar('val_micro_f1', micro_f1, global_step) # Log to console if config['console_log_freq'] is not None and batch_idx % config['console_log_freq'] == 0: print(f'GAT validation: time elapsed= {(time.time() - time_start):.2f} [s] |' f' epoch={epoch + 1} | batch={batch_idx + 1} | val micro-F1={micro_f1}') # The "patience" logic - should we break out from the training loop? If either validation micro-F1 # keeps going up or the val loss keeps going down we won't stop if micro_f1 > BEST_VAL_PERF or loss.item() < BEST_VAL_LOSS: BEST_VAL_PERF = max(micro_f1, BEST_VAL_PERF) # keep track of the best validation micro_f1 so far BEST_VAL_LOSS = min(loss.item(), BEST_VAL_LOSS) # and the minimal loss PATIENCE_CNT = 0 # reset the counter every time we encounter new best micro_f1 else: PATIENCE_CNT += 1 # otherwise keep counting if PATIENCE_CNT >= patience_period: raise Exception('Stopping the training, the universe has no more patience for this training.') else: return micro_f1 # in the case of test phase we just report back the test micro_f1 return main_loop # return the decorated function def train_gat_ppi(config): """ Very similar to Cora's training script. The main differences are: 1. Using dataloaders since we're dealing with an inductive setting - multiple graphs per batch 2. Doing multi-class classification (BCEWithLogitsLoss) and reporting micro-F1 instead of accuracy 3. Model architecture and hyperparams are a bit different (as reported in the GAT paper) """ global BEST_VAL_PERF, BEST_VAL_LOSS # Checking whether you have a strong GPU. Since PPI training requires almost 8 GBs of VRAM # I've added the option to force the use of CPU even though you have a GPU on your system (but it's too weak). device = torch.device("cuda" if torch.cuda.is_available() and not config['force_cpu'] else "cpu") # Step 1: prepare the data loaders data_loader_train, data_loader_val, data_loader_test = load_graph_data(config, device) # Step 2: prepare the model gat = GAT( num_of_layers=config['num_of_layers'], num_heads_per_layer=config['num_heads_per_layer'], num_features_per_layer=config['num_features_per_layer'], add_skip_connection=config['add_skip_connection'], bias=config['bias'], dropout=config['dropout'], layer_type=config['layer_type'], log_attention_weights=False # no need to store attentions, used only in playground.py for visualizations ).to(device) # Step 3: Prepare other training related utilities (loss & optimizer and decorator function) loss_fn = nn.BCEWithLogitsLoss(reduction='mean') optimizer = Adam(gat.parameters(), lr=config['lr'], weight_decay=config['weight_decay']) # The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops main_loop = get_main_loop( config, gat, loss_fn, optimizer, config['patience_period'], time.time()) BEST_VAL_PERF, BEST_VAL_LOSS, PATIENCE_CNT = [0, 0, 0] # reset vars used for early stopping # Step 4: Start the training procedure for epoch in range(config['num_of_epochs']): # Training loop main_loop(phase=LoopPhase.TRAIN, data_loader=data_loader_train, epoch=epoch) # Validation loop with torch.no_grad(): try: main_loop(phase=LoopPhase.VAL, data_loader=data_loader_val, epoch=epoch) except Exception as e: # "patience has run out" exception :O print(str(e)) break # break out from the training loop # Step 5: Potentially test your model # Don't overfit to the test dataset - only when you've fine-tuned your model on the validation dataset should you # report your final loss and micro-F1 on the test dataset. Friends don't let friends overfit to the test data. <3 if config['should_test']: micro_f1 = main_loop(phase=LoopPhase.TEST, data_loader=data_loader_test) config['test_perf'] = micro_f1 print('*' * 50) print(f'Test micro-F1 = {micro_f1}') else: config['test_perf'] = -1 # Save the latest GAT in the binaries directory torch.save( utils.get_training_state(config, gat), os.path.join(BINARIES_PATH, utils.get_available_binary_name(config['dataset_name'])) ) def get_training_args(): parser = argparse.ArgumentParser() # Training related parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=200) parser.add_argument("--patience_period", type=int, help="number of epochs with no improvement on val before terminating", default=100) parser.add_argument("--lr", type=float, help="model learning rate", default=5e-3) parser.add_argument("--weight_decay", type=float, help="L2 regularization on model weights", default=0) parser.add_argument("--should_test", action='store_true', help='should test the model on the test dataset? (no by default)') parser.add_argument("--force_cpu", action='store_true', help='use CPU if your GPU is too small (no by default)') # Dataset related (note: we need the dataset name for metadata and related stuff, and not for picking the dataset) parser.add_argument("--dataset_name", choices=[el.name for el in DatasetType], help='dataset to use for training', default=DatasetType.PPI.name) parser.add_argument("--batch_size", type=int, help='number of graphs in a batch', default=2) parser.add_argument("--should_visualize", action='store_true', help='should visualize the dataset? (no by default)') # Logging/debugging/checkpoint related (helps a lot with experimentation) parser.add_argument("--enable_tensorboard", action='store_true', help="enable tensorboard logging (no by default)") parser.add_argument("--console_log_freq", type=int, help="log to output console (batch) freq (None for no logging)", default=10) parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (epoch) freq (None for no logging)", default=5) args = parser.parse_args() # I'm leaving the hyperparam values as reported in the paper, but I experimented a bit and the comments suggest # how you can make GAT achieve an even higher micro-F1 or make it smaller gat_config = { # GNNs, contrary to CNNs, are often shallow (it ultimately depends on the graph properties) "num_of_layers": 3, # PPI has got 42% of nodes with all 0 features - that's why 3 layers are useful "num_heads_per_layer": [4, 4, 6], # other values may give even better results from the reported ones "num_features_per_layer": [PPI_NUM_INPUT_FEATURES, 256, 256, PPI_NUM_CLASSES], # 64 would also give ~0.975 uF1! "add_skip_connection": True, # skip connection is very important! (keep it otherwise micro-F1 is almost 0) "bias": True, # bias doesn't matter that much "dropout": 0.0, # dropout hurts the performance (best to keep it at 0) "layer_type": LayerType.IMP3 # the only implementation that supports the inductive setting } # Wrapping training configuration into a dictionary training_config = dict() for arg in vars(args): training_config[arg] = getattr(args, arg) training_config['ppi_load_test_only'] = False # load both train/val/test data loaders (don't change it) # Add additional config information training_config.update(gat_config) return training_config if __name__ == '__main__': # Train the graph attention network (GAT) train_gat_ppi(get_training_args())
def train_val_loop(is_train, token_ids_loader, epoch): global num_of_trg_tokens_processed, global_train_step, global_val_step, writer if is_train: baseline_transformer.train() else: baseline_transformer.eval() device = next(baseline_transformer.parameters()).device # # Main loop - start of the CORE PART # for batch_idx, token_ids_batch in enumerate(token_ids_loader): src_token_ids_batch, trg_token_ids_batch_input, trg_token_ids_batch_gt = get_src_and_trg_batches( token_ids_batch) src_mask, trg_mask, num_src_tokens, num_trg_tokens = get_masks_and_count_tokens( src_token_ids_batch, trg_token_ids_batch_input, pad_token_id, device) # log because the KL loss expects log probabilities (just an implementation detail) predicted_log_distributions = baseline_transformer( src_token_ids_batch, trg_token_ids_batch_input, src_mask, trg_mask) smooth_target_distributions = label_smoothing( trg_token_ids_batch_gt) # these are regular probabilities if is_train: custom_lr_optimizer.zero_grad( ) # clean the trainable weights gradients in the computational graph loss = kl_div_loss(predicted_log_distributions, smooth_target_distributions) if is_train: loss.backward( ) # compute the gradients for every trainable weight in the computational graph custom_lr_optimizer.step() # apply the gradients to weights # End of CORE PART # # Logging and metrics # if is_train: global_train_step += 1 num_of_trg_tokens_processed += num_trg_tokens if training_config['enable_tensorboard']: writer.add_scalar('training_loss', loss.item(), global_train_step) if training_config[ 'console_log_freq'] is not None and batch_idx % training_config[ 'console_log_freq'] == 0: print( f'Transformer training: time elapsed= {(time.time() - time_start):.2f} [s] ' f'| epoch={epoch + 1} | batch= {batch_idx + 1} ' f'| target tokens/batch= {num_of_trg_tokens_processed / training_config["console_log_freq"]}' ) num_of_trg_tokens_processed = 0 # Save model checkpoint if training_config['checkpoint_freq'] is not None and ( epoch + 1 ) % training_config['checkpoint_freq'] == 0 and batch_idx == 0: ckpt_model_name = f"transformer_ckpt_epoch_{epoch + 1}.pth" torch.save( utils.get_training_state(training_config, baseline_transformer), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) else: global_val_step += 1 if training_config['enable_tensorboard']: writer.add_scalar('val_loss', loss.item(), global_val_step)
def train_transformer(training_config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # checking whether you have a GPU, I hope so! # Step 1: Prepare data loaders train_token_ids_loader, val_token_ids_loader, src_field_processor, trg_field_processor = get_data_loaders( training_config['dataset_path'], training_config['language_direction'], training_config['dataset_name'], training_config['batch_size'], device) pad_token_id = src_field_processor.vocab.stoi[ PAD_TOKEN] # pad token id is the same for target as well src_vocab_size = len(src_field_processor.vocab) trg_vocab_size = len(trg_field_processor.vocab) # Step 2: Prepare the model (original transformer) and push to GPU baseline_transformer = Transformer( model_dimension=BASELINE_MODEL_DIMENSION, src_vocab_size=src_vocab_size, trg_vocab_size=trg_vocab_size, number_of_heads=BASELINE_MODEL_NUMBER_OF_HEADS, number_of_layers=BASELINE_MODEL_NUMBER_OF_LAYERS, dropout_probability=BASELINE_MODEL_DROPOUT_PROB).to(device) # Step 3: Prepare other training related utilities kl_div_loss = nn.KLDivLoss( reduction='batchmean') # gives better BLEU score than "mean" # Makes smooth target distributions as opposed to conventional one-hot distributions # My feeling is that this is a really dummy and arbitrary heuristic but time will tell. label_smoothing = LabelSmoothingDistribution( BASELINE_MODEL_LABEL_SMOOTHING_VALUE, pad_token_id, trg_vocab_size, device) # Check out playground.py for an intuitive visualization of how the LR changes with time/training steps, easy stuff. custom_lr_optimizer = CustomLRAdamOptimizer( Adam(baseline_transformer.parameters(), betas=(0.9, 0.98), eps=1e-9), BASELINE_MODEL_DIMENSION, training_config['num_warmup_steps']) # The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops train_val_loop = get_train_val_loop(baseline_transformer, custom_lr_optimizer, kl_div_loss, label_smoothing, pad_token_id, time.time()) # Step 4: Start the training for epoch in range(training_config['num_of_epochs']): # Training loop train_val_loop(is_train=True, token_ids_loader=train_token_ids_loader, epoch=epoch) # Validation loop with torch.no_grad(): train_val_loop(is_train=False, token_ids_loader=val_token_ids_loader, epoch=epoch) bleu_score = utils.calculate_bleu_score(baseline_transformer, val_token_ids_loader, trg_field_processor) if training_config['enable_tensorboard']: writer.add_scalar('bleu_score', bleu_score, epoch) # Save the latest transformer in the binaries directory torch.save(utils.get_training_state(training_config, baseline_transformer), os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
def train_gat_ppi(config): # 记录全局参数,最好的验证F1值,最好的验证损失 global BEST_VAL_MICRO_F1, BEST_VAL_LOSS device = torch.device("cuda" if torch.cuda.is_available() and not config['force_cpu'] else "cpu") # Step1 加载数据 data_loader_train, data_loader_val, data_loader_test = load_graph_data( config, device) # Step2 准备模型 gat = GAT_ppi(num_of_layers=config['num_of_layers'], num_heads_per_layer=config['num_heads_per_layer'], num_features_per_layer=config['num_features_per_layer'], add_skip_connection=config['add_skip_connection'], bias=config['bias'], dropout=config['dropout'], log_attention_weights=False).to(device) # Step3 准备训练工具 loss_fn = nn.BCEWithLogitsLoss(reduction='mean') optimizer = Adam(gat.parameters(), lr=config['lr'], weight_decay=config['weight_decay']) # 返回主迭代方法,这样提高代码复用率 main_loop = get_main_loop(config=config, gat=gat, sigmoid_cross_entropy_loss=loss_fn, optimizer=optimizer, patience_period=config['patience_period'], time_start=time.time()) BEST_VAL_MICRO_F1, BEST_VAL_LOSS, PATIENCE_CNT = [0, 0, 0] # 重置 # Step4 开始训练过程 for epoch in range(config['num_of_epochs']): # 训练循环 main_loop(phase=LoopPhase.TRAIN, data_loader=data_loader_train, epoch=epoch) # 验证循环 with torch.no_grad(): try: main_loop(phase=LoopPhase.VAL, data_loader=data_loader_val, epoch=epoch) except Exception as e: print(str(e)) break # Step5 验证 if config['should_test']: micro_f1 = main_loop(phase=LoopPhase.TEST, data_loader=data_loader_test) config['test_perf'] = micro_f1 print('*' * 50) print(f'Test micro-F1 = {micro_f1}') else: config['test_perf'] = -1 # 保存最新的GAT模型的二进制文件 torch.save( utils.get_training_state(config, gat), os.path.join(BINARIES_PATH, utils.get_available_binary_name(config['dataset_name'])))
def main_loop(phase, data_loader, epoch=0): global BEST_VAL_MICRO_F1, BEST_VAL_LOSS, PATIENCE_CNT, writer if phase == LoopPhase.TRAIN: gat.train() else: gat.eval() for batch_idx, (node_features, gt_node_labels, edge_index) in enumerate(data_loader): """迭代一批图形数据,原论文是2张图,这里将2张图合为一张图,相当于一张图2个连通分量""" edge_index = edge_index.to(device) node_features = node_features.to(device) gt_node_labels = gt_node_labels.to(device) graph_data = (node_features, edge_index) # 打包数据 nodes_unnormalized_scores = gat(graph_data)[ 0] # 最后输出的分数,还没经过Sigmoid,由于对于每个分量而言为2分类问题(0或1),所以使用sigmoid loss = sigmoid_cross_entropy_loss(nodes_unnormalized_scores, gt_node_labels) if phase == LoopPhase.TRAIN: optimizer.zero_grad() loss.backward() optimizer.step() # 计算f1 pred = (nodes_unnormalized_scores > 0 ).float().cpu().numpy() # 只要得分大于0 sigmoid之后就大于0.5,那么就认为它是1 gt = gt_node_labels.cpu().numpy() micro_f1 = f1_score(gt, pred, average='micro') # 记录数据 global_step = len(data_loader) * epoch + batch_idx if phase == LoopPhase.TRAIN: # 记录指标 if config['enable_tensorboard']: writer.add_scalar('training_loss', loss.item(), global_step) writer.add_scalar('training_micro_f1', micro_f1, global_step) # 记录数据在控制台,每代记录一次,记录的是这一代第一个batch if config[ 'console_log_freq'] is not None and batch_idx % config[ 'console_log_freq'] == 0: print( f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] |' f' epoch={epoch + 1} | batch={batch_idx + 1} | train micro-F1={micro_f1}.' ) # 保存checkpoint if config['checkpoint_freq'] is not None and ( epoch + 1) % config['checkpoint_freq'] == 0 and batch_idx == 0: ckpt_model_name = f'gat_{config["dataset_name"]}_ckpt_epoch_{epoch + 1}.pth' config['test_perf'] = -1 # 尚未进行性能测试 torch.save(utils.get_training_state(config, gat), os.path.join(CHECKPOINTS_PATH, ckpt_model_name)) elif phase == LoopPhase.VAL: if config['enable_tensorboard']: writer.add_scalar('val_loss', loss.item(), global_step) writer.add_scalar('val_micro_f1', micro_f1, global_step) if config[ 'console_log_freq'] is not None and batch_idx % config[ 'console_log_freq'] == 0: print( f'GAT validation: time elapsed= {(time.time() - time_start):.2f} [s] |' f' epoch={epoch + 1} | batch={batch_idx + 1} | val micro-F1={micro_f1}' ) # 选择最优参数 if micro_f1 > BEST_VAL_MICRO_F1 or loss.item() < BEST_VAL_LOSS: BEST_VAL_MICRO_F1 = max(micro_f1, BEST_VAL_MICRO_F1) BEST_VAL_LOSS = min(loss.item(), BEST_VAL_LOSS) PATIENCE_CNT = 0 else: PATIENCE_CNT += 1 if PATIENCE_CNT >= patience_period: raise Exception( 'Stopping the training, the universe has no more patience for this training.' ) else: return micro_f1 # 单纯的验证,直接返回f1值