コード例 #1
0
    def maybe_log(self):
        num_steps = self.env.get_total_steps()

        if self.log_freq is not None and num_steps > 0 and num_steps % self.log_freq == 0:
            self.tensorboard_writer.add_scalar('Epsilon',
                                               self.dqn.epsilon_value(),
                                               num_steps)
            if len(self.huber_loss) > 0:
                self.tensorboard_writer.add_scalar('Huber loss',
                                                   np.mean(self.huber_loss),
                                                   num_steps)
            self.tensorboard_writer.add_scalar(
                'FPS', num_steps / (time.time() - self.start_time), num_steps)

            self.huber_loss = [
            ]  # clear the loss values and start recollecting them again

        # Periodically save DQN models
        if self.checkpoint_freq is not None and num_steps > 0 and num_steps % self.checkpoint_freq == 0:
            ckpt_model_name = f'dqn_{self.config["env_id"]}_ckpt_steps_{num_steps}.pth'
            torch.save(utils.get_training_state(self.config, self.dqn),
                       os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

        # Log the gradients
        if self.grads_log_freq is not None and self.learner_cnt > 0 and self.learner_cnt % self.grads_log_freq == 0:
            total_grad_l2_norm = 0

            for cnt, (name, weight_or_bias_parameters) in enumerate(
                    self.dqn.named_parameters()):
                grad_l2_norm = weight_or_bias_parameters.grad.data.norm(
                    p=2).item()
                self.tensorboard_writer.add_scalar(f'grad_norms/{name}',
                                                   grad_l2_norm,
                                                   self.learner_cnt)
                total_grad_l2_norm += grad_l2_norm**2

            # As if we concatenated all of the params into a single vector and took L2
            total_grad_l2_norm = total_grad_l2_norm**(1 / 2)
            self.tensorboard_writer.add_scalar(f'grad_norms/total',
                                               total_grad_l2_norm,
                                               self.learner_cnt)
コード例 #2
0
def train_dqn(config):
    env = utils.get_env_wrapper(config['env_id'])
    replay_buffer = ReplayBuffer(
        config['replay_buffer_size'],
        crash_if_no_mem=config['dont_crash_if_no_mem'])

    utils.set_random_seeds(env, config['seed'])

    linear_schedule = utils.LinearSchedule(config['epsilon_start_value'],
                                           config['epsilon_end_value'],
                                           config['epsilon_duration'])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dqn = DQN(env,
              number_of_actions=env.action_space.n,
              epsilon_schedule=linear_schedule).to(device)
    target_dqn = DQN(env, number_of_actions=env.action_space.n).to(device)

    # Don't get confused by the actor-learner terminology, DQN is not an actor-critic method, but conceptually
    # we can split the learning process into collecting experience/acting in the env and learning from that experience
    actor_learner = ActorLearner(config, env, replay_buffer, dqn, target_dqn,
                                 env.reset())

    while actor_learner.get_number_of_env_steps(
    ) < config['num_of_training_steps']:

        num_env_steps = actor_learner.get_number_of_env_steps()
        if config['console_log_freq'] is not None and num_env_steps % config[
                'console_log_freq'] == 0:
            actor_learner.log_to_console()

        actor_learner.collect_experience()

        if num_env_steps > config['num_warmup_steps']:
            actor_learner.learn_from_experience()

    torch.save(  # save the best DQN model overall (gave the highest reward in an episode)
        utils.get_training_state(config, actor_learner.best_dqn_model),
        os.path.join(BINARIES_PATH,
                     utils.get_available_binary_name(config['env_id'])))
コード例 #3
0
def train_gat(config):
    global BEST_VAL_ACC, BEST_VAL_LOSS

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # checking whether you have a GPU, I hope so!

    # Step 1: load the graph data
    node_features, node_labels, edge_index, train_indices, val_indices, test_indices = load_graph_data(config, device)

    # Step 2: prepare the model
    gat = GAT(
        num_of_layers=config['num_of_layers'],
        num_heads_per_layer=config['num_heads_per_layer'],
        num_features_per_layer=config['num_features_per_layer'],
        add_skip_connection=config['add_skip_connection'],
        bias=config['bias'],
        dropout=config['dropout'],
        layer_type=config['layer_type'],
        log_attention_weights=False  # no need to store attentions, used only in playground.py while visualizing
    ).to(device)

    # Step 3: Prepare other training related utilities (loss & optimizer and decorator function)
    loss_fn = nn.CrossEntropyLoss(reduction='mean')
    optimizer = Adam(gat.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])

    # The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops
    main_loop = get_main_loop(
        config,
        gat,
        loss_fn,
        optimizer,
        node_features,
        node_labels,
        edge_index,
        train_indices,
        val_indices,
        test_indices,
        config['patience_period'],
        time.time())

    BEST_VAL_ACC, BEST_VAL_LOSS, PATIENCE_CNT = [0, 0, 0]  # reset vars used for early stopping

    # Step 4: Start the training procedure
    for epoch in range(config['num_of_epochs']):
        # Training loop
        main_loop(phase=LoopPhase.TRAIN, epoch=epoch)

        # Validation loop
        with torch.no_grad():
            try:
                main_loop(phase=LoopPhase.VAL, epoch=epoch)
            except Exception as e:  # "patience has run out" exception :O
                print(str(e))
                break  # break out from the training loop

    # Step 5: Potentially test your model
    # Don't overfit to the test dataset - only when you've fine-tuned your model on the validation dataset should you
    # report your final loss and accuracy on the test dataset. Friends don't let friends overfit to the test data. <3
    if config['should_test']:
        test_acc = main_loop(phase=LoopPhase.TEST)
        config['test_acc'] = test_acc
        print(f'Test accuracy = {test_acc}')
    else:
        config['test_acc'] = -1

    # Save the latest GAT in the binaries directory
    torch.save(utils.get_training_state(config, gat), os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
コード例 #4
0
    def main_loop(phase, epoch=0):
        global BEST_VAL_ACC, BEST_VAL_LOSS, PATIENCE_CNT, writer

        # Certain modules behave differently depending on whether we're training the model or not.
        # e.g. nn.Dropout - we only want to drop model weights during the training.
        if phase == LoopPhase.TRAIN:
            gat.train()
        else:
            gat.eval()

        node_indices = get_node_indices(phase)
        gt_node_labels = get_node_labels(phase)  # gt stands for ground truth

        # Do a forwards pass and extract only the relevant node scores (train/val or test ones)
        # Note: [0] just extracts the node_features part of the data (index 1 contains the edge_index)
        # shape = (N, C) where N is the number of nodes in the split (train/val/test) and C is the number of classes
        nodes_unnormalized_scores = gat(graph_data)[0].index_select(node_dim, node_indices)

        # Example: let's take an output for a single node on Cora - it's a vector of size 7 and it contains unnormalized
        # scores like: V = [-1.393,  3.0765, -2.4445,  9.6219,  2.1658, -5.5243, -4.6247]
        # What PyTorch's cross entropy loss does is for every such vector it first applies a softmax, and so we'll
        # have the V transformed into: [1.6421e-05, 1.4338e-03, 5.7378e-06, 0.99797, 5.7673e-04, 2.6376e-07, 6.4848e-07]
        # secondly, whatever the correct class is (say it's 3), it will then take the element at position 3,
        # 0.99797 in this case, and the loss will be -log(0.99797). It does this for every node and applies a mean.
        # You can see that as the probability of the correct class for most nodes approaches 1 we get to 0 loss! <3
        loss = cross_entropy_loss(nodes_unnormalized_scores, gt_node_labels)

        if phase == LoopPhase.TRAIN:
            optimizer.zero_grad()  # clean the trainable weights gradients in the computational graph (.grad fields)
            loss.backward()  # compute the gradients for every trainable weight in the computational graph
            optimizer.step()  # apply the gradients to weights

        # Finds the index of maximum (unnormalized) score for every node and that's the class prediction for that node.
        # Compare those to true (ground truth) labels and find the fraction of correct predictions -> accuracy metric.
        class_predictions = torch.argmax(nodes_unnormalized_scores, dim=-1)
        accuracy = torch.sum(torch.eq(class_predictions, gt_node_labels).long()).item() / len(gt_node_labels)

        #
        # Logging
        #

        if phase == LoopPhase.TRAIN:
            # Log metrics
            if config['enable_tensorboard']:
                writer.add_scalar('training_loss', loss.item(), epoch)
                writer.add_scalar('training_acc', accuracy, epoch)

            # Save model checkpoint
            if config['checkpoint_freq'] is not None and (epoch + 1) % config['checkpoint_freq'] == 0:
                ckpt_model_name = f"gat_ckpt_epoch_{epoch + 1}.pth"
                config['test_acc'] = -1
                torch.save(utils.get_training_state(config, gat), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

        elif phase == LoopPhase.VAL:
            # Log metrics
            if config['enable_tensorboard']:
                writer.add_scalar('val_loss', loss.item(), epoch)
                writer.add_scalar('val_acc', accuracy, epoch)

            # Log to console
            if config['console_log_freq'] is not None and epoch % config['console_log_freq'] == 0:
                print(f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] | epoch={epoch + 1} | val acc={accuracy}')

            # The "patience" logic - should we break out from the training loop? If either validation acc keeps going up
            # or the val loss keeps going down we won't stop
            if accuracy > BEST_VAL_ACC or loss.item() < BEST_VAL_LOSS:
                BEST_VAL_ACC = max(accuracy, BEST_VAL_ACC)  # keep track of the best validation accuracy so far
                BEST_VAL_LOSS = min(loss.item(), BEST_VAL_LOSS)
                PATIENCE_CNT = 0  # reset the counter every time we encounter new best accuracy
            else:
                PATIENCE_CNT += 1  # otherwise keep counting

            if PATIENCE_CNT >= patience_period:
                raise Exception('Stopping the training, the universe has no more patience for this training.')

        else:
            return accuracy  # in the case of test phase we just report back the test accuracy
コード例 #5
0
def train_vanilla_gan(training_config):
    writer = SummaryWriter()  # (tensorboard) writer will output to ./runs/ directory by default
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # checking whether you have a GPU

    # Prepare MNIST data loader (it will download MNIST the first time you run it)
    mnist_data_loader = utils.get_mnist_data_loader(training_config['batch_size'])

    # Fetch feed-forward nets (place them on GPU if present) and optimizers which will tweak their weights
    discriminator_net, generator_net = utils.get_gan(device, GANType.VANILLA.name)
    discriminator_opt, generator_opt = utils.get_optimizers(discriminator_net, generator_net)

    # 1s will configure BCELoss into -log(x) whereas 0s will configure it to -log(1-x)
    # So that means we can effectively use binary cross-entropy loss to achieve adversarial loss!
    adversarial_loss = nn.BCELoss()
    real_images_gt = torch.ones((training_config['batch_size'], 1), device=device)
    fake_images_gt = torch.zeros((training_config['batch_size'], 1), device=device)

    # For logging purposes
    ref_batch_size = 16
    ref_noise_batch = utils.get_gaussian_latent_batch(ref_batch_size, device)  # Track G's quality during training
    discriminator_loss_values = []
    generator_loss_values = []
    img_cnt = 0

    ts = time.time()  # start measuring time

    # GAN training loop, it's always smart to first train the discriminator so as to avoid mode collapse!
    utils.print_training_info_to_console(training_config)
    for epoch in range(training_config['num_epochs']):
        for batch_idx, (real_images, _) in enumerate(mnist_data_loader):

            real_images = real_images.to(device)  # Place imagery on GPU (if present)

            #
            # Train discriminator: maximize V = log(D(x)) + log(1-D(G(z))) or equivalently minimize -V
            # Note: D = discriminator, x = real images, G = generator, z = latent Gaussian vectors, G(z) = fake images
            #

            # Zero out .grad variables in discriminator network (otherwise we would have corrupt results)
            discriminator_opt.zero_grad()

            # -log(D(x)) <- we minimize this by making D(x)/discriminator_net(real_images) as close to 1 as possible
            real_discriminator_loss = adversarial_loss(discriminator_net(real_images), real_images_gt)

            # G(z) | G == generator_net and z == utils.get_gaussian_latent_batch(batch_size, device)
            fake_images = generator_net(utils.get_gaussian_latent_batch(training_config['batch_size'], device))
            # D(G(z)), we call detach() so that we don't calculate gradients for the generator during backward()
            fake_images_predictions = discriminator_net(fake_images.detach())
            # -log(1 - D(G(z))) <- we minimize this by making D(G(z)) as close to 0 as possible
            fake_discriminator_loss = adversarial_loss(fake_images_predictions, fake_images_gt)

            discriminator_loss = real_discriminator_loss + fake_discriminator_loss
            discriminator_loss.backward()  # this will populate .grad vars in the discriminator net
            discriminator_opt.step()  # perform D weights update according to optimizer's strategy

            #
            # Train generator: minimize V1 = log(1-D(G(z))) or equivalently maximize V2 = log(D(G(z))) (or min of -V2)
            # The original expression (V1) had problems with diminishing gradients for G when D is too good.
            #

            # if you want to cause mode collapse probably the easiest way to do that would be to add "for i in range(n)"
            # here (simply train G more frequent than D), n = 10 worked for me other values will also work - experiment.

            # Zero out .grad variables in discriminator network (otherwise we would have corrupt results)
            generator_opt.zero_grad()

            # D(G(z)) (see above for explanations)
            generated_images_predictions = discriminator_net(generator_net(utils.get_gaussian_latent_batch(training_config['batch_size'], device)))
            # By placing real_images_gt here we minimize -log(D(G(z))) which happens when D approaches 1
            # i.e. we're tricking D into thinking that these generated images are real!
            generator_loss = adversarial_loss(generated_images_predictions, real_images_gt)

            generator_loss.backward()  # this will populate .grad vars in the G net (also in D but we won't use those)
            generator_opt.step()  # perform G weights update according to optimizer's strategy

            #
            # Logging and checkpoint creation
            #

            generator_loss_values.append(generator_loss.item())
            discriminator_loss_values.append(discriminator_loss.item())

            if training_config['enable_tensorboard']:
                writer.add_scalars('losses/g-and-d', {'g': generator_loss.item(), 'd': discriminator_loss.item()}, len(mnist_data_loader) * epoch + batch_idx + 1)
                # Save debug imagery to tensorboard also (some redundancy but it may be more beginner-friendly)
                if training_config['debug_imagery_log_freq'] is not None and batch_idx % training_config['debug_imagery_log_freq'] == 0:
                    with torch.no_grad():
                        log_generated_images = generator_net(ref_noise_batch)
                        log_generated_images_resized = nn.Upsample(scale_factor=2, mode='nearest')(log_generated_images)
                        intermediate_imagery_grid = make_grid(log_generated_images_resized, nrow=int(np.sqrt(ref_batch_size)), normalize=True)
                        writer.add_image('intermediate generated imagery', intermediate_imagery_grid, len(mnist_data_loader) * epoch + batch_idx + 1)

            if training_config['console_log_freq'] is not None and batch_idx % training_config['console_log_freq'] == 0:
                print(f'GAN training: time elapsed = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(mnist_data_loader)}]')

            # Save intermediate generator images (more convenient like this than through tensorboard)
            if training_config['debug_imagery_log_freq'] is not None and batch_idx % training_config['debug_imagery_log_freq'] == 0:
                with torch.no_grad():
                    log_generated_images = generator_net(ref_noise_batch)
                    log_generated_images_resized = nn.Upsample(scale_factor=2.5, mode='nearest')(log_generated_images)
                    save_image(log_generated_images_resized, os.path.join(training_config['debug_path'], f'{str(img_cnt).zfill(6)}.jpg'), nrow=int(np.sqrt(ref_batch_size)), normalize=True)
                    img_cnt += 1

            # Save generator checkpoint
            if training_config['checkpoint_freq'] is not None and (epoch + 1) % training_config['checkpoint_freq'] == 0 and batch_idx == 0:
                ckpt_model_name = f"vanilla_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth"
                torch.save(utils.get_training_state(generator_net, GANType.VANILLA.name), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

    # Save the latest generator in the binaries directory
    torch.save(utils.get_training_state(generator_net, GANType.VANILLA.name), os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
コード例 #6
0
def train_gan(training_config):
    writer = SummaryWriter()
    device = torch.device("cpu")

    # Download MNIST dataset in the directory data
    mnist_data_loader = utils.get_mnist_data_loader(
        training_config['batch_size'])

    discriminator_net, generator_net = utils.get_gan(device,
                                                     GANType.CLASSIC.name)
    discriminator_opt, generator_opt = utils.get_optimizers(
        discriminator_net, generator_net)

    adversarial_loss = nn.BCELoss()
    real_image_gt = torch.ones((training_config['batch_size'], 1),
                               device=device)
    fake_image_gt = torch.zeros((training_config['batch_size'], 1),
                                device=device)

    ref_batch_size = 16
    ref_noise_batch = utils.get_gaussian_latent_batch(ref_batch_size, device)
    discriminator_loss_values = []
    generator_loss_values = []
    img_cnt = 0

    ts = time.time()

    utils.print_training_info_to_console(training_config)
    for epoch in range(training_config['num_epochs']):
        for batch_idx, (real_images, _) in enumerate(mnist_data_loader):
            real_images = real_images.to(device)

            # Train discriminator
            discriminator_opt.zero_grad()

            real_discriminator_loss = adversarial_loss(
                discriminator_net(real_images), real_image_gt)

            fake_images = generator_net(
                utils.get_gaussian_latent_batch(training_config['batch_size'],
                                                device))
            fake_images_predictions = discriminator_net(fake_images.detach())
            fake_discriminator_loss = adversarial_loss(fake_images_predictions,
                                                       fake_image_gt)

            discriminator_loss = real_discriminator_loss + fake_discriminator_loss
            discriminator_loss.backward()
            discriminator_opt.step()

            # Train generator
            generator_opt.zero_grad()

            generated_images_prediction = discriminator_net(
                generator_net(
                    utils.get_gaussian_latent_batch(
                        training_config['batch_size'], device)))

            generator_loss = adversarial_loss(generated_images_prediction,
                                              real_image_gt)

            generator_loss.backward()
            generator_opt.step()

            # Logging and checkpoint creation
            generator_loss_values.append(generator_loss.item())
            discriminator_loss_values.append(discriminator_loss.item())

            if training_config['enable_tensorboard']:
                writer.add_scalars(
                    'Losses/g-and-d', {
                        'g': generator_loss.item(),
                        'd': discriminator_loss.item()
                    },
                    len(mnist_data_loader) * epoch + batch_idx + 1)

                if training_config[
                        'debug_imagery_log_freq'] is not None and batch_idx % training_config[
                            'debug_imagery_log_freq'] == 0:
                    with torch.no_grad():
                        log_generated_images = generator_net(ref_noise_batch)
                        log_generated_images_resized = nn.Upsample(
                            scale_factor=2,
                            mode='nearest')(log_generated_images)
                        intermediate_imagery_grid = make_grid(
                            log_generated_images_resized,
                            nrow=int(np.sqrt(ref_batch_size)),
                            normalize=True)
                        writer.add_image(
                            'intermediate generated imagery',
                            intermediate_imagery_grid,
                            len(mnist_data_loader) * epoch + batch_idx + 1)

            if training_config[
                    'console_log_freq'] is not None and batch_idx % training_config[
                        'console_log_freq'] == 0:
                print(
                    f'GAN training: time elapsed = {(time.time() - ts):.2f} [s] | epoch={epoch + 1} | batch= [{batch_idx + 1}/{len(mnist_data_loader)}]'
                )

            # Save intermediate generator images
            if training_config[
                    'debug_imagery_log_freq'] is not None and batch_idx % training_config[
                        'debug_imagery_log_freq'] == 0:
                with torch.no_grad():
                    log_generated_images = generator_net(ref_noise_batch)
                    log_generated_images_resized = nn.Upsample(
                        scale_factor=2, mode='nearest')(log_generated_images)
                    save_image(log_generated_images_resized,
                               os.path.join(training_config['debug_path'],
                                            f'{str(img_cnt).zfill(6)}.jpg'),
                               nrow=int(np.sqrt(ref_batch_size)),
                               normalize=True)
                    img_cnt += 1

            # Save generator checkpoint
            if training_config['checkpoint_freq'] is not None and (
                    epoch + 1
            ) % training_config['checkpoint_freq'] == 0 and batch_idx == 0:
                ckpt_model_name = f"Classic_ckpt_epoch_{epoch + 1}_batch_{batch_idx + 1}.pth"
                torch.save(
                    utils.get_training_state(generator_net,
                                             GANType.CLASSIC.name),
                    os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

    torch.save(utils.get_training_state(generator_net, GANType.CLASSIC.name),
               os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
コード例 #7
0
def train_gat_ppi(config):
    """
    Very similar to Cora's training script. The main differences are:
    1. Using dataloaders since we're dealing with an inductive setting - multiple graphs per batch
    2. Doing multi-class classification (BCEWithLogitsLoss) and reporting micro-F1 instead of accuracy
    3. Model architecture and hyperparams are a bit different (as reported in the GAT paper)

    """
    global BEST_VAL_PERF, BEST_VAL_LOSS

    # Checking whether you have a strong GPU. Since PPI training requires almost 8 GBs of VRAM
    # I've added the option to force the use of CPU even though you have a GPU on your system (but it's too weak).
    device = torch.device("cuda" if torch.cuda.is_available() and not config['force_cpu'] else "cpu")

    # Step 1: prepare the data loaders
    data_loader_train, data_loader_val, data_loader_test = load_graph_data(config, device)

    # Step 2: prepare the model
    gat = GAT(
        num_of_layers=config['num_of_layers'],
        num_heads_per_layer=config['num_heads_per_layer'],
        num_features_per_layer=config['num_features_per_layer'],
        add_skip_connection=config['add_skip_connection'],
        bias=config['bias'],
        dropout=config['dropout'],
        layer_type=config['layer_type'],
        log_attention_weights=False  # no need to store attentions, used only in playground.py for visualizations
    ).to(device)

    # Step 3: Prepare other training related utilities (loss & optimizer and decorator function)
    loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
    optimizer = Adam(gat.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])

    # The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops
    main_loop = get_main_loop(
        config,
        gat,
        loss_fn,
        optimizer,
        config['patience_period'],
        time.time())

    BEST_VAL_PERF, BEST_VAL_LOSS, PATIENCE_CNT = [0, 0, 0]  # reset vars used for early stopping

    # Step 4: Start the training procedure
    for epoch in range(config['num_of_epochs']):
        # Training loop
        main_loop(phase=LoopPhase.TRAIN, data_loader=data_loader_train, epoch=epoch)

        # Validation loop
        with torch.no_grad():
            try:
                main_loop(phase=LoopPhase.VAL, data_loader=data_loader_val, epoch=epoch)
            except Exception as e:  # "patience has run out" exception :O
                print(str(e))
                break  # break out from the training loop

    # Step 5: Potentially test your model
    # Don't overfit to the test dataset - only when you've fine-tuned your model on the validation dataset should you
    # report your final loss and micro-F1 on the test dataset. Friends don't let friends overfit to the test data. <3
    if config['should_test']:
        micro_f1 = main_loop(phase=LoopPhase.TEST, data_loader=data_loader_test)
        config['test_perf'] = micro_f1

        print('*' * 50)
        print(f'Test micro-F1 = {micro_f1}')
    else:
        config['test_perf'] = -1

    # Save the latest GAT in the binaries directory
    torch.save(
        utils.get_training_state(config, gat),
        os.path.join(BINARIES_PATH, utils.get_available_binary_name(config['dataset_name']))
    )
コード例 #8
0
    def main_loop(phase, data_loader, epoch=0):
        global BEST_VAL_PERF, BEST_VAL_LOSS, PATIENCE_CNT, writer

        # Certain modules behave differently depending on whether we're training the model or not.
        # e.g. nn.Dropout - we only want to drop model weights during the training.
        if phase == LoopPhase.TRAIN:
            gat.train()
        else:
            gat.eval()

        # Iterate over batches of graph data (2 graphs per batch was used in the original paper for the PPI dataset)
        # We merge them into a single graph with 2 connected components, that's the main idea. After that
        # the implementation #3 is agnostic to the fact that those are multiple and not a single graph!
        for batch_idx, (node_features, gt_node_labels, edge_index) in enumerate(data_loader):
            # Push the batch onto GPU - note PPI is to big to load the whole dataset into a normal GPU
            # it takes almost 8 GBs of VRAM to train it on a GPU
            edge_index = edge_index.to(device)
            node_features = node_features.to(device)
            gt_node_labels = gt_node_labels.to(device)

            # I pack data into tuples because GAT uses nn.Sequential which expects this format
            graph_data = (node_features, edge_index)

            # Note: [0] just extracts the node_features part of the data (index 1 contains the edge_index)
            # shape = (N, C) where N is the number of nodes in the batch and C is the number of classes (121 for PPI)
            # GAT imp #3 is agnostic to the fact that we actually have multiple graphs
            # (it sees a single graph with multiple connected components)
            nodes_unnormalized_scores = gat(graph_data)[0]

            # Example: because PPI has 121 labels let's make a simple toy example that will show how the loss works.
            # Let's say we have 3 labels instead and a single node's unnormalized (raw GAT output) scores are [-3, 0, 3]
            # What this loss will do is first it will apply a sigmoid and so we'll end up with: [0.048, 0.5, 0.95]
            # next it will apply a binary cross entropy across all of these and find the average, and that's it!
            # So if the true classes were [0, 0, 1] the loss would be (-log(1-0.048) + -log(1-0.5) + -log(0.95))/3.
            # You can see that the logarithm takes 2 forms depending on whether the true label is 0 or 1,
            # either -log(1-x) or -log(x) respectively. Easy-peasy. <3
            loss = sigmoid_cross_entropy_loss(nodes_unnormalized_scores, gt_node_labels)

            if phase == LoopPhase.TRAIN:
                optimizer.zero_grad()  # clean the trainable weights gradients in the computational graph (.grad fields)
                loss.backward()  # compute the gradients for every trainable weight in the computational graph
                optimizer.step()  # apply the gradients to weights

            # Calculate the main metric - micro F1

            # Convert unnormalized scores into predictions. Explanation:
            # If the unnormalized score is bigger than 0 that means that sigmoid would have a value higher than 0.5
            # (by sigmoid's definition) and thus we have predicted 1 for that label otherwise we have predicted 0.
            pred = (nodes_unnormalized_scores > 0).float().cpu().numpy()
            gt = gt_node_labels.cpu().numpy()
            micro_f1 = f1_score(gt, pred, average='micro')

            #
            # Logging
            #

            global_step = len(data_loader) * epoch + batch_idx
            if phase == LoopPhase.TRAIN:
                # Log metrics
                if config['enable_tensorboard']:
                    writer.add_scalar('training_loss', loss.item(), global_step)
                    writer.add_scalar('training_micro_f1', micro_f1, global_step)

                # Log to console
                if config['console_log_freq'] is not None and batch_idx % config['console_log_freq'] == 0:
                    print(f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] |'
                          f' epoch={epoch + 1} | batch={batch_idx + 1} | train micro-F1={micro_f1}.')

                # Save model checkpoint
                if config['checkpoint_freq'] is not None and (epoch + 1) % config['checkpoint_freq'] == 0 and batch_idx == 0:
                    ckpt_model_name = f'gat_{config["dataset_name"]}_ckpt_epoch_{epoch + 1}.pth'
                    config['test_perf'] = -1  # test perf not calculated yet, note: perf means main metric micro-F1 here
                    torch.save(utils.get_training_state(config, gat), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

            elif phase == LoopPhase.VAL:
                # Log metrics
                if config['enable_tensorboard']:
                    writer.add_scalar('val_loss', loss.item(), global_step)
                    writer.add_scalar('val_micro_f1', micro_f1, global_step)

                # Log to console
                if config['console_log_freq'] is not None and batch_idx % config['console_log_freq'] == 0:
                    print(f'GAT validation: time elapsed= {(time.time() - time_start):.2f} [s] |'
                          f' epoch={epoch + 1} | batch={batch_idx + 1} | val micro-F1={micro_f1}')

                # The "patience" logic - should we break out from the training loop? If either validation micro-F1
                # keeps going up or the val loss keeps going down we won't stop
                if micro_f1 > BEST_VAL_PERF or loss.item() < BEST_VAL_LOSS:
                    BEST_VAL_PERF = max(micro_f1, BEST_VAL_PERF)  # keep track of the best validation micro_f1 so far
                    BEST_VAL_LOSS = min(loss.item(), BEST_VAL_LOSS)  # and the minimal loss
                    PATIENCE_CNT = 0  # reset the counter every time we encounter new best micro_f1
                else:
                    PATIENCE_CNT += 1  # otherwise keep counting

                if PATIENCE_CNT >= patience_period:
                    raise Exception('Stopping the training, the universe has no more patience for this training.')

            else:
                return micro_f1  # in the case of test phase we just report back the test micro_f1
コード例 #9
0
    def main_loop(phase, data_loader, epoch=0):
        global BEST_VAL_PERF, BEST_VAL_LOSS, PATIENCE_CNT, writer

        # Certain modules behave differently depending on whether we're training the model or not.
        # e.g. nn.Dropout - we only want to drop model weights during the training.
        if phase == LoopPhase.TRAIN:
            gat.train()
        else:
            gat.eval()

        # Iterate over batches of graph data (2 graphs per batch was used in the original paper for the PPI dataset)
        # We merge them into a single graph with 2 connected components, that's the main idea. After that
        # the implementation #3 is agnostic to the fact that those are multiple and not a single graph!
        for batch_idx, (node_features, gt_node_labels, edge_index) in enumerate(data_loader):
            # Push the batch onto GPU - note PPI is to big to load the whole dataset into a normal GPU
            # it takes almost 8 GBs of VRAM to train it on a GPU
            edge_index = edge_index.to(device)
            node_features = node_features.to(device)
            gt_node_labels = gt_node_labels.to(device)

            # I pack data into tuples because GAT uses nn.Sequential which expects this format
            graph_data = (node_features, edge_index)

            # Note: [0] just extracts the node_features part of the data (index 1 contains the edge_index)
            # shape = (N, C) where N is the number of nodes in the batch and C is the number of classes (121 for PPI)
            # GAT imp #3 is agnostic to the fact that we actually have multiple graphs
            # (it sees a single graph with multiple connected components)
            nodes_unnormalized_scores = gat(graph_data)[0]

            # Example: because PPI has 121 labels let's make a simple toy example that will show how the loss works.
            # Let's say we have 3 labels instead and a single node's unnormalized (raw GAT output) scores are [-3, 0, 3]
            # What this loss will do is first it will apply a sigmoid and so we'll end up with: [0.048, 0.5, 0.95]
            # next it will apply a binary cross entropy across all of these and find the average, and that's it!
            # So if the true classes were [0, 0, 1] the loss would be (-log(1-0.048) + -log(1-0.5) + -log(0.95))/3.
            # You can see that the logarithm takes 2 forms depending on whether the true label is 0 or 1,
            # either -log(1-x) or -log(x) respectively. Easy-peasy. <3
            loss = sigmoid_cross_entropy_loss(nodes_unnormalized_scores, gt_node_labels)

            if phase == LoopPhase.TRAIN:
                optimizer.zero_grad()  # clean the trainable weights gradients in the computational graph (.grad fields)
                loss.backward()  # compute the gradients for every trainable weight in the computational graph
                optimizer.step()  # apply the gradients to weights

            # Calculate the main metric - micro F1

            # Convert unnormalized scores into predictions. Explanation:
            # If the unnormalized score is bigger than 0 that means that sigmoid would have a value higher than 0.5
            # (by sigmoid's definition) and thus we have predicted 1 for that label otherwise we have predicted 0.
            pred = (nodes_unnormalized_scores > 0).float().cpu().numpy()
            gt = gt_node_labels.cpu().numpy()
            micro_f1 = f1_score(gt, pred, average='micro')

            #
            # Logging
            #

            global_step = len(data_loader) * epoch + batch_idx
            if phase == LoopPhase.TRAIN:
                # Log metrics
                if config['enable_tensorboard']:
                    # writer.add_scalar('training_loss', loss.item(), global_step)
                    # writer.add_scalar('training_micro_f1', micro_f1, global_step)

                # Log to console
                if config['console_log_freq'] is not None and batch_idx % config['console_log_freq'] == 0:
                    print(f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] |'
                          f' epoch={epoch + 1} | batch={batch_idx + 1} | train micro-F1={micro_f1}.')

                # Save model checkpoint
                if config['checkpoint_freq'] is not None and (epoch + 1) % config['checkpoint_freq'] == 0 and batch_idx == 0:
                    ckpt_model_name = f'gat_{config["dataset_name"]}_ckpt_epoch_{epoch + 1}.pth'
                    config['test_perf'] = -1  # test perf not calculated yet, note: perf means main metric micro-F1 here
                    torch.save(utils.get_training_state(config, gat), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

            elif phase == LoopPhase.VAL:
                # Log metrics
                if config['enable_tensorboard']:
                    # writer.add_scalar('val_loss', loss.item(), global_step)
                    # writer.add_scalar('val_micro_f1', micro_f1, global_step)

                # Log to console
                if config['console_log_freq'] is not None and batch_idx % config['console_log_freq'] == 0:
                    print(f'GAT validation: time elapsed= {(time.time() - time_start):.2f} [s] |'
                          f' epoch={epoch + 1} | batch={batch_idx + 1} | val micro-F1={micro_f1}')

                # The "patience" logic - should we break out from the training loop? If either validation micro-F1
                # keeps going up or the val loss keeps going down we won't stop
                if micro_f1 > BEST_VAL_PERF or loss.item() < BEST_VAL_LOSS:
                    BEST_VAL_PERF = max(micro_f1, BEST_VAL_PERF)  # keep track of the best validation micro_f1 so far
                    BEST_VAL_LOSS = min(loss.item(), BEST_VAL_LOSS)  # and the minimal loss
                    PATIENCE_CNT = 0  # reset the counter every time we encounter new best micro_f1
                else:
                    PATIENCE_CNT += 1  # otherwise keep counting

                if PATIENCE_CNT >= patience_period:
                    raise Exception('Stopping the training, the universe has no more patience for this training.')

            else:
                return micro_f1  # in the case of test phase we just report back the test micro_f1

    return main_loop  # return the decorated function


def train_gat_ppi(config):
    """
    Very similar to Cora's training script. The main differences are:
    1. Using dataloaders since we're dealing with an inductive setting - multiple graphs per batch
    2. Doing multi-class classification (BCEWithLogitsLoss) and reporting micro-F1 instead of accuracy
    3. Model architecture and hyperparams are a bit different (as reported in the GAT paper)

    """
    global BEST_VAL_PERF, BEST_VAL_LOSS

    # Checking whether you have a strong GPU. Since PPI training requires almost 8 GBs of VRAM
    # I've added the option to force the use of CPU even though you have a GPU on your system (but it's too weak).
    device = torch.device("cuda" if torch.cuda.is_available() and not config['force_cpu'] else "cpu")

    # Step 1: prepare the data loaders
    data_loader_train, data_loader_val, data_loader_test = load_graph_data(config, device)

    # Step 2: prepare the model
    gat = GAT(
        num_of_layers=config['num_of_layers'],
        num_heads_per_layer=config['num_heads_per_layer'],
        num_features_per_layer=config['num_features_per_layer'],
        add_skip_connection=config['add_skip_connection'],
        bias=config['bias'],
        dropout=config['dropout'],
        layer_type=config['layer_type'],
        log_attention_weights=False  # no need to store attentions, used only in playground.py for visualizations
    ).to(device)

    # Step 3: Prepare other training related utilities (loss & optimizer and decorator function)
    loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
    optimizer = Adam(gat.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])

    # The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops
    main_loop = get_main_loop(
        config,
        gat,
        loss_fn,
        optimizer,
        config['patience_period'],
        time.time())

    BEST_VAL_PERF, BEST_VAL_LOSS, PATIENCE_CNT = [0, 0, 0]  # reset vars used for early stopping

    # Step 4: Start the training procedure
    for epoch in range(config['num_of_epochs']):
        # Training loop
        main_loop(phase=LoopPhase.TRAIN, data_loader=data_loader_train, epoch=epoch)

        # Validation loop
        with torch.no_grad():
            try:
                main_loop(phase=LoopPhase.VAL, data_loader=data_loader_val, epoch=epoch)
            except Exception as e:  # "patience has run out" exception :O
                print(str(e))
                break  # break out from the training loop

    # Step 5: Potentially test your model
    # Don't overfit to the test dataset - only when you've fine-tuned your model on the validation dataset should you
    # report your final loss and micro-F1 on the test dataset. Friends don't let friends overfit to the test data. <3
    if config['should_test']:
        micro_f1 = main_loop(phase=LoopPhase.TEST, data_loader=data_loader_test)
        config['test_perf'] = micro_f1

        print('*' * 50)
        print(f'Test micro-F1 = {micro_f1}')
    else:
        config['test_perf'] = -1

    # Save the latest GAT in the binaries directory
    torch.save(
        utils.get_training_state(config, gat),
        os.path.join(BINARIES_PATH, utils.get_available_binary_name(config['dataset_name']))
    )


def get_training_args():
    parser = argparse.ArgumentParser()

    # Training related
    parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=200)
    parser.add_argument("--patience_period", type=int, help="number of epochs with no improvement on val before terminating", default=100)
    parser.add_argument("--lr", type=float, help="model learning rate", default=5e-3)
    parser.add_argument("--weight_decay", type=float, help="L2 regularization on model weights", default=0)
    parser.add_argument("--should_test", action='store_true', help='should test the model on the test dataset? (no by default)')
    parser.add_argument("--force_cpu", action='store_true', help='use CPU if your GPU is too small (no by default)')

    # Dataset related (note: we need the dataset name for metadata and related stuff, and not for picking the dataset)
    parser.add_argument("--dataset_name", choices=[el.name for el in DatasetType], help='dataset to use for training', default=DatasetType.PPI.name)
    parser.add_argument("--batch_size", type=int, help='number of graphs in a batch', default=2)
    parser.add_argument("--should_visualize", action='store_true', help='should visualize the dataset? (no by default)')

    # Logging/debugging/checkpoint related (helps a lot with experimentation)
    parser.add_argument("--enable_tensorboard", action='store_true', help="enable tensorboard logging (no by default)")
    parser.add_argument("--console_log_freq", type=int, help="log to output console (batch) freq (None for no logging)", default=10)
    parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (epoch) freq (None for no logging)", default=5)
    args = parser.parse_args()

    # I'm leaving the hyperparam values as reported in the paper, but I experimented a bit and the comments suggest
    # how you can make GAT achieve an even higher micro-F1 or make it smaller
    gat_config = {
        # GNNs, contrary to CNNs, are often shallow (it ultimately depends on the graph properties)
        "num_of_layers": 3,  # PPI has got 42% of nodes with all 0 features - that's why 3 layers are useful
        "num_heads_per_layer": [4, 4, 6],  # other values may give even better results from the reported ones
        "num_features_per_layer": [PPI_NUM_INPUT_FEATURES, 256, 256, PPI_NUM_CLASSES],  # 64 would also give ~0.975 uF1!
        "add_skip_connection": True,  # skip connection is very important! (keep it otherwise micro-F1 is almost 0)
        "bias": True,  # bias doesn't matter that much
        "dropout": 0.0,  # dropout hurts the performance (best to keep it at 0)
        "layer_type": LayerType.IMP3  # the only implementation that supports the inductive setting
    }

    # Wrapping training configuration into a dictionary
    training_config = dict()
    for arg in vars(args):
        training_config[arg] = getattr(args, arg)
    training_config['ppi_load_test_only'] = False  # load both train/val/test data loaders (don't change it)

    # Add additional config information
    training_config.update(gat_config)

    return training_config


if __name__ == '__main__':

    # Train the graph attention network (GAT)
    train_gat_ppi(get_training_args())
コード例 #10
0
    def train_val_loop(is_train, token_ids_loader, epoch):
        global num_of_trg_tokens_processed, global_train_step, global_val_step, writer

        if is_train:
            baseline_transformer.train()
        else:
            baseline_transformer.eval()

        device = next(baseline_transformer.parameters()).device

        #
        # Main loop - start of the CORE PART
        #
        for batch_idx, token_ids_batch in enumerate(token_ids_loader):
            src_token_ids_batch, trg_token_ids_batch_input, trg_token_ids_batch_gt = get_src_and_trg_batches(
                token_ids_batch)
            src_mask, trg_mask, num_src_tokens, num_trg_tokens = get_masks_and_count_tokens(
                src_token_ids_batch, trg_token_ids_batch_input, pad_token_id,
                device)

            # log because the KL loss expects log probabilities (just an implementation detail)
            predicted_log_distributions = baseline_transformer(
                src_token_ids_batch, trg_token_ids_batch_input, src_mask,
                trg_mask)
            smooth_target_distributions = label_smoothing(
                trg_token_ids_batch_gt)  # these are regular probabilities

            if is_train:
                custom_lr_optimizer.zero_grad(
                )  # clean the trainable weights gradients in the computational graph

            loss = kl_div_loss(predicted_log_distributions,
                               smooth_target_distributions)

            if is_train:
                loss.backward(
                )  # compute the gradients for every trainable weight in the computational graph
                custom_lr_optimizer.step()  # apply the gradients to weights

            # End of CORE PART

            #
            # Logging and metrics
            #

            if is_train:
                global_train_step += 1
                num_of_trg_tokens_processed += num_trg_tokens

                if training_config['enable_tensorboard']:
                    writer.add_scalar('training_loss', loss.item(),
                                      global_train_step)

                if training_config[
                        'console_log_freq'] is not None and batch_idx % training_config[
                            'console_log_freq'] == 0:
                    print(
                        f'Transformer training: time elapsed= {(time.time() - time_start):.2f} [s] '
                        f'| epoch={epoch + 1} | batch= {batch_idx + 1} '
                        f'| target tokens/batch= {num_of_trg_tokens_processed / training_config["console_log_freq"]}'
                    )

                    num_of_trg_tokens_processed = 0

                # Save model checkpoint
                if training_config['checkpoint_freq'] is not None and (
                        epoch + 1
                ) % training_config['checkpoint_freq'] == 0 and batch_idx == 0:
                    ckpt_model_name = f"transformer_ckpt_epoch_{epoch + 1}.pth"
                    torch.save(
                        utils.get_training_state(training_config,
                                                 baseline_transformer),
                        os.path.join(CHECKPOINTS_PATH, ckpt_model_name))
            else:
                global_val_step += 1

                if training_config['enable_tensorboard']:
                    writer.add_scalar('val_loss', loss.item(), global_val_step)
コード例 #11
0
def train_transformer(training_config):
    device = torch.device("cuda" if torch.cuda.is_available() else
                          "cpu")  # checking whether you have a GPU, I hope so!

    # Step 1: Prepare data loaders
    train_token_ids_loader, val_token_ids_loader, src_field_processor, trg_field_processor = get_data_loaders(
        training_config['dataset_path'], training_config['language_direction'],
        training_config['dataset_name'], training_config['batch_size'], device)

    pad_token_id = src_field_processor.vocab.stoi[
        PAD_TOKEN]  # pad token id is the same for target as well
    src_vocab_size = len(src_field_processor.vocab)
    trg_vocab_size = len(trg_field_processor.vocab)

    # Step 2: Prepare the model (original transformer) and push to GPU
    baseline_transformer = Transformer(
        model_dimension=BASELINE_MODEL_DIMENSION,
        src_vocab_size=src_vocab_size,
        trg_vocab_size=trg_vocab_size,
        number_of_heads=BASELINE_MODEL_NUMBER_OF_HEADS,
        number_of_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
        dropout_probability=BASELINE_MODEL_DROPOUT_PROB).to(device)

    # Step 3: Prepare other training related utilities
    kl_div_loss = nn.KLDivLoss(
        reduction='batchmean')  # gives better BLEU score than "mean"

    # Makes smooth target distributions as opposed to conventional one-hot distributions
    # My feeling is that this is a really dummy and arbitrary heuristic but time will tell.
    label_smoothing = LabelSmoothingDistribution(
        BASELINE_MODEL_LABEL_SMOOTHING_VALUE, pad_token_id, trg_vocab_size,
        device)

    # Check out playground.py for an intuitive visualization of how the LR changes with time/training steps, easy stuff.
    custom_lr_optimizer = CustomLRAdamOptimizer(
        Adam(baseline_transformer.parameters(), betas=(0.9, 0.98), eps=1e-9),
        BASELINE_MODEL_DIMENSION, training_config['num_warmup_steps'])

    # The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops
    train_val_loop = get_train_val_loop(baseline_transformer,
                                        custom_lr_optimizer, kl_div_loss,
                                        label_smoothing, pad_token_id,
                                        time.time())

    # Step 4: Start the training
    for epoch in range(training_config['num_of_epochs']):
        # Training loop
        train_val_loop(is_train=True,
                       token_ids_loader=train_token_ids_loader,
                       epoch=epoch)

        # Validation loop
        with torch.no_grad():
            train_val_loop(is_train=False,
                           token_ids_loader=val_token_ids_loader,
                           epoch=epoch)

            bleu_score = utils.calculate_bleu_score(baseline_transformer,
                                                    val_token_ids_loader,
                                                    trg_field_processor)
            if training_config['enable_tensorboard']:
                writer.add_scalar('bleu_score', bleu_score, epoch)

    # Save the latest transformer in the binaries directory
    torch.save(utils.get_training_state(training_config, baseline_transformer),
               os.path.join(BINARIES_PATH, utils.get_available_binary_name()))
コード例 #12
0
def train_gat_ppi(config):

    # 记录全局参数,最好的验证F1值,最好的验证损失
    global BEST_VAL_MICRO_F1, BEST_VAL_LOSS

    device = torch.device("cuda" if torch.cuda.is_available()
                          and not config['force_cpu'] else "cpu")

    # Step1 加载数据
    data_loader_train, data_loader_val, data_loader_test = load_graph_data(
        config, device)

    # Step2 准备模型
    gat = GAT_ppi(num_of_layers=config['num_of_layers'],
                  num_heads_per_layer=config['num_heads_per_layer'],
                  num_features_per_layer=config['num_features_per_layer'],
                  add_skip_connection=config['add_skip_connection'],
                  bias=config['bias'],
                  dropout=config['dropout'],
                  log_attention_weights=False).to(device)

    # Step3 准备训练工具
    loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
    optimizer = Adam(gat.parameters(),
                     lr=config['lr'],
                     weight_decay=config['weight_decay'])

    # 返回主迭代方法,这样提高代码复用率
    main_loop = get_main_loop(config=config,
                              gat=gat,
                              sigmoid_cross_entropy_loss=loss_fn,
                              optimizer=optimizer,
                              patience_period=config['patience_period'],
                              time_start=time.time())

    BEST_VAL_MICRO_F1, BEST_VAL_LOSS, PATIENCE_CNT = [0, 0, 0]  # 重置

    # Step4 开始训练过程
    for epoch in range(config['num_of_epochs']):
        # 训练循环
        main_loop(phase=LoopPhase.TRAIN,
                  data_loader=data_loader_train,
                  epoch=epoch)

        # 验证循环
        with torch.no_grad():
            try:
                main_loop(phase=LoopPhase.VAL,
                          data_loader=data_loader_val,
                          epoch=epoch)
            except Exception as e:
                print(str(e))
                break

    # Step5 验证
    if config['should_test']:
        micro_f1 = main_loop(phase=LoopPhase.TEST,
                             data_loader=data_loader_test)
        config['test_perf'] = micro_f1

        print('*' * 50)
        print(f'Test micro-F1 = {micro_f1}')

    else:
        config['test_perf'] = -1

    # 保存最新的GAT模型的二进制文件
    torch.save(
        utils.get_training_state(config, gat),
        os.path.join(BINARIES_PATH,
                     utils.get_available_binary_name(config['dataset_name'])))
コード例 #13
0
    def main_loop(phase, data_loader, epoch=0):
        global BEST_VAL_MICRO_F1, BEST_VAL_LOSS, PATIENCE_CNT, writer

        if phase == LoopPhase.TRAIN:
            gat.train()
        else:
            gat.eval()

        for batch_idx, (node_features, gt_node_labels,
                        edge_index) in enumerate(data_loader):
            """迭代一批图形数据,原论文是2张图,这里将2张图合为一张图,相当于一张图2个连通分量"""

            edge_index = edge_index.to(device)
            node_features = node_features.to(device)
            gt_node_labels = gt_node_labels.to(device)

            graph_data = (node_features, edge_index)  # 打包数据

            nodes_unnormalized_scores = gat(graph_data)[
                0]  # 最后输出的分数,还没经过Sigmoid,由于对于每个分量而言为2分类问题(0或1),所以使用sigmoid

            loss = sigmoid_cross_entropy_loss(nodes_unnormalized_scores,
                                              gt_node_labels)

            if phase == LoopPhase.TRAIN:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # 计算f1
            pred = (nodes_unnormalized_scores > 0
                    ).float().cpu().numpy()  # 只要得分大于0 sigmoid之后就大于0.5,那么就认为它是1
            gt = gt_node_labels.cpu().numpy()
            micro_f1 = f1_score(gt, pred, average='micro')

            # 记录数据

            global_step = len(data_loader) * epoch + batch_idx
            if phase == LoopPhase.TRAIN:
                # 记录指标
                if config['enable_tensorboard']:
                    writer.add_scalar('training_loss', loss.item(),
                                      global_step)
                    writer.add_scalar('training_micro_f1', micro_f1,
                                      global_step)

                # 记录数据在控制台,每代记录一次,记录的是这一代第一个batch
                if config[
                        'console_log_freq'] is not None and batch_idx % config[
                            'console_log_freq'] == 0:
                    print(
                        f'GAT training: time elapsed= {(time.time() - time_start):.2f} [s] |'
                        f' epoch={epoch + 1} | batch={batch_idx + 1} | train micro-F1={micro_f1}.'
                    )

                # 保存checkpoint
                if config['checkpoint_freq'] is not None and (
                        epoch +
                        1) % config['checkpoint_freq'] == 0 and batch_idx == 0:
                    ckpt_model_name = f'gat_{config["dataset_name"]}_ckpt_epoch_{epoch + 1}.pth'
                    config['test_perf'] = -1  # 尚未进行性能测试
                    torch.save(utils.get_training_state(config, gat),
                               os.path.join(CHECKPOINTS_PATH, ckpt_model_name))

            elif phase == LoopPhase.VAL:
                if config['enable_tensorboard']:
                    writer.add_scalar('val_loss', loss.item(), global_step)
                    writer.add_scalar('val_micro_f1', micro_f1, global_step)

                if config[
                        'console_log_freq'] is not None and batch_idx % config[
                            'console_log_freq'] == 0:
                    print(
                        f'GAT validation: time elapsed= {(time.time() - time_start):.2f} [s] |'
                        f' epoch={epoch + 1} | batch={batch_idx + 1} | val micro-F1={micro_f1}'
                    )

                # 选择最优参数
                if micro_f1 > BEST_VAL_MICRO_F1 or loss.item() < BEST_VAL_LOSS:
                    BEST_VAL_MICRO_F1 = max(micro_f1, BEST_VAL_MICRO_F1)
                    BEST_VAL_LOSS = min(loss.item(), BEST_VAL_LOSS)
                    PATIENCE_CNT = 0
                else:
                    PATIENCE_CNT += 1

                if PATIENCE_CNT >= patience_period:
                    raise Exception(
                        'Stopping the training, the universe has no more patience for this training.'
                    )
            else:
                return micro_f1  # 单纯的验证,直接返回f1值