Ejemplo n.º 1
0
def train(inv_model, src_data, tgt_data, loss_func, inv_opt, args):
    inv_model.train()
    src_data.shuffle()

    new_size = len(tgt_data) / args.batch_size * args.src_batch_size
    new_size = int(new_size)

    src_pos_data = [d for d in src_data if d.targets[0] == 1]
    src_neg_data = [d for d in src_data if d.targets[0] == 0]
    print(len(tgt_data))
    print(len(src_pos_data), len(src_neg_data), new_size)
    src_data = MoleculeDataset(src_pos_data + src_neg_data[:new_size])

    src_data.shuffle()
    tgt_data.shuffle()

    src_iter = range(0, len(src_data), args.src_batch_size)
    tgt_iter = range(0, len(tgt_data), args.batch_size)

    for i, j in zip(src_iter, tgt_iter):
        inv_model.zero_grad()
        src_batch = src_data[i:i + args.src_batch_size]
        src_batch = MoleculeDataset(src_batch)
        src_loss = forward(inv_model, src_batch, loss_func, is_source=True)

        tgt_batch = tgt_data[j:j + args.batch_size]
        tgt_batch = MoleculeDataset(tgt_batch)
        tgt_loss = forward(inv_model, tgt_batch, loss_func, is_source=False)

        loss = (src_loss + tgt_loss) / 2
        loss.backward()
        inv_opt[0].step()
        inv_opt[1].step()

        lr = inv_opt[1].get_lr()[0]
        ignorm = compute_gnorm(inv_model)
        print(f'lr: {lr:.5f}, loss: {loss:.4f}, gnorm: {ignorm:.4f}')
Ejemplo n.º 2
0
def train(model: MoleculeModel,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: A :class:`~chemprop.models.model.MoleculeModel`.
    :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`.
    :param loss_func: Loss function.
    :param optimizer: An optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for recording output.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()
    loss_sum, iter_count = 0, 0

    for batch in tqdm(data_loader, total=len(data_loader)):
        # Prepare batch
        batch: MoleculeDataset
        mol_batch, features_batch, target_batch = batch.batch_graph(
        ), batch.features(), batch.targets()
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        # Run model
        model.zero_grad()
        preds = model(mol_batch, features_batch)

        # Move tensors to correct device
        mask = mask.to(preds.device)
        targets = targets.to(preds.device)
        class_weights = torch.ones(targets.shape, device=preds.device)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(batch)

        loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
Ejemplo n.º 3
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: Number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: Total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, iter_count = 0, 0

    # don't use the last batch if it's small, for stability
    num_iters = len(data) // args.batch_size * args.batch_size

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break

        mol_batch = MoleculeDataset(data[i:i + args.batch_size])

        smiles_batch, features_batch, target_batch = \
            mol_batch.smiles(), mol_batch.features(), mol_batch.targets()

        mask = torch.Tensor([[not np.isnan(x) for x in tb]
                             for tb in target_batch])

        targets = torch.Tensor([[0 if np.isnan(x) else x for x in tb]
                                for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(smiles_batch, features_batch)

        # todo: change the loss function for property prediction tasks

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([loss_func(preds[:, target_index, :],
                                        targets[:, target_index]).unsqueeze(1)
                              for target_index in range(preds.size(1))],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask

        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(
                f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
            debug(
                f'\nLoss = {loss_avg:.4e}, PNorm = {pnorm:.4f},'
                f' GNorm = {gnorm:.4f}, {lrs_str}')

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)

                for idx, learn_rate in enumerate(lrs):
                    writer.add_scalar(
                        f'learning_rate_{idx}', learn_rate, n_iter)

    return n_iter
Ejemplo n.º 4
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None,
          chunk_names: bool = False,
          val_smiles: List[str] = None,
          test_smiles: List[str] = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :param chunk_names: Whether to train on the data in chunks. In this case,
    data must be a list of paths to the data chunks.
    :param val_smiles: Validation smiles strings without targets.
    :param test_smiles: Test smiles strings without targets, used for adversarial setting.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    if args.dataset_type == 'bert_pretraining':
        features_loss = nn.MSELoss()

    if chunk_names:
        for path, memo_path in tqdm(data, total=len(data)):
            featurization.SMILES_TO_FEATURES = dict()
            if os.path.isfile(memo_path):
                found_memo = True
                with open(memo_path, 'rb') as f:
                    featurization.SMILES_TO_FEATURES = pickle.load(f)
            else:
                found_memo = False
            with open(path, 'rb') as f:
                chunk = pickle.load(f)
            if args.moe:
                for source in chunk:
                    source.shuffle()
            else:
                chunk.shuffle()
            n_iter = train(model=model,
                           data=chunk,
                           loss_func=loss_func,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           args=args,
                           n_iter=n_iter,
                           logger=logger,
                           writer=writer,
                           chunk_names=False,
                           val_smiles=val_smiles,
                           test_smiles=test_smiles)
            if not found_memo:
                with open(memo_path, 'wb') as f:
                    pickle.dump(featurization.SMILES_TO_GRAPH,
                                f,
                                protocol=pickle.HIGHEST_PROTOCOL)
        return n_iter

    if not args.moe:
        data.shuffle()

    loss_sum, iter_count = 0, 0
    if args.adversarial:
        if args.moe:
            train_smiles = []
            for d in data:
                train_smiles += d.smiles()
        else:
            train_smiles = data.smiles()
        train_val_smiles = train_smiles + val_smiles
        d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0

    if args.moe:
        test_smiles = list(test_smiles)
        random.shuffle(test_smiles)
        train_smiles = []
        for d in data:
            d.shuffle()
            train_smiles.append(d.smiles())
        num_iters = min(len(test_smiles), min([len(d) for d in data]))
    elif args.maml:
        num_iters = args.maml_batches_per_epoch * args.maml_batch_size
        model.zero_grad()
        maml_sum_loss = 0
    else:
        num_iters = len(data) if args.last_batch else len(
            data) // args.batch_size * args.batch_size

    if args.parallel_featurization:
        batch_queue = Queue(args.batch_queue_max_size)
        exit_queue = Queue(1)
        batch_process = Process(target=async_mol2graph,
                                args=(batch_queue, data, args, num_iters,
                                      args.batch_size, exit_queue,
                                      args.last_batch))
        batch_process.start()
        currently_loaded_batches = []

    iter_size = 1 if args.maml else args.batch_size

    for i in trange(0, num_iters, iter_size):
        if args.moe:
            if not args.batch_domain_encs:
                model.compute_domain_encs(
                    train_smiles)  # want to recompute every batch
            mol_batch = [
                MoleculeDataset(d[i:i + args.batch_size]) for d in data
            ]
            train_batch, train_targets = [], []
            for b in mol_batch:
                tb, tt = b.smiles(), b.targets()
                train_batch.append(tb)
                train_targets.append(tt)
            test_batch = test_smiles[i:i + args.batch_size]
            loss = model.compute_loss(train_batch, train_targets, test_batch)
            model.zero_grad()

            loss_sum += loss.item()
            iter_count += len(mol_batch)
        elif args.maml:
            task_train_data, task_test_data, task_idx = data.sample_maml_task(
                args)
            mol_batch = task_test_data
            smiles_batch, features_batch, target_batch = task_train_data.smiles(
            ), task_train_data.features(), task_train_data.targets(task_idx)
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor(target_batch).unsqueeze(1)
            if next(model.parameters()).is_cuda:
                targets = targets.cuda()
            preds = model(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            grad = torch.autograd.grad(
                loss, [p for p in model.parameters() if p.requires_grad])
            theta = [
                p for p in model.named_parameters() if p[1].requires_grad
            ]  # comes in same order as grad
            theta_prime = {
                p[0]: p[1] - args.maml_lr * grad[i]
                for i, p in enumerate(theta)
            }
            for name, nongrad_param in [
                    p for p in model.named_parameters()
                    if not p[1].requires_grad
            ]:
                theta_prime[name] = nongrad_param + torch.zeros(
                    nongrad_param.size()).to(nongrad_param)
        else:
            # Prepare batch
            if args.parallel_featurization:
                if len(currently_loaded_batches) == 0:
                    currently_loaded_batches = batch_queue.get()
                mol_batch, featurized_mol_batch = currently_loaded_batches.pop(
                )
            else:
                if not args.last_batch and i + args.batch_size > len(data):
                    break
                mol_batch = MoleculeDataset(data[i:i + args.batch_size])
            smiles_batch, features_batch, target_batch = mol_batch.smiles(
            ), mol_batch.features(), mol_batch.targets()

            if args.dataset_type == 'bert_pretraining':
                batch = mol2graph(smiles_batch, args)
                mask = mol_batch.mask()
                batch.bert_mask(mask)
                mask = 1 - torch.FloatTensor(mask)  # num_atoms
                features_targets = torch.FloatTensor(
                    target_batch['features']
                ) if target_batch[
                    'features'] is not None else None  # num_molecules x features_size
                targets = torch.FloatTensor(target_batch['vocab'])  # num_atoms
                if args.bert_vocab_func == 'feature_vector':
                    mask = mask.reshape(-1, 1)
                else:
                    targets = targets.long()
            else:
                batch = smiles_batch
                mask = torch.Tensor([[x is not None for x in tb]
                                     for tb in target_batch])
                targets = torch.Tensor([[0 if x is None else x for x in tb]
                                        for tb in target_batch])

            if next(model.parameters()).is_cuda:
                mask, targets = mask.cuda(), targets.cuda()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    features_targets = features_targets.cuda()

            if args.class_balance:
                class_weights = []
                for task_num in range(data.num_tasks()):
                    class_weights.append(
                        args.class_weights[task_num][targets[:,
                                                             task_num].long()])
                class_weights = torch.stack(
                    class_weights).t()  # num_molecules x num_tasks
            else:
                class_weights = torch.ones(targets.shape)

            if args.cuda:
                class_weights = class_weights.cuda()

            # Run model
            model.zero_grad()
            if args.parallel_featurization:
                previous_graph_input_mode = model.encoder.graph_input
                model.encoder.graph_input = True  # force model to accept already processed input
                preds = model(featurized_mol_batch, features_batch)
                model.encoder.graph_input = previous_graph_input_mode
            else:
                preds = model(batch, features_batch)
            if args.dataset_type == 'regression_with_binning':
                preds = preds.view(targets.size(0), targets.size(1), -1)
                targets = targets.long()
                loss = 0
                for task in range(targets.size(1)):
                    loss += loss_func(
                        preds[:, task, :], targets[:, task]
                    ) * class_weights[:,
                                      task] * mask[:,
                                                   task]  # for some reason cross entropy doesn't support multi target
                loss = loss.sum() / mask.sum()
            else:
                if args.dataset_type == 'unsupervised':
                    targets = targets.long().reshape(-1)

                if args.dataset_type == 'bert_pretraining':
                    features_preds, preds = preds['features'], preds['vocab']

                if args.dataset_type == 'kernel':
                    preds = preds.view(int(preds.size(0) / 2), 2,
                                       preds.size(1))
                    preds = model.kernel_output_layer(preds)

                loss = loss_func(preds, targets) * class_weights * mask
                if args.predict_features_and_task:
                    loss = (loss.sum() + loss[:, :-args.features_size].sum() * (args.task_weight-1)) \
                                / (mask.sum() + mask[:, :-args.features_size].sum() * (args.task_weight-1))
                else:
                    loss = loss.sum() / mask.sum()

                if args.dataset_type == 'bert_pretraining' and features_targets is not None:
                    loss += features_loss(features_preds, features_targets)

            loss_sum += loss.item()
            iter_count += len(mol_batch)

        if args.maml:
            model_prime = build_model(args=args, params=theta_prime)
            smiles_batch, features_batch, target_batch = task_test_data.smiles(
            ), task_test_data.features(), [
                t[task_idx] for t in task_test_data.targets()
            ]
            # no mask since we only picked data points that have the desired target
            targets = torch.Tensor([[t] for t in target_batch])
            if next(model_prime.parameters()).is_cuda:
                targets = targets.cuda()
            model_prime.zero_grad()
            preds = model_prime(smiles_batch, features_batch)
            loss = loss_func(preds, targets)
            loss = loss.sum() / len(smiles_batch)
            loss_sum += loss.item()
            iter_count += len(
                smiles_batch
            )  # TODO check that this makes sense, but it's just for display
            maml_sum_loss += loss
            if i % args.maml_batch_size == args.maml_batch_size - 1:
                maml_sum_loss.backward()
                optimizer.step()
                model.zero_grad()
                maml_sum_loss = 0
        else:
            loss.backward()
            if args.max_grad_norm is not None:
                clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()

        if args.adjust_weight_decay:
            current_pnorm = compute_pnorm(model)
            if current_pnorm < args.pnorm_target:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i]['weight_decay'] = max(
                        0, optimizer.param_groups[i]['weight_decay'] -
                        args.adjust_weight_decay_step)
            else:
                for i in range(len(optimizer.param_groups)):
                    optimizer.param_groups[i][
                        'weight_decay'] += args.adjust_weight_decay_step

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        if args.adversarial:
            for _ in range(args.gan_d_per_g):
                train_val_smiles_batch = random.sample(train_val_smiles,
                                                       args.batch_size)
                test_smiles_batch = random.sample(test_smiles, args.batch_size)
                d_loss, gp_norm = model.train_D(train_val_smiles_batch,
                                                test_smiles_batch)
            train_val_smiles_batch = random.sample(train_val_smiles,
                                                   args.batch_size)
            test_smiles_batch = random.sample(test_smiles, args.batch_size)
            g_loss = model.train_G(train_val_smiles_batch, test_smiles_batch)

            # we probably only care about the g_loss honestly
            d_loss_sum += d_loss * args.batch_size
            gp_norm_sum += gp_norm * args.batch_size
            g_loss_sum += g_loss * args.batch_size

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            if args.adversarial:
                d_loss_avg, g_loss_avg, gp_norm_avg = d_loss_sum / iter_count, g_loss_sum / iter_count, gp_norm_sum / iter_count
                d_loss_sum, g_loss_sum, gp_norm_sum = 0, 0, 0
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join('lr_{} = {:.4e}'.format(i, lr)
                                for i, lr in enumerate(lrs))
            debug("Loss = {:.4e}, PNorm = {:.4f}, GNorm = {:.4f}, {}".format(
                loss_avg, pnorm, gnorm, lrs_str))
            if args.adversarial:
                debug(
                    "D Loss = {:.4e}, G Loss = {:.4e}, GP Norm = {:.4}".format(
                        d_loss_avg, g_loss_avg, gp_norm_avg))

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar('learning_rate_{}'.format(i), lr, n_iter)

    if args.parallel_featurization:
        exit_queue.put(
            0)  # dummy var to get the subprocess to know that we're done
        batch_process.join()

    return n_iter
Ejemplo n.º 5
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, iter_count = 0, 0

    iter_size = args.batch_size

    if args.class_balance:
        # Reconstruct data so that each batch has equal number of positives and negatives
        # (will leave out a different random sample of negatives each epoch)
        assert len(
            data[0].targets) == 1  # only works for single class classification
        pos = [d for d in data if d.targets[0] == 1]
        neg = [d for d in data if d.targets[0] == 0]

        new_data = []
        pos_size = iter_size // 2
        pos_index = neg_index = 0
        while True:
            new_pos = pos[pos_index:pos_index + pos_size]
            new_neg = neg[neg_index:neg_index + iter_size - len(new_pos)]

            if len(new_pos) == 0 or len(new_neg) == 0:
                break

            if len(new_pos) + len(new_neg) < iter_size:
                new_pos = pos[pos_index:pos_index + iter_size - len(new_neg)]

            new_data += new_pos + new_neg

            pos_index += len(new_pos)
            neg_index += len(new_neg)

        data = new_data

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch, weight_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets(), mol_batch.weights()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])
        weights = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in weight_batch])
        #        print (weight_batch)
        #        print (weights)

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        if args.enable_weight:
            class_weights = weights
        else:
            class_weights = torch.ones(targets.shape)
#        print(class_weights)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
#        print ("loss")
#        print (loss)
#        print (class_weights)

        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data = deepcopy(data)

    data.shuffle()

    if args.uncertainty == 'bootstrap':
        data.sample(int(4 * len(data) / args.ensemble_size))

    loss_sum, iter_count = 0, 0

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if model.uncertainty:
            pred_targets = preds[:, [
                j for j in range(len(preds[0])) if j % 2 == 0
            ]]
            pred_var = preds[:,
                             [j for j in range(len(preds[0])) if j % 2 == 1]]
            loss = loss_func(pred_targets, pred_var, targets)
            # sigma = ((pred_targets - targets) ** 2).detach()
            # loss = loss_func(pred_targets, targets) * class_weights * mask
            # loss += nn.MSELoss(reduction='none')(pred_sigma, sigma) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, iter_count = 0, 0

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        mask = torch.Tensor([[x is not None for x in tb]
                             for tb in target_batch])
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        class_weights = torch.ones(targets.shape)
        #print('class_weight',class_weights.size(),class_weights)
        #print('mask',mask.size(),mask)

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        ############ add L1 regularization ############

        ffn_d0_L1_reg_loss = 0
        ffn_d1_L1_reg_loss = 0
        ffn_d2_L1_reg_loss = 0
        ffn_final_L1_reg_loss = 0
        ffn_mol_L1_reg_loss = 0

        lamda_ffn_d0 = 0
        lamda_ffn_d1 = 0
        lamda_ffn_d2 = 0
        lamda_ffn_final = 0
        lamda_ffn_mol = 0

        for param in model.ffn_d0.parameters():
            ffn_d0_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_d1.parameters():
            ffn_d1_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_d2.parameters():
            ffn_d2_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_final.parameters():
            ffn_final_L1_reg_loss += torch.sum(torch.abs(param))
        for param in model.ffn_mol.parameters():
            ffn_mol_L1_reg_loss += torch.sum(torch.abs(param))

        loss += lamda_ffn_d0 * ffn_d0_L1_reg_loss + lamda_ffn_d1 * ffn_d1_L1_reg_loss + lamda_ffn_d2 * ffn_d2_L1_reg_loss + lamda_ffn_final * ffn_final_L1_reg_loss + lamda_ffn_mol * ffn_mol_L1_reg_loss

        ############ add L1 regularization ############

        ############ add L2 regularization ############
        '''
        ffn_d0_L2_reg_loss = 0
        ffn_d1_L2_reg_loss = 0
        ffn_d2_L2_reg_loss = 0
        ffn_final_L2_reg_loss = 0
        ffn_mol_L2_reg_loss = 0

        lamda_ffn_d0 = 1e-6
        lamda_ffn_d1 = 1e-6
        lamda_ffn_d2 = 1e-5
        lamda_ffn_final = 1e-4
        lamda_ffn_mol = 1e-3 

        for param in model.ffn_d0.parameters():
            ffn_d0_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_d1.parameters():
            ffn_d1_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_d2.parameters():
            ffn_d2_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_final.parameters():
            ffn_final_L2_reg_loss += torch.sum(torch.square(param))
        for param in model.ffn_mol.parameters():
            ffn_mol_L2_reg_loss += torch.sum(torch.square(param))

        loss += lamda_ffn_d0 * ffn_d0_L2_reg_loss + lamda_ffn_d1 * ffn_d1_L2_reg_loss + lamda_ffn_d2 * ffn_d2_L2_reg_loss + lamda_ffn_final * ffn_final_L2_reg_loss + lamda_ffn_mol * ffn_mol_L2_reg_loss
        '''
        ############ add L2 regularization ############

        loss_sum += loss.item()
        iter_count += len(mol_batch)

        #loss.backward(retain_graph=True)  # wei, retain_graph=True
        loss.backward()
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)
    #print(model)
    return n_iter
Ejemplo n.º 8
0
def train(model: nn.Module,
          data: Union[MoleculeDataset, List[MoleculeDataset]],
          loss_func: Callable,
          metric_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle()

    loss_sum, metric_sum, iter_count = [0]*(len(args.atom_targets) + len(args.bond_targets)), \
                                       [0]*(len(args.atom_targets) + len(args.bond_targets)), 0

    loss_weights = args.loss_weights

    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MoleculeDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        #mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])

        # FIXME assign 0 to None in target
        # targets = [[0 if x is None else x for x in tb] for tb in target_batch]

        targets = [torch.Tensor(np.concatenate(x)) for x in zip(*target_batch)]
        if next(model.parameters()).is_cuda:
            #   mask, targets = mask.cuda(), targets.cuda()
            targets = [x.cuda() for x in targets]
        # FIXME
        #class_weights = torch.ones(targets.shape)

        #if args.cuda:
        #   class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)
        targets = [x.reshape([-1, 1]) for x in targets]

        #FIXME mutlticlass
        '''
        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1) for target_index in range(preds.size(1))], dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        '''
        loss_multi_task = []
        metric_multi_task = []
        for target, pred, lw in zip(targets, preds, loss_weights):
            loss = loss_func(pred, target)
            loss = loss.sum() / target.shape[0]
            loss_multi_task.append(loss * lw)
            if args.cuda:
                metric = metric_func(pred.data.cpu().numpy(),
                                     target.data.cpu().numpy())
            else:
                metric = metric_func(pred.data.numpy(), target.data.numpy())
            metric_multi_task.append(metric)

        loss_sum = [x + y for x, y in zip(loss_sum, loss_multi_task)]
        iter_count += 1

        sum(loss_multi_task).backward()
        optimizer.step()

        metric_sum = [x + y for x, y in zip(metric_sum, metric_multi_task)]

        if isinstance(scheduler, NoamLR) or isinstance(scheduler, SinexpLR):
            scheduler.step()

        n_iter += len(mol_batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = [x / iter_count for x in loss_sum]
            metric_avg = [x / iter_count for x in metric_sum]
            loss_sum, iter_count, metric_sum = [0]*(len(args.atom_targets) + len(args.bond_targets)), \
                                               0, \
                                               [0]*(len(args.atom_targets) + len(args.bond_targets))

            loss_str = ', '.join(f'lss_{i} = {lss:.4e}'
                                 for i, lss in enumerate(loss_avg))
            metric_str = ', '.join(f'mc_{i} = {mc:.4e}'
                                   for i, mc in enumerate(metric_avg))
            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'{loss_str}, {metric_str}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                for i, lss in enumerate(loss_avg):
                    writer.add_scalar(f'train_loss_{i}', lss, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
Ejemplo n.º 9
0
def train(model: nn.Module,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None,
          gp_switch: bool = False,
          likelihood = None,
          bbp_switch = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data_loader: A MoleculeDataLoader.
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    
    

        
    debug = logger.debug if logger is not None else print
    
    model.train()
    if likelihood is not None:
        likelihood.train()
    
    loss_sum = 0
    if bbp_switch is not None:
        data_loss_sum = 0
        kl_loss_sum = 0
        kl_loss_depth_sum = 0

    #for batch in tqdm(data_loader, total=len(data_loader)):
    for batch in data_loader:
        # Prepare batch
        batch: MoleculeDataset

        # .batch_graph() returns BatchMolGraph
        # .features() returns None if no additional features
        # .targets() returns list of lists of floats containing the targets
        mol_batch, features_batch, target_batch = batch.batch_graph(), batch.features(), batch.targets()
        
        # mask is 1 where targets are not None
        mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])
        # where targets are None, replace with 0
        targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch])

        # Move tensors to correct device
        mask = mask.to(args.device)
        targets = targets.to(args.device)
        class_weights = torch.ones(targets.shape, device=args.device)
        
        # zero gradients
        model.zero_grad()
        optimizer.zero_grad()
        
        
        ##### FORWARD PASS AND LOSS COMPUTATION #####
        
        
        if bbp_switch == None:
        
            # forward pass
            preds = model(mol_batch, features_batch)
    
            # compute loss
            if gp_switch:
                loss = -loss_func(preds, targets)
            else:
                loss = loss_func(preds, targets, torch.exp(model.log_noise))
        
        
        ### bbp non sample option
        if bbp_switch == 1:    
            preds, kl_loss = model(mol_batch, features_batch, sample = False)
            data_loss = loss_func(preds, targets, torch.exp(model.log_noise))
            kl_loss /= args.train_data_size
            loss = data_loss + kl_loss  
            
        ### bbp sample option
        if bbp_switch == 2:

            if args.samples_bbp == 1:
                preds, kl_loss = model(mol_batch, features_batch, sample=True)
                data_loss = loss_func(preds, targets, torch.exp(model.log_noise))
                kl_loss /= args.train_data_size
        
            elif args.samples_bbp > 1:
                data_loss_cum = 0
                kl_loss_cum = 0
        
                for i in range(args.samples_bbp):
                    preds, kl_loss_i = model(mol_batch, features_batch, sample=True)
                    data_loss_i = loss_func(preds, targets, torch.exp(model.log_noise))                    
                    kl_loss_i /= args.train_data_size                    
                    
                    data_loss_cum += data_loss_i
                    kl_loss_cum += kl_loss_i
        
                data_loss = data_loss_cum / args.samples_bbp
                kl_loss = kl_loss_cum / args.samples_bbp
            
            loss = data_loss + kl_loss

        ### DUN non sample option
        if bbp_switch == 3:
            cat = torch.exp(model.log_cat) / torch.sum(torch.exp(model.log_cat))    
            _, preds_list, kl_loss, kl_loss_depth = model(mol_batch, features_batch, sample=False)
            data_loss = loss_func(preds_list, targets, torch.exp(model.log_noise), cat)
            kl_loss /= args.train_data_size
            kl_loss_depth /= args.train_data_size
            loss = data_loss + kl_loss + kl_loss_depth
            #print('-----')
            #print(data_loss)
            #print(kl_loss)
            #print(cat)

        ### DUN sample option
        if bbp_switch == 4:

            cat = torch.exp(model.log_cat) / torch.sum(torch.exp(model.log_cat))

            if args.samples_dun == 1:
                _, preds_list, kl_loss, kl_loss_depth = model(mol_batch, features_batch, sample=True)
                data_loss = loss_func(preds_list, targets, torch.exp(model.log_noise), cat)
                kl_loss /= args.train_data_size
                kl_loss_depth /= args.train_data_size
        
            elif args.samples_dun > 1:
                data_loss_cum = 0
                kl_loss_cum = 0

                for i in range(args.samples_dun):
                    _, preds_list, kl_loss_i, kl_loss_depth = model(mol_batch, features_batch, sample=True)
                    data_loss_i = loss_func(preds_list, targets, torch.exp(model.log_noise), cat)
                    kl_loss_i /= args.train_data_size                    
                    kl_loss_depth /= args.train_data_size

                    data_loss_cum += data_loss_i
                    kl_loss_cum += kl_loss_i
        
                data_loss = data_loss_cum / args.samples_dun
                kl_loss = kl_loss_cum / args.samples_dun
            
            loss = data_loss + kl_loss + kl_loss_depth

            #print('-----')
            #print(data_loss)
            #print(kl_loss)
            #print(cat)
            
        #############################################
        
        
        # backward pass; update weights
        loss.backward()
        optimizer.step()
        
        
        #for name, parameter in model.named_parameters():
            #print(name)#, parameter.grad)
            #print(np.sum(np.array(parameter.grad)))

        # add to loss_sum and iter_count
        loss_sum += loss.item() * len(batch)
        if bbp_switch is not None:
            data_loss_sum += data_loss.item() * len(batch)
            kl_loss_sum += kl_loss.item() * len(batch)
            if bbp_switch > 2:
                kl_loss_depth_sum += kl_loss_depth * len(batch)

        # update learning rate by taking a step
        if isinstance(scheduler, NoamLR) or isinstance(scheduler, OneCycleLR):
            scheduler.step()

        # increment n_iter (total number of examples across epochs)
        n_iter += len(batch)

        

            
        ########### per epoch REPORTING
        if n_iter % args.train_data_size == 0:
            lrs = scheduler.get_last_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            
            loss_avg = loss_sum / args.train_data_size

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
            debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}')
            
            if bbp_switch is not None:
                data_loss_avg = data_loss_sum / args.train_data_size
                kl_loss_avg = kl_loss_sum / args.train_data_size
                wandb.log({"Total loss": loss_avg}, commit=False)
                wandb.log({"Likelihood cost": data_loss_avg}, commit=False)
                wandb.log({"KL cost": kl_loss_avg}, commit=False)

                if bbp_switch > 2:
                    kl_loss_depth_avg = kl_loss_depth_sum / args.train_data_size
                    wandb.log({"KL cost DEPTH": kl_loss_depth_avg}, commit=False)

                    # log variational categorical distribution
                    wandb.log({"d_1": cat.detach().cpu().numpy()[0]}, commit=False)
                    wandb.log({"d_2": cat.detach().cpu().numpy()[1]}, commit=False)
                    wandb.log({"d_3": cat.detach().cpu().numpy()[2]}, commit=False)
                    wandb.log({"d_4": cat.detach().cpu().numpy()[3]}, commit=False)
                    wandb.log({"d_5": cat.detach().cpu().numpy()[4]}, commit=False)

            else:
                if gp_switch:
                    wandb.log({"Negative ELBO": loss_avg}, commit=False)
                else:
                    wandb.log({"Negative log likelihood (scaled)": loss_avg}, commit=False)
            
            if args.pdts:
                wandb.log({"Learning rate": lrs[0]}, commit=True)
            else:
                wandb.log({"Learning rate": lrs[0]}, commit=False)
            
    if args.pdts and args.swag:
        return loss_avg, n_iter
    else:
        return n_iter
Ejemplo n.º 10
0
Archivo: train.py Proyecto: ks8/glassML
def train(model: nn.Module,
          data: DataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: NoamLR,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.
    :param model: Model.
    :param data: A DataLoader.
    :param loss_func: Loss function.
    :param optimizer: Optimizer.
    :param scheduler: A NoamLR learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    loss_sum, iter_count = 0, 0
    for batch in tqdm(data, total=len(data)):
        if args.cuda:
            targets = batch.y.float().unsqueeze(1).cuda()
        else:
            targets = batch.y.float().unsqueeze(1)
        batch = GlassBatchMolGraph(
            batch)  # TODO: Apply a check for connectivity of graph

        # Run model
        model.zero_grad()
        preds = model(batch)
        loss = loss_func(preds, targets)
        loss = loss.sum() / loss.size(0)

        loss_sum += loss.item()
        iter_count += len(batch)

        loss.backward()
        if args.max_grad_norm is not None:
            clip_grad_norm_(model.parameters(), args.max_grad_norm)
        optimizer.step()
        scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lr = scheduler.get_lr()[0]
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            debug("Loss = {:.4e}, PNorm = {:.4f}, GNorm = {:.4f}, lr = {:.4e}".
                  format(loss_avg, pnorm, gnorm, lr))

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                writer.add_scalar('learning_rate', lr, n_iter)

    return n_iter
Ejemplo n.º 11
0
def train(model: MoleculeModel,
          data_loader: MoleculeDataLoader,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: TrainArgs,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: A :class:`~chemprop.models.model.MoleculeModel`.
    :param data_loader: A :class:`~chemprop.data.data.MoleculeDataLoader`.
    :param loss_func: Loss function.
    :param optimizer: An optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: A :class:`~chemprop.args.TrainArgs` object containing arguments for training the model.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for recording output.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()
    loss_sum = iter_count = 0

    for batch in tqdm(data_loader, total=len(data_loader), leave=False):
        # Prepare batch
        batch: MoleculeDataset
        mol_batch, features_batch, target_batch, mask_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch, data_weights_batch = \
            batch.batch_graph(), batch.features(), batch.targets(), batch.mask(), batch.atom_descriptors(), \
            batch.atom_features(), batch.bond_features(), batch.data_weights()

        mask = torch.tensor(mask_batch, dtype=torch.bool) # shape(batch, tasks)
        targets = torch.tensor([[0 if x is None else x for x in tb] for tb in target_batch]) # shape(batch, tasks)

        if args.target_weights is not None:
            target_weights = torch.tensor(args.target_weights).unsqueeze(0) # shape(1,tasks)
        else:
            target_weights = torch.ones(targets.shape[1]).unsqueeze(0)
        data_weights = torch.tensor(data_weights_batch).unsqueeze(1) # shape(batch,1)

        if args.loss_function == 'bounded_mse':
            lt_target_batch = batch.lt_targets() # shape(batch, tasks)
            gt_target_batch = batch.gt_targets() # shape(batch, tasks)
            lt_target_batch = torch.tensor(lt_target_batch)
            gt_target_batch = torch.tensor(gt_target_batch)

        # Run model
        model.zero_grad()
        preds = model(mol_batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch)

        # Move tensors to correct device
        torch_device = preds.device
        mask = mask.to(torch_device)
        targets = targets.to(torch_device)
        target_weights = target_weights.to(torch_device)
        data_weights = data_weights.to(torch_device)
        if args.loss_function == 'bounded_mse':
            lt_target_batch = lt_target_batch.to(torch_device)
            gt_target_batch = gt_target_batch.to(torch_device)

        # Calculate losses
        if args.loss_function == 'mcc' and args.dataset_type == 'classification':
            loss = loss_func(preds, targets, data_weights, mask) *target_weights.squeeze(0)
        elif args.loss_function == 'mcc': # multiclass dataset type
            targets = targets.long()
            target_losses = []
            for target_index in range(preds.size(1)):
                target_loss = loss_func(preds[:, target_index, :], targets[:, target_index], data_weights, mask[:, target_index]).unsqueeze(0)
                target_losses.append(target_loss)
            loss = torch.cat(target_losses).to(torch_device) * target_weights.squeeze(0)
        elif args.dataset_type == 'multiclass':
            targets = targets.long()
            if args.loss_function == 'dirichlet':
                loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
            else:
                target_losses = []
                for target_index in range(preds.size(1)):
                    target_loss = loss_func(preds[:, target_index, :], targets[:, target_index]).unsqueeze(1)
                    target_losses.append(target_loss)
                loss = torch.cat(target_losses, dim=1).to(torch_device) * target_weights * data_weights * mask
        elif args.dataset_type == 'spectra':
            loss = loss_func(preds, targets, mask) * target_weights * data_weights * mask
        elif args.loss_function == 'bounded_mse':
            loss = loss_func(preds, targets, lt_target_batch, gt_target_batch) * target_weights * data_weights * mask
        elif args.loss_function == 'evidential':
            loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
        elif args.loss_function == 'dirichlet': # classification
            loss = loss_func(preds, targets, args.evidential_regularization) * target_weights * data_weights * mask
        else:
            loss = loss_func(preds, targets) * target_weights * data_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += 1

        loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += len(batch)

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum = iter_count = 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
            debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}')

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter
Ejemplo n.º 12
0
def train(model: nn.Module,
          data: MolPairDataset,
          loss_func: Callable,
          optimizer: Optimizer,
          scheduler: _LRScheduler,
          args: Namespace,
          n_iter: int = 0,
          logger: logging.Logger = None,
          writer: SummaryWriter = None) -> int:
    """
    Trains a model for an epoch.

    :param model: Model.
    :param data: A MolPairDataset (or a list of MolPairDatasets if using moe).
    :param loss_func: Loss function.
    :param optimizer: An Optimizer.
    :param scheduler: A learning rate scheduler.
    :param args: Arguments.
    :param n_iter: The number of iterations (training examples) trained on so far.
    :param logger: A logger for printing intermediate results.
    :param writer: A tensorboardX SummaryWriter.
    :return: The total number of iterations (training examples) trained on so far.
    """
    debug = logger.debug if logger is not None else print

    model.train()

    data.shuffle(
    )  # Very important this is done before conversion to maintain randomness in contrastive dataset.

    loss_sum, iter_count = 0, 0

    if args.loss_func == 'contrastive':
        data = convert2contrast(data)
    num_iters = len(
        data
    ) // args.batch_size * args.batch_size  # don't use the last batch if it's small, for stability

    iter_size = args.batch_size

    for i in trange(0, num_iters, iter_size):
        # Prepare batch
        if i + args.batch_size > len(data):
            break
        mol_batch = MolPairDataset(data[i:i + args.batch_size])
        smiles_batch, features_batch, target_batch = mol_batch.smiles(
        ), mol_batch.features(), mol_batch.targets()
        batch = smiles_batch
        targets = torch.Tensor([[0 if x is None else x for x in tb]
                                for tb in target_batch])
        if args.loss_func == 'contrastive':
            mask = targets
        else:
            mask = torch.Tensor([[x is not None for x in tb]
                                 for tb in target_batch])

        if next(model.parameters()).is_cuda:
            mask, targets = mask.cuda(), targets.cuda()

        if args.dataset_type == 'regression':
            class_weights = torch.ones(targets.shape)
        else:
            class_weights = targets * (args.class_weights - 1) + 1

        if args.cuda:
            class_weights = class_weights.cuda()

        # Run model
        model.zero_grad()
        preds = model(batch, features_batch)

        if args.dataset_type == 'multiclass':
            targets = targets.long()
            loss = torch.cat([
                loss_func(preds[:, target_index, :],
                          targets[:, target_index]).unsqueeze(1)
                for target_index in range(preds.size(1))
            ],
                             dim=1) * class_weights * mask
        else:
            loss = loss_func(preds, targets) * class_weights * mask
        loss = loss.sum() / mask.sum()

        loss_sum += loss.item()
        iter_count += 1

        loss.backward()
        if args.grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        if isinstance(scheduler, NoamLR):
            scheduler.step()

        n_iter += args.batch_size

        # Log and/or add to tensorboard
        if (n_iter // args.batch_size) % args.log_frequency == 0:
            lrs = scheduler.get_lr()
            pnorm = compute_pnorm(model)
            gnorm = compute_gnorm(model)
            loss_avg = loss_sum / iter_count
            loss_sum, iter_count = 0, 0

            lrs_str = ', '.join(f'lr_{i} = {lr:.4e}'
                                for i, lr in enumerate(lrs))
            debug(
                f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}'
            )

            if writer is not None:
                writer.add_scalar('train_loss', loss_avg, n_iter)
                writer.add_scalar('param_norm', pnorm, n_iter)
                writer.add_scalar('gradient_norm', gnorm, n_iter)
                for i, lr in enumerate(lrs):
                    writer.add_scalar(f'learning_rate_{i}', lr, n_iter)

    return n_iter