Exemplo n.º 1
0
def ewc_train(model: nn.Module, optimizer: torch.optim,
              data_loader: torch.utils.data.DataLoader, fisher_info: dict,
              importance: float):
    model.train()
    epoch_loss = 0
    params = {n: p.detach() for n, p in model.named_parameters()}
    for input, target in data_loader:
        input, target = variable(input), variable(target)
        optimizer.zero_grad()
        output = model(input)
        xent_loss = F.cross_entropy(output, target)
        ewc_loss = importance * ewc_penalty(params, model, fisher_info)
        loss = xent_loss + ewc_loss
        print('cls loss {}, ewc loss {}'.format(xent_loss, ewc_loss))
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)
Exemplo n.º 2
0
def train_attr(model, optimizer_attr: torch.optim, data: torch_geometric.data.Data):
    model.train()
    optimizer_attr.zero_grad()

    labels = data.y.to(model.device)
    x, pos_edge_index = data.x, data.train_pos_edge_index

    _edge_index, _ = remove_self_loops(pos_edge_index)
    pos_edge_index_with_self_loops, _ = add_self_loops(_edge_index,
                                                       num_nodes=x.size(0))

    neg_edge_index = negative_sampling(
        edge_index=pos_edge_index_with_self_loops, num_nodes=x.size(0),
        num_neg_samples=pos_edge_index.size(1))

    F.nll_loss(model(pos_edge_index, neg_edge_index)[1][data.train_mask], labels[data.train_mask]).backward()
    optimizer_attr.step()
    model.eval()
Exemplo n.º 3
0
def normal_train(model: nn.Module, opt: torch.optim, loss_func: torch.nn,
                 data_loader: torch.utils.data.DataLoader, device):
    epoch_loss = 0

    for i, (inputs, labels) in enumerate(data_loader):
        inputs = inputs.to(device).long()
        labels = labels.to(device).float()

        opt.zero_grad()

        output = model(inputs)
        loss = loss_func(output.view(-1), labels)
        epoch_loss += loss.item()
        loss.backward()
        opt.step()

    # return epoch_loss / len(data_loader)
    return loss
Exemplo n.º 4
0
def normal_train(model: nn.Module, labels: list, optimizer: torch.optim,
                 data_loader: torch.utils.data.DataLoader, gpu: torch.device):
    model.train()
    model.apply(set_bn_eval)  #冻结BN及其统计数据
    epoch_loss = 0
    for data, target in data_loader:
        data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu)
        optimizer.zero_grad()
        output = model(data)
        for idx in range(output.size(1)):
            if idx not in labels:
                output[range(len(output)), idx] = 0
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)
Exemplo n.º 5
0
def update_weights(
    optimizer: torch.optim,
    network: Network,
    data_loader,
):

    optimizer.zero_grad()
    p_loss, v_loss = 0, 0

    for image, actions, target_values, target_rewards, target_policies in data_loader:
        image = image.to(device)
        # Initial step, from the real observation.
        net_output = network.initial_inference(image)
        predictions = [(1.0, net_output.value, net_output.reward,
                        net_output.policy_logits)]
        hidden_state = net_output.hidden_state

        # Recurrent steps, from action and previous hidden state.
        for action in actions:
            action = action.to(device)

            net_output = network.recurrent_inference(hidden_state, action)

            predictions.append((1.0 / len(actions), net_output.value,
                                net_output.reward, net_output.policy_logits))
            hidden_state = net_output.hidden_state

        for prediction, target_value, target_reward, target_policy in zip(
                predictions, target_values, target_rewards, target_policies):
            target_value, target_reward, target_policy = target_value.to(
                device), target_reward.to(device), target_policy.to(device)

            _, value, reward, policy_logits = prediction

            p_loss += torch.mean(
                torch.sum(-target_policy * torch.log(policy_logits), dim=1))
            v_loss += torch.mean(torch.sum((target_value - value)**2, dim=1))

    total_loss = (p_loss + v_loss)
    total_loss.backward()
    optimizer.step()
    print('step %d: p_loss %f v_loss %f' %
          (network.steps % config.checkpoint_interval, p_loss, v_loss))
    network.steps += 1
Exemplo n.º 6
0
 def find_lr(dataloader,
             model,
             optimizer: torch.optim,
             criterion,
             device,
             num_steps,
             lr_min: float = 1e-7,
             lr_max: float = 10,
             beta: float = 0.98):
     model.to(device)
     optim_dict = optimizer.state_dict().copy()
     optimizer.param_groups[0]['lr'] = lr_min
     #     num_steps = len(dataloader) - 1
     scheduler = LrSchedulerFinder(optimizer, lr_min, lr_max, num_steps)
     model_dict = model.state_dict().copy()
     losses = list()
     lrs = list()
     avg_loss = 0
     best_loss = 0
     for idx_batch, (data, label) in tqdm(enumerate(dataloader, 1),
                                          total=num_steps):
         print("here")
         if idx_batch == num_steps:
             break
         y, kl = model(data.to(device))
         print(y, kl)
         loss = criterion(y, label, kl, 0)
         if np.isnan(loss.item()):
             print(loss.item())
         avg_loss = beta * avg_loss + (1 - beta) * loss.item()
         smooth_loss = avg_loss / (1 - beta**idx_batch)
         if idx_batch > 1 and smooth_loss > 4 * best_loss:
             break
         if smooth_loss < best_loss or idx_batch == 1:
             best_loss = smooth_loss
         losses.append(smooth_loss)
         lrs.append(scheduler.get_lr()[0])
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         scheduler.step()
     model.load_state_dict(model_dict)
     optimizer.load_state_dict(optim_dict)
     return np.array(lrs), np.array(losses)
Exemplo n.º 7
0
def train_on_batch(
        model: Tree2Seq, criterion: nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
        graph: dgl.DGLGraph, labels: torch.Tensor, clip_norm: int
) -> Dict:
    model.train()

    # Model step
    model.zero_grad()
    loss, prediction, batch_info = _forward_pass(model, graph, labels, criterion)
    batch_info['learning_rate'] = scheduler.get_last_lr()[0]
    loss.backward()
    nn.utils.clip_grad_value_(model.parameters(), clip_norm)
    optimizer.step()
    scheduler.step()
    del loss
    del prediction
    torch.cuda.empty_cache()

    return batch_info
Exemplo n.º 8
0
def train(data: DataLoader, optimizer: torch.optim, model: torch.nn.Module,
          device: torch.device, epoch: int, print_freq: int) -> None:
    """Trains the model

    Parameters
    ----------
    data: dataloader
        the dataloader used for training

    optimizer: troch.optim
        The optimizer used during training for gradient descent

    model: subclass of troch.nn.Module
        The nn

    device: troch.device
        Location for where to put the data

    epoch: int
        Current epoch number

    print_freq: int
        Determines after how many training iterations notificaitons are printed
        to stdout
    """
    model.train()
    model.to(device)
    logger = get_logging()
    header = 'Epoch: [{}]'.format(epoch)
    lr_scheduler = learning_rate_scheduler(optimizer, epoch, len(data))
    for images, targets in logger.log_every(data, print_freq, header):
        optimizer.zero_grad()
        images = list(i.to(device) for i in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        check_losses(losses, loss_dict, targets)
        losses.backward()
        optimizer.step()
        if lr_scheduler is not None:
            lr_scheduler.step()
        logger.update(loss=losses, **loss_dict)
        logger.update(lr=optimizer.param_groups[0]["lr"])
Exemplo n.º 9
0
def train_step(
    model: nn.Module,
    loader: DataLoader,
    writer: SummaryWriter,
    batch_size: int,
    optimizer: torch.optim,
    ideal_dcg: float,
    epoch: int,
    k: int = 10,
):
    grad_batch, y_pred_batch = [], []
    model.train()
    pbar = tqdm(total=len(loader))
    for i, (X, Y) in enumerate(loader):
        X, Y = X.squeeze().cuda(), Y.squeeze().numpy()
        N = 1.0 / ideal_dcg.maxDCG(Y)
        y_pred = model(X)
        y_pred_batch.append(y_pred)
        # compute the rank order of each document
        rank_df = pd.DataFrame({"Y": Y, "doc": np.arange(Y.shape[0])})
        rank_df = rank_df.sort_values("Y").reset_index(drop=True)
        rank_order = rank_df.sort_values("doc").index.values + 1
        with torch.no_grad():
            Y = torch.tensor(Y).view(-1, 1).cuda()
            lambda_update = model.get_lambda(Y, y_pred, rank_order, N)
            assert lambda_update.shape == y_pred.shape
            check_grad = torch.sum(lambda_update, (0, 1)).item()
            grad_batch.append(lambda_update)
        pbar.update(1)
        if i % batch_size == 0:
            for grad, y_pred in zip(grad_batch, y_pred_batch):
                y_pred.backward(grad / batch_size, retain_graph=True)
            optimizer.step()
            model.zero_grad()
            grad_batch, y_pred_batch = [], []

    to_write = {
        'NDCG Train': eval_ndcg_at_k(model, loader, k, epoch),
        'Loss Train': eval_cross_entropy_loss(model, loader, epoch),
        'MAP Train': eval_map(model, loader),
    }
    writer.log(to_write)
    pbar.close()
Exemplo n.º 10
0
def train_l0(model: nn.Module, optimizer: torch.optim,
             data_loader: torch.utils.data.DataLoader):
    model.train()
    epoch_loss = 0
    for input, target in data_loader:
        if input.dim() == 3:
            input.unsqueeze_(0)
            target.unsqueeze_(0)
        input, target = variable(input), variable(target)
        optimizer.zero_grad()
        output = model(input)
        xent_loss = F.cross_entropy(output, target)
        l0_loss = model.regularization()
        loss = xent_loss + l0_loss
        print("Training of l0, xent loss {}, l0 loss {}".format(
            xent_loss.item(), l0_loss.item()))
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)
Exemplo n.º 11
0
def l0_train_trans(model: nn.Module, optimizer: torch.optim,
                   data_loader: torch.utils.data.DataLoader, l0_scores: dict,
                   params: dict, importance: float):
    model.train()
    epoch_loss = 0
    for input, target in data_loader:
        input, target = variable(input), variable(target)
        optimizer.zero_grad()
        output = model(input)
        xent_loss = F.cross_entropy(output, target)
        l0_trans_loss = importance * l0_weighted_penalty(
            model, params, l0_scores)
        l0_reg_loss = model.regularization()
        loss = xent_loss + l0_trans_loss + l0_reg_loss
        print("xent loss {}, l0 trans loss {}, l0 reg loss {}".format(
            xent_loss.item(), l0_trans_loss.item(), l0_reg_loss.item()))
        epoch_loss += loss.data[0]
        loss.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)
Exemplo n.º 12
0
def train_drug_target(
    device: torch.device,
    drug_target_net: nn.Module,
    data_loader: torch.utils.data.DataLoader,
    max_num_batches: int,
    optimizer: torch.optim,
):

    drug_target_net.train()

    for batch_idx, (drug_feature, target) in enumerate(data_loader):

        if batch_idx >= max_num_batches:
            break

        drug_feature, target = drug_feature.to(device), target.to(device)

        drug_target_net.zero_grad()
        out_target = drug_target_net(drug_feature)
        F.nll_loss(input=out_target, target=target).backward()
        optimizer.step()
Exemplo n.º 13
0
def ewc_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, ewcs: list, lam: float, gpu: torch.device):
    model.train()
    model.apply(set_bn_eval) #冻结BN及其统计数据
    epoch_loss = 0
    for data, target in data_loader:
        data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu)
        optimizer.zero_grad()
        output = model(data)
        # for idx in range(output.size(1)):
        #     if idx not in labels:
        #         output[range(len(output)), idx] = 0
        # criterion = nn.CrossEntropyLoss()
        # loss = criterion(output, target) 
        loss = myloss(output, target, labels)        
        # print('loss:', loss.item())
        for ewc in ewcs:
            loss += (lam / 2) * ewc.penalty(model)
            # print('ewc loss:', loss.item())
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)
Exemplo n.º 14
0
def Epoch(nn: Resnet, data_loader: DataLoader, value_criterion,
          policy_criterion, optim: Optimizer) -> float:
    running_loss = 0
    count = 0
    for batch in data_loader:
        optim.zero_grad()
        states, target_values, target_policies = batch
        states, target_policies = AugmentData(states, target_policies)

        states = states.to(DEVICE)
        target_values = target_values.to(DEVICE)
        target_policies = target_policies.to(DEVICE)
        nn_values, nn_policies = nn(states)
        value_loss = value_criterion(nn_values, target_values)
        policy_loss = policy_criterion(target_policies, nn_policies)
        loss = value_loss + policy_loss
        loss.backward()
        optim.step()

        running_loss += loss.item()
        count += 1

    return running_loss / count
Exemplo n.º 15
0
Arquivo: VAE.py Projeto: genEM3/genEM3
def train(epoch: int = None,
          model: torch.nn.Module = None,
          train_loader: torch.utils.data.DataLoader = None,
          optimizer: torch.optim = None,
          args: argparse.Namespace = None,
          device: torch.device = torch.device('cpu')):
    """training loop on a batch of data"""
    model.train()
    train_loss = 0
    detailedLoss = {'Recon': 0.0, 'KLD': 0.0, 'weighted_KLD': 0.0}
    for batch_idx, data in tqdm(enumerate(train_loader), total=len(train_loader), desc='train'):
        data = data['input'].to(device)

        optimizer.zero_grad()
        recon_batch = model(data)

        loss, curDetLoss = loss_function(recon_batch,
                                         data,
                                         model.cur_mu,
                                         model.cur_logvar,
                                         model.weight_KLD)
        train_loss += (loss.item() / NUM_FACTOR)
        # Separate loss
        for key in curDetLoss:
            detailedLoss[key] += (curDetLoss.get(key) / NUM_FACTOR)
        # Backprop
        loss.backward()
        optimizer.step()
    num_data_points = len(train_loader.dataset.data_train_inds)
    train_loss /= num_data_points
    train_loss *= NUM_FACTOR

    for key in detailedLoss:
        detailedLoss[key] /= num_data_points
        detailedLoss[key] *= NUM_FACTOR

    return train_loss, detailedLoss
Exemplo n.º 16
0
def train_cl_clf(device: torch.device,

                 category_clf_net: nn.Module,
                 site_clf_net: nn.Module,
                 type_clf_net: nn.Module,
                 data_loader: torch.utils.data.DataLoader,

                 max_num_batches: int,
                 optimizer: torch.optim, ):

    category_clf_net.train()
    site_clf_net.train()
    type_clf_net.train()

    for batch_idx, (rnaseq, data_src, cl_site, cl_type, cl_category) \
            in enumerate(data_loader):

        if batch_idx >= max_num_batches:
            break

        rnaseq, data_src, cl_site, cl_type, cl_category = \
            rnaseq.to(device), data_src.to(device), cl_site.to(device), \
            cl_type.to(device), cl_category.to(device)

        category_clf_net.zero_grad()
        site_clf_net.zero_grad()
        type_clf_net.zero_grad()

        out_category = category_clf_net(rnaseq, data_src)
        out_site = site_clf_net(rnaseq, data_src)
        out_type = type_clf_net(rnaseq, data_src)

        F.nll_loss(input=out_category, target=cl_category).backward()
        F.nll_loss(input=out_site, target=cl_site).backward()
        F.nll_loss(input=out_type, target=cl_type).backward()

        optimizer.step()
def train(data: TokensDataSet, model: torch.nn.Module, device: torch.device,
          loss: torch.nn.modules.loss, optimizer: torch.optim, batch_size: int,
          epochs: int) -> None:
    model = model.to(device)
    for epoch in range(epochs):
        order = np.random.permutation(data.X_train.shape[0])
        total_loss = 0
        for start_index in range(0, data.X_train.shape[0], batch_size):
            if start_index + batch_size > data.X_train.shape[0]:
                break
            optimizer.zero_grad()
            model.train()

            batch_idxs = order[start_index:(start_index + batch_size)]
            X_batch = data.X_train[batch_idxs].to(device)
            X_batch_lengths = data.X_train_lengths[batch_idxs].to(device)
            y_batch = data.y_train[batch_idxs].to(device)

            preds = model.forward(X_batch, X_batch_lengths)
            loss_val = loss(preds, y_batch)
            loss_val.backward()
            total_loss += loss_val.item()
            optimizer.step()
        print(f'epoch: {epoch} | loss: {total_loss}')
Exemplo n.º 18
0
def clf_train(net, tloader, opti: torch.optim, crit: nn.Module, **kwargs):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.to(device)
    net.train()
    tcorr = 0
    tloss = 0
    try:
        crit = crit()
    except:
        pass
    for ii, (data, labl) in enumerate(tqdm(tloader)):
        data, labl = data.to(device), labl.to(device)
        out = net(data)
        loss = crit(out, labl)
        opti.zero_grad()
        loss.backward()
        opti.step()
        with torch.no_grad():
            tloss += loss.item()
            tcorr += (out.argmax(dim=1) == labl).float().sum()

    tloss /= len(tloader)
    tacc = accuracy(tcorr, len(tloader) * tloader.batch_size)
    return tacc, tloss
Exemplo n.º 19
0
def train(model: nn.Module, optimizer: optim, train_loader: DataLoader,
          test_loader: DataLoader, args, epoch: int) -> float:
    model.train()
    loss_epoch = np.zeros(len(train_loader))
    for i, (train_batch, labels_batch) in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        batch_size = train_batch.shape[0]

        train_batch = train_batch.permute(1, 0, 2).to(torch.float32).to(
            args.device)  # not scaled
        labels_batch = labels_batch.permute(1, 0).to(torch.float32).to(
            args.device)  # not scaled

        loss = torch.zeros(1, device=args.device)
        hidden = model.init_hidden(batch_size, args.device)
        cell = model.init_cell(batch_size, args.device)

        for t in range(args.window_size):
            # if z_t is missing, replace it by output mu from the last time step
            loss_, y, hidden, cell = model(
                train_batch[t].unsqueeze_(0).clone(), hidden, cell,
                labels_batch[t])
            loss += loss_

        loss.backward()
        optimizer.step()
        loss = loss.item() / args.window_size  # loss per timestep
        loss_epoch[i] = loss
        # test_metrics = evaluate(model, test_loader, args, epoch)
        if i % 1000 == 0:
            test_metrics = evaluate(model, test_loader, args, epoch)
            model.train()
            logger.info(f'train_loss: {loss}')
        if i == 0:
            logger.info(f'train_loss: {loss}')
    return loss_epoch
Exemplo n.º 20
0
    def train(self,
              epoch_s: int,
              epoch_e: int,
              data: Data,
              n_samples: int,
              optimizer: torch.optim,
              device: torch.device,
              strategy: str = 'max',
              mode: bool = True,
              batch_size: int = 256) -> None:

        loader = DataLoader(torch.arange(data.num_nodes),
                            batch_size=batch_size,
                            shuffle=True)

        train_time = time.time()
        prefix_sav = f'./model_save/WNNode2vec_{train_time}'
        loss_list = []

        for epoch in range(epoch_s, epoch_e):
            super().train()
            total_loss = 0
            print(f'epoch: {epoch}')
            for subset in tqdm(loader):
                optimizer.zero_grad()
                loss = self.loss(data.edge_index, subset=subset.to(device))
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            rls_loss = total_loss / len(loader)
            loss_list.append(rls_loss)
            sr_rls = sr_test(device, self.forward, strategy)
            oup = self.forward(torch.arange(
                0, data.num_nodes, device=data.edge_index.device)).data
            save_model(epoch, self, optimizer, loss_list,
                       prefix_sav, oup=oup, sr=sr_rls)
Exemplo n.º 21
0
def train(
    model: DRIVEMODEL,
    optimizer_name: str,
    optimizer: torch.optim,
    scheduler: torch.optim.lr_scheduler,
    train_dir: str,
    dev_dir: str,
    test_dir: str,
    output_dir: str,
    batch_size: int,
    accumulation_steps: int,
    initial_epoch: int,
    num_epoch: int,
    max_acc: float,
    hide_map_prob: float,
    dropout_images_prob: List[float],
    num_load_files_training: int,
    fp16: bool = True,
    amp_opt_level=None,
    save_checkpoints: bool = True,
    eval_every: int = 5,
    save_every: int = 20,
    save_best: bool = True,
):
    """
    Train a model

    Input:
    - model: DRIVEMODEL model to train
    - optimizer_name: Name of the optimizer to use [SGD, Adam]
    - optimizer: Optimizer (torch.optim)
    - train_dir: Directory where the train files are stored
    - dev_dir: Directory where the development files are stored
    - test_dir: Directory where the test files are stored
    - output_dir: Directory where the model and the checkpoints are going to be saved
    - batch_size: Batch size (Around 10 for 8GB GPU)
    - initial_epoch: Number of previous epochs used to train the model (0 unless the model has been
      restored from checkpoint)
    - num_epochs: Number of epochs to do
    - max_acc: Accuracy in the development set (0 unless the model has been
      restored from checkpoint)
    - hide_map_prob: Probability for removing the minimap (put a black square)
       from a training example (0<=hide_map_prob<=1)
    - dropout_images_prob List of 5 floats or None, probability for removing each input image during training
     (black image) from a training example (0<=dropout_images_prob<=1)
    - fp16: Use FP16 for training
    - amp_opt_level: If FP16 training Nvidia apex opt level
    - save_checkpoints: save a checkpoint each epoch (Each checkpoint will rewrite the previous one)
    - save_best: save the model that achieves the higher accuracy in the development set

    Output:
     - float: Accuracy in the development test of the best model
    """
    writer: SummaryWriter = SummaryWriter()

    if fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    criterion: CrossEntropyLoss = torch.nn.CrossEntropyLoss()
    print("Loading dev set")
    X_dev, y_dev = load_dataset(dev_dir, fp=16 if fp16 else 32)
    X_dev = torch.from_numpy(X_dev)
    print("Loading test set")
    X_test, y_test = load_dataset(test_dir, fp=16 if fp16 else 32)
    X_test = torch.from_numpy(X_test)
    total_training_exampels: int = 0
    model.zero_grad()

    printTrace("Training...")
    for epoch in range(num_epoch):
        step_no: int = 0
        iteration_no: int = 0
        num_used_files: int = 0
        data_loader = DataLoader_AutoDrive(
            dataset_dir=train_dir,
            nfiles2load=num_load_files_training,
            hide_map_prob=hide_map_prob,
            dropout_images_prob=dropout_images_prob,
            fp=16 if fp16 else 32,
        )

        data = data_loader.get_next()
        # Get files in batches, all files will be loaded and data will be shuffled
        while data:
            X, y = data
            model.train()
            start_time: float = time.time()
            total_training_exampels += len(y)
            running_loss: float = 0.0
            num_batchs: int = 0
            acc_dev: float = 0.0

            for X_bacth, y_batch in nn_batchs(X, y, batch_size):
                X_bacth, y_batch = (
                    torch.from_numpy(X_bacth).to(device),
                    torch.from_numpy(y_batch).long().to(device),
                )

                outputs = model.forward(X_bacth)
                loss = criterion(outputs, y_batch) / accumulation_steps
                running_loss += loss.item()

                if fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), 1.0)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                if (step_no + 1) % accumulation_steps or (
                        num_used_files + 1 >
                        len(data_loader) - num_load_files_training
                        and num_batchs == math.ceil(len(y) / batch_size) - 1
                ):  # If we are in the last bach of the epoch we also want to perform gradient descent
                    optimizer.step()
                    model.zero_grad()

                num_batchs += 1
                step_no += 1

            num_used_files += num_load_files_training

            # Print Statistics
            printTrace(
                f"EPOCH: {initial_epoch+epoch}. Iteration {iteration_no}. "
                f"{num_used_files} of {len(data_loader)} files. "
                f"Total examples used for training {total_training_exampels}. "
                f"Iteration time: {round(time.time() - start_time,2)} secs.")
            printTrace(
                f"Loss: {-1 if num_batchs == 0 else running_loss / num_batchs}. "
                f"Learning rate {optimizer.state_dict()['param_groups'][0]['lr']}"
            )
            writer.add_scalar("Loss/train", running_loss / num_batchs,
                              iteration_no)

            scheduler.step(running_loss / num_batchs)

            if (iteration_no + 1) % eval_every == 0:
                start_time_eval: float = time.time()
                if len(X) > 0 and len(y) > 0:
                    acc_train: float = evaluate(
                        model=model,
                        X=torch.from_numpy(X),
                        golds=y,
                        device=device,
                        batch_size=batch_size,
                    )
                else:
                    acc_train = -1.0

                acc_dev: float = evaluate(
                    model=model,
                    X=X_dev,
                    golds=y_dev,
                    device=device,
                    batch_size=batch_size,
                )

                acc_test: float = evaluate(
                    model=model,
                    X=X_test,
                    golds=y_test,
                    device=device,
                    batch_size=batch_size,
                )

                printTrace(
                    f"Acc training set: {round(acc_train,2)}. "
                    f"Acc dev set: {round(acc_dev,2)}. "
                    f"Acc test set: {round(acc_test,2)}.  "
                    f"Eval time: {round(time.time() - start_time_eval,2)} secs."
                )

                if 0.0 < acc_dev > max_acc and save_best:
                    max_acc = acc_dev
                    printTrace(
                        f"New max acc in dev set {round(max_acc,2)}. Saving model..."
                    )
                    save_model(
                        model=model,
                        save_dir=output_dir,
                        fp16=fp16,
                        amp_opt_level=amp_opt_level,
                    )
                if acc_train > -1:
                    writer.add_scalar("Accuracy/train", acc_train,
                                      iteration_no)
                writer.add_scalar("Accuracy/dev", acc_dev, iteration_no)
                writer.add_scalar("Accuracy/test", acc_test, iteration_no)

            if save_checkpoints and (iteration_no + 1) % save_every == 0:
                printTrace("Saving checkpoint...")
                save_checkpoint(
                    path=os.path.join(output_dir, "checkpoint.pt"),
                    model=model,
                    optimizer_name=optimizer_name,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    acc_dev=acc_dev,
                    epoch=initial_epoch + epoch,
                    fp16=fp16,
                    opt_level=amp_opt_level,
                )

            iteration_no += 1
            data = data_loader.get_next()

        data_loader.close()

    return max_acc
Exemplo n.º 22
0
def model_train \
    ( data_trn:torch.utils.data.Dataset
    , modl:torch.nn.Module
    , crit:torch.nn
    , optm:torch.optim
    , batch_size:int=100
    , hidden_shapes:list=[20,30,40]
    , hidden_acti:str="relu"
    , final_shape:int=1
    , final_acti:str="sigmoid"
    , device:torch.device=get_device()
    , scheduler:torch.optim.lr_scheduler=None
    ):

    # Set to train
    modl.train()
    loss_trn = 0.0
    accu_trn = 0.0

    # Set data generator
    load_trn = DataLoader(data_trn, batch_size=batch_size, shuffle=True, num_workers=0)

    # Loop over each batch
    for batch, data in enumerate(load_trn):
        
        # Extract data
        inputs, labels = data

        # Push data to device
        # inputs, labels = inputs.to(device), labels.to(device)
        inputs.to(device)
        labels.to(device)

        # Zero out the parameter gradients
        optm.zero_grad()

        # Feed forward
        output = modl \
            ( feat=inputs
            , hidden_shapes=hidden_shapes
            , hidden_acti=hidden_acti
            , final_shape=final_shape
            , final_acti=final_acti
            )

        # Calc loss
        loss = crit(output, labels.unsqueeze(1))

        # Global metrics
        loss_trn += loss.item()
        accu_trn += (output.argmax(1) == labels).sum().item()

        # Feed backward
        loss.backward()

        # Optimise
        optm.step()

    # Adjust scheduler
    if scheduler:
        scheduler.step()
    
    return loss_trn/len(data_trn), accu_trn/len(data_trn)
Exemplo n.º 23
0
def our_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, ewcs: list, lam: float, gpu: torch.device, cut_idx, if_freeze):
    
    #还需要进行loss判断,true:freeze
    #---------------------freeze
    # if if_freeze == 1 :
    #     for idx, param in enumerate(model.parameters()):
    #         if idx >= cut_idx:
    #             continue
    #         param.requires_grad = False

    #     optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args_lr) # no need to add: lr=0.1?
    #----------------------
    model.train()
    model.apply(set_bn_eval) #冻结BN及其统计数据
    epoch_loss = 0
    for data, target in data_loader:
        data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu)
        optimizer.zero_grad()
        output = model(data)
        for idx in range(output.size(1)):
            if idx not in labels:
                output[range(len(output)), idx] = 0
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target) 
        # print('loss:', loss.item())
        for ewc in ewcs:
            loss += (lam / 2) * ewc.penalty(model)
            # print('ewc loss:', loss.item())
        epoch_loss += loss.item()

        loss.backward()
        countskip = 0
        countall = 0
        #------根据if_freeze,决定是否冻结server----------------------------------------------
        if if_freeze == 1 :
            #----------------重写step---------------------           
            for group in optimizer.param_groups:
                for idx, p in enumerate(group['params']):
                    countall += 1
                    if idx >= cut_idx: #冻结server,即跳过cut_idx ~ end
                        countskip += 1
                        #print('skip_server_layer')
                        continue                    
                    if p.grad is None:
                        continue
                    d_p = p.grad
                    #p.add_(d_p, alpha=-group['lr'])
                    p.data = p.data - d_p*group['lr']
            print("countskip:",countskip,"countall:",countall)
        else:
            optimizer.step()
        #----------------------------------------------------
        #optimizer.step()   #optimizer.param_groups : 'params' : .grad ==> 梯度 
                           #91行
                           #for n, p in model.named_parameters():
                           #    p.grad.data ==> 当前网络层梯度数据?
    #-----------------------------解冻
    # if if_freeze == 1 :    
    #     for idx, param in enumerate(model.parameters()):
    #         if idx >= cut_idx:
    #             continue
    #         param.requires_grad = True

    #     optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args_lr) # no need to add: lr=0.1?
    #-------------------------------
    return epoch_loss / len(data_loader)
Exemplo n.º 24
0
def ours_first_train(model: nn.Module, labels: list, optimizer: torch.optim, data_loader: torch.utils.data.DataLoader, gpu: torch.device, cut_idx, freeze_stat, last_loss):
    model.train()
    model.apply(set_bn_eval)  #冻结BN及其统计数据
    # epoch_loss = 0
    # for data, target in data_loader:
    #     data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu)
    #     optimizer.zero_grad()
    #     output = model(data)
    #     for idx in range(output.size(1)):
    #         if idx not in labels:
    #             output[range(len(output)), idx] = 0
    #     criterion = nn.CrossEntropyLoss()
    #     loss = criterion(output, target)
    #     epoch_loss += loss.item()
    #     #loss.backward()
    #     #optimizer.step()

    # average_epoch_loss = (epoch_loss / len(data_loader))

    epoch_loss = 0    
    for data, target in data_loader:
        data, target = Variable(data).cuda(gpu), Variable(target).cuda(gpu)
        optimizer.zero_grad()
        output = model(data)
        for idx in range(output.size(1)):
            if idx not in labels:
                output[range(len(output)), idx] = 0
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, target)
        epoch_loss += loss.item()
        loss.backward()
        countskip = 0
        countall = 0
        #-----------------------------------------------------------
        if last_loss > 1.5 :
            optimizer.step()
            print('pretrain',last_loss)
        elif freeze_stat == 0 :
            #----------------重写step---------------------
            #optimizer.step()           
            for group in optimizer.param_groups:
                for idx, p in enumerate(group['params']):
                    countall += 1
                    if idx >= cut_idx:  #跳过cut_idx ~ end, 冻结server参数
                        countskip += 1
                        #print('skip_server_layer')
                        #p.grad = p.grad*0
                        continue                    
                    if p.grad is None:
                        continue
                    d_p = p.grad
                    #p.add_(d_p, alpha=-group['lr'])
                    p.data = p.data - d_p*group['lr']
            print("servercountskip:",countskip,"countall:",countall)
        else:
            #----------------重写step---------------------
            #print('freeze_stat = 1')
            #optimizer.step()           
            for group in optimizer.param_groups:
                for idx, p in enumerate(group['params']):
                    countall += 1
                    if idx < cut_idx:  #跳过0 ~ cut_idx-1, 冻结device参数
                        countskip += 1
                        #print('skip_server_layer')
                        #p.grad = p.grad*0
                        continue                    
                    if p.grad is None:
                        continue
                    d_p = p.grad
                    #p.add_(d_p, alpha=-group['lr'])
                    p.data = p.data - d_p*group['lr']
            print("devicecountskip:",countskip,"countall:",countall)

    return epoch_loss / len(data_loader)
def train(train_loader: DataLoader, validation_loader: DataLoader,
          num_epochs: int, total_training_batches: int, model: Module,
          criterion: loss, optimizer: optim, batch_size: int,
          learning_rate: float):
    """Train network."""

    writer.add_text(
        'Experiment summary',
        'Batch size: {}, Learning rate {}'.format(batch_size, learning_rate))

    batch_number = 0
    step_number = 0
    previous_running_loss = 0

    for epoch in range(num_epochs):
        train_running_loss = 0
        train_accuracy = 0
        for images, labels in train_loader:
            if batch_number % 10 == 0:
                logging.info('Batch number {}/{}...'.format(
                    batch_number, total_training_batches))

            batch_number += 1
            step_number += 1

            # Pass this computations to selected device
            images = images.cuda()
            labels = labels.cuda()

            # Clear the gradients, do this because gradients are accumulated
            optimizer.zero_grad()

            # Forwards pass, then backward pass, then update weights
            probabilities = model.forward(images)
            model_loss = criterion(probabilities, labels)
            model_loss.backward()
            optimizer.step()

            # Get the class probabilities
            ps = torch.nn.functional.softmax(probabilities, dim=1)

            # Get top probabilities
            top_probability, top_class = ps.topk(1, dim=1)

            # Comparing one element in each row of top_class with
            # each of the labels, and return True/False
            equals = top_class == labels.view(*top_class.shape)

            # Number of correct predictions
            train_accuracy += torch.sum(equals.type(torch.FloatTensor)).item()
            train_running_loss += model_loss.item()
        else:
            validation_running_loss = 0
            validation_accuracy = 0
            # Turn off gradients for testing
            with torch.no_grad():
                # set model to evaluation mode
                model.eval()
                for images, labels in validation_loader:
                    # Pass this computations to selected device
                    images = images.cuda()
                    labels = labels.cuda()

                    probabilities = model.forward(images)
                    validation_running_loss += criterion(probabilities, labels)

                    # Get the class probabilities
                    ps = torch.nn.functional.softmax(probabilities, dim=1)

                    # Get top probabilities
                    top_probability, top_class = ps.topk(1, dim=1)

                    # Comparing one element in each row of top_class with
                    # each of the labels, and return True/False
                    equals = top_class == labels.view(*top_class.shape)

                    # Number of correct predictions
                    validation_accuracy += torch.sum(
                        equals.type(torch.FloatTensor)).item()

            if validation_running_loss <= previous_running_loss:
                logging.info(
                    'Validation loss decreased {:.5f} -> {:.5f}. Saving model.'
                    .format(previous_running_loss, validation_running_loss))
                torch.save(model.state_dict(),
                           'model_{}.pt'.format(batch_size))

            previous_running_loss = validation_running_loss

            # Set model to train mode
            model.train()

            # Calculating accuracy
            validation_accuracy = (validation_accuracy /
                                   validation_loader.sampler.num_samples * 100)
            train_accuracy = (train_accuracy /
                              train_loader.batch_sampler.sampler.num_samples *
                              100)

            # Saving losses and accuracy
            writer.add_scalar('loss/train_loss', train_running_loss, epoch)
            writer.add_scalar('accuracy/train_accuracy', train_accuracy, epoch)
            writer.add_scalar('loss/validation_loss', validation_running_loss,
                              epoch)
            writer.add_scalar('accuracy/validation_accuracy',
                              validation_accuracy, epoch)

            logging.info("Epoch: {}/{}.. ".format(epoch + 1, num_epochs))
            logging.info("Training Loss: {:.3f}.. ".format(train_running_loss))
            logging.info("Training Accuracy: {:.3f}%".format(train_accuracy))
            logging.info(
                "Validation Loss: {:.3f}.. ".format(validation_running_loss))
            logging.info(
                "Validation Accuracy: {:.3f}%".format(validation_accuracy))

            batch_number = 0
Exemplo n.º 26
0
def train(model, optimizer: torch.optim, optimizer_attack: torch.optim, data: torch_geometric.data.Data,
          switch: bool = True, use_ws_loss: bool = True):
    """
        trains the model for one epoch

        Parameters
        ----------
        model: Model
        optimizer: torch.optim
        optimizer_attack: torch.optim
        data: torch_geometric.data.Data
        switch: bool
        use_ws_loss: bool
    """

    model.train()

    labels = data.y.to(model.device)
    x, pos_edge_index = data.x, data.train_pos_edge_index

    _edge_index, _ = remove_self_loops(pos_edge_index)
    pos_edge_index_with_self_loops, _ = add_self_loops(_edge_index,
                                                       num_nodes=x.size(0))

    neg_edge_index = negative_sampling(
        edge_index=pos_edge_index_with_self_loops, num_nodes=x.size(0),
        num_neg_samples=pos_edge_index.size(1))

    link_logits, attr_prediction, attack_prediction, _ = model(pos_edge_index, neg_edge_index)
    link_labels = _get_link_labels(pos_edge_index, neg_edge_index).to(link_logits.device)

    # same from here to the end
    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)

    # loss 2
    if use_ws_loss:  # wasserstein distance VS total variation
        one_hot = torch.cuda.FloatTensor(attack_prediction.size(0), attack_prediction.size(1)).zero_()
        mask = one_hot.scatter_(1, labels.view(-1, 1), 1)

        nonzero = mask * attack_prediction
        avg = torch.mean(nonzero, dim=0)
        loss2 = torch.abs(torch.max(avg) - torch.min(avg))
    else:
        loss2 = F.nll_loss(attack_prediction, labels)

    link_logits = link_logits.detach().cpu().numpy()
    link_labels = link_labels.detach().cpu().numpy()

    train_acc = roc_auc_score(link_labels, link_logits)

    if switch:
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    else:
        optimizer_attack.zero_grad()
        loss2.backward()
        optimizer_attack.step()

        for p in model.attk.parameters():
            p.data.clamp_(-1, 1)

    model.eval()
    return train_acc
def train_helper(model: torchvision.models.resnet.ResNet,
                 dataloaders: Dict[str, torch.utils.data.DataLoader],
                 dataset_sizes: Dict[str,
                                     int], criterion: torch.nn.modules.loss,
                 optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
                 num_epochs: int, writer: IO, train_order_writer: IO,
                 device: torch.device, start_epoch: int, batch_size: int,
                 save_interval: int, checkpoints_folder: Path, num_layers: int,
                 classes: List[str], num_classes: int) -> None:
    """
    Function for training ResNet.
    Args:
        model: ResNet model for training.
        dataloaders: Dataloaders for IO pipeline.
        dataset_sizes: Sizes of the training and validation dataset.
        criterion: Metric used for calculating loss.
        optimizer: Optimizer to use for gradient descent.
        scheduler: Scheduler to use for learning rate decay.
        start_epoch: Starting epoch for training.
        writer: Writer to write logging information.
        train_order_writer: Writer to write the order of training examples.
        device: Device to use for running model.
        num_epochs: Total number of epochs to train for.
        batch_size: Mini-batch size to use for training.
        save_interval: Number of epochs between saving checkpoints.
        checkpoints_folder: Directory to save model checkpoints to.
        num_layers: Number of layers to use in the ResNet model from [18, 34, 50, 101, 152].
        classes: Names of the classes in the dataset.
        num_classes: Number of classes in the dataset.
    """
    since = time.time()

    # Initialize all the tensors to be used in training and validation.
    # Do this outside the loop since it will be written over entirely at each
    # epoch and doesn't need to be reallocated each time.
    train_all_labels = torch.empty(size=(dataset_sizes["train"], ),
                                   dtype=torch.long).cpu()
    train_all_predicts = torch.empty(size=(dataset_sizes["train"], ),
                                     dtype=torch.long).cpu()
    val_all_labels = torch.empty(size=(dataset_sizes["val"], ),
                                 dtype=torch.long).cpu()
    val_all_predicts = torch.empty(size=(dataset_sizes["val"], ),
                                   dtype=torch.long).cpu()

    global_minibatch_counter = 0

    # Train for specified number of epochs.
    for epoch in range(start_epoch, num_epochs):

        # Training phase.
        model.train(mode=True)

        train_running_loss = 0.0
        train_running_corrects = 0
        epoch_minibatch_counter = 0

        # Train over all training data.
        for idx, (inputs, labels, paths) in enumerate(dataloaders["train"]):

            train_inputs = inputs.to(device=device)
            train_labels = labels.to(device=device)
            optimizer.zero_grad()

            # Forward and backpropagation.
            with torch.set_grad_enabled(mode=True):
                train_outputs = model(train_inputs)
                __, train_preds = torch.max(train_outputs, dim=1)
                train_loss = criterion(input=train_outputs,
                                       target=train_labels)
                train_loss.backward()
                optimizer.step()

            # Update training diagnostics.
            train_running_loss += train_loss.item() * train_inputs.size(0)
            train_running_corrects += torch.sum(
                train_preds == train_labels.data, dtype=torch.double)

            start = idx * batch_size
            end = start + batch_size

            train_all_labels[start:end] = train_labels.detach().cpu()
            train_all_predicts[start:end] = train_preds.detach().cpu()

            global_minibatch_counter += 1
            epoch_minibatch_counter += 1

            # for path in paths: #write the order that the model was trained in
            #     train_order_writer.write("/".join(path.split("/")[-2:]) + "\n")

            if global_minibatch_counter % 10 == 0 or global_minibatch_counter == 5:

                calculate_confusion_matrix(
                    all_labels=train_all_labels.numpy(),
                    all_predicts=train_all_predicts.numpy(),
                    classes=classes,
                    num_classes=num_classes)

                # Store training diagnostics.
                train_loss = train_running_loss / (epoch_minibatch_counter *
                                                   batch_size)
                train_acc = train_running_corrects / (epoch_minibatch_counter *
                                                      batch_size)

                # Validation phase.
                model.train(mode=False)

                val_running_loss = 0.0
                val_running_corrects = 0

                # Feed forward over all the validation data.
                for idx, (val_inputs, val_labels,
                          paths) in enumerate(dataloaders["val"]):
                    val_inputs = val_inputs.to(device=device)
                    val_labels = val_labels.to(device=device)

                    # Feed forward.
                    with torch.set_grad_enabled(mode=False):
                        val_outputs = model(val_inputs)
                        _, val_preds = torch.max(val_outputs, dim=1)
                        val_loss = criterion(input=val_outputs,
                                             target=val_labels)

                    # Update validation diagnostics.
                    val_running_loss += val_loss.item() * val_inputs.size(0)
                    val_running_corrects += torch.sum(
                        val_preds == val_labels.data, dtype=torch.double)

                    start = idx * batch_size
                    end = start + batch_size

                    val_all_labels[start:end] = val_labels.detach().cpu()
                    val_all_predicts[start:end] = val_preds.detach().cpu()

                calculate_confusion_matrix(
                    all_labels=val_all_labels.numpy(),
                    all_predicts=val_all_predicts.numpy(),
                    classes=classes,
                    num_classes=num_classes)

                # Store validation diagnostics.
                val_loss = val_running_loss / dataset_sizes["val"]
                val_acc = val_running_corrects / dataset_sizes["val"]

                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                # Remaining things related to training.
                if global_minibatch_counter % 10 == 0 or global_minibatch_counter == 5:
                    epoch_output_path = checkpoints_folder.joinpath(
                        f"resnet{num_layers}_e{epoch}_mb{global_minibatch_counter}_va{val_acc:.5f}.pt"
                    )

                    # Confirm the output directory exists.
                    epoch_output_path.parent.mkdir(parents=True, exist_ok=True)

                    # Save the model as a state dictionary.
                    torch.save(obj={
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "scheduler_state_dict": scheduler.state_dict(),
                        "epoch": epoch + 1
                    },
                               f=str(epoch_output_path))

                writer.write(
                    f"{epoch},{global_minibatch_counter},{train_loss:.4f},"
                    f"{train_acc:.4f},{val_loss:.4f},{val_acc:.4f}\n")

                current_lr = None
                for group in optimizer.param_groups:
                    current_lr = group["lr"]

                # Print the diagnostics for each epoch.
                print(f"Epoch {epoch} with "
                      f"mb {global_minibatch_counter} "
                      f"lr {current_lr:.15f}: "
                      f"t_loss: {train_loss:.4f} "
                      f"t_acc: {train_acc:.4f} "
                      f"v_loss: {val_loss:.4f} "
                      f"v_acc: {val_acc:.4f}\n")

        scheduler.step()

        current_lr = None
        for group in optimizer.param_groups:
            current_lr = group["lr"]

    # Print training information at the end.
    print(f"\ntraining complete in "
          f"{(time.time() - since) // 60:.2f} minutes")
Exemplo n.º 28
0
def run(data_loader: BaseDataLoader, encoder: Encoder, criterion, optimizer: optim, scheduler: optim.lr_scheduler,
          similarity_measure: Similarity, save_model,
          args:argparse.Namespace):
    encoder.to(device)
    best_accs = {"encode_acc": float('-inf'), "pivot_acc": float('-inf')}
    last_update = 0
    dev_arg_dict = {
        "use_mid": args.use_mid,
        "topk": args.val_topk,
        "trg_encoding_num": args.trg_encoding_num,
        "mid_encoding_num": args.mid_encoding_num
    }
    # lr_decay = scheduler is not None
    # if lr_decay:
    #     print("[INFO] using learning rate decay")
    for ep in range(args.max_epoch):
        encoder.train()
        train_loss = 0.0
        start_time = time.time()
        # if not args.mega:
        train_batches = data_loader.create_batches("train")
        # else:
        #     if ep <= 30:
        #         train_batches = data_loader.create_batches("train")
        #     else:
        #         train_batches = data_loader.create_megabatch(encoder)
        batch_num = 0
        t = 0
        for idx, batch in enumerate(train_batches):
            optimizer.zero_grad()
            cur_loss = calc_batch_loss(encoder, criterion, batch, args.mid_proportion, args.trg_encoding_num, args.mid_encoding_num)
            train_loss += cur_loss.item()
            cur_loss.backward()
            # optimizer.step()

            for p in list(filter(lambda p: p.grad is not None, encoder.parameters())):
                t += p.grad.data.norm(2).item()

            torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=5)
            optimizer.step()

            if encoder.name == "bilstm":
                # set all but forget gate bias to 0
                reset_bias(encoder.src_lstm)
                reset_bias(encoder.trg_lstm)
                # pass
            batch_num += 1
        print("[INFO] epoch {:d}: train loss={:.8f}, time={:.2f}".format(ep, train_loss / batch_num,
                                                                         time.time()-start_time))
        # print(t)

        if (ep + 1) % EPOCH_CHECK == 0:
            with torch.no_grad():
                encoder.eval()
                # eval
                train_batches = data_loader.create_batches("train")
                dev_batches = data_loader.create_batches("dev")
                start_time = time.time()

                recall, tot = eval_data(encoder, train_batches, dev_batches, similarity_measure, dev_arg_dict)
                dev_pivot_acc = recall[0] / float(tot)
                dev_encode_acc = recall[1] / float(tot)
                if dev_encode_acc > best_accs["encode_acc"]:
                    best_accs["encode_acc"] = dev_encode_acc
                    best_accs["pivot_acc"] = dev_pivot_acc
                    last_update = ep + 1
                    save_model(encoder, ep + 1, train_loss / batch_num, optimizer, args.model_path + "_" + "best" + ".tar")
                save_model(encoder, ep + 1, train_loss / batch_num, optimizer, args.model_path + "_" + "last" + ".tar")
                print("[INFO] epoch {:d}: encoding/pivoting dev acc={:.4f}/{:.4f}, time={:.2f}".format(
                                                                                            ep, dev_encode_acc, dev_pivot_acc,
                                                                                            time.time()-start_time))
                if args.lr_decay and ep + 1 - last_update > UPDATE_PATIENT:
                    new_lr = optimizer.param_groups[0]['lr'] * args.lr_scaler
                    best_info  = torch.load(args.model_path + "_" + "best" + ".tar")
                    encoder.load_state_dict(best_info["model_state_dict"])
                    optimizer.load_state_dict(best_info["optimizer_state_dict"])
                    optimizer.param_groups[0]['lr'] = new_lr
                    print("[INFO] reload best model ..")

                if ep + 1 - last_update > PATIENT:
                    print("[FINAL] in epoch {}, the best develop encoding/pivoting accuracy = {:.4f}/{:.4f}".format(ep + 1,
                                                                                                                    best_accs["encode_acc"],
                                                                                                                    best_accs["pivot_acc"]))
                    break
Exemplo n.º 29
0
def train_supervised(datasets: List[DataLoader],
                     optimizer: torch.optim,
                     model: torch.nn.Module,
                     device: torch.device,
                     epoch: int,
                     print_freq: int,
                     writer: Optional[SummaryWriter] = None,
                     writer_iter: Optional[int] = None) -> int:
    """Trains the model

    Parameters
    ----------
    datasets: List[dataloader]
        the dataloader used for training

    optimizer: troch.optim
        The optimizer used during training for gradient descent

    model: subclass of troch.nn.Module
        The nn

    device: troch.device
        Location for where to put the data

    epoch: int
        Current epoch number

    print_freq: int
        Determines after how many training iterations notificaitons are printed
        to stdout

    writer: Optional[SummaryWriter]
        Tensorboard usmmary writter to save experiments

    writer_iter: Optional[int]
        Species the gradient steps (location) for the writer
    """
    # import pdb; pdb.set_trace()
    data = datasets[0]
    model.train()
    logger = get_logging(training=True)
    header = 'Epoch: [{}]'.format(epoch)
    lr_scheduler = learning_rate_scheduler(optimizer, epoch, len(data))
    for images, targets in logger.log_every(data, print_freq, header):
        optimizer.zero_grad()
        images = list(i.to(device, non_blocking=True) for i in images)
        targets = [{k: v.to(device, non_blocking=True)
                    for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        if writer:
            logging.log_losses(writer, loss_dict, print_freq, writer_iter)
        check_losses(losses, loss_dict, targets)
        losses.backward()
        optimizer.step()
        if lr_scheduler is not None:
            lr_scheduler.step()
        logger.update(loss=losses, **loss_dict)
        logger.update(lr=optimizer.param_groups[0]["lr"])
        writer_iter += 1
    return writer_iter
def train(
    config: DictConfig,
    model: nn.Module,
    device: torch.device,
    train_loader: AudioDataLoader,
    valid_loader: AudioDataLoader,
    train_sampler: BucketingSampler,
    optimizer: optim,
    epoch: int,
    id2char: dict,
    epoch_idx: int,
) -> None:
    train_epoch_loss = list()
    train_epoch_cer = list()
    train_epoch_result = {
        'train_loss': train_epoch_loss,
        'train_cer': train_epoch_cer
    }

    total_distance = 0
    total_length = 0
    ctcloss = nn.CTCLoss(blank=config.train.blank_id,
                         reduction='mean',
                         zero_infinity=True)
    crossentropyloss = nn.CrossEntropyLoss(ignore_index=config.train.pad_id,
                                           reduction='mean')

    model.train()

    for batch_idx, data in enumerate(train_loader):
        feature, target, feature_lengths, target_lengths = data

        feature = feature.to(device)
        target = target.to(device)
        feature_lengths = feature_lengths.to(device)
        target_lengths = target_lengths.to(device)

        result = target[:, 1:]

        optimizer.zero_grad()

        if config.model.architecture == 'las':
            encoder_output_prob, encoder_output_lengths, decoder_output_prob = model(
                feature, feature_lengths, target,
                config.model.teacher_forcing_ratio)

            decoder_output_prob = decoder_output_prob.to(device)
            decoder_output_prob = decoder_output_prob[:, :result.size(1), :]
            y_hat = decoder_output_prob.max(2)[1]  # (B, T)

            if config.model.use_joint_ctc_attention:
                encoder_output_prob = encoder_output_prob.transpose(0, 1)
                ctc_loss = ctcloss(encoder_output_prob, target,
                                   encoder_output_lengths, target_lengths)

                cross_entropy_loss = crossentropyloss(
                    decoder_output_prob.contiguous().view(
                        -1, decoder_output_prob.size(2)),
                    result.contiguous().view(-1))
                loss = config.model.ctc_weight * ctc_loss + config.model.cross_entropy_weight * cross_entropy_loss

            else:
                loss = crossentropyloss(
                    decoder_output_prob.contiguous().view(
                        -1, decoder_output_prob.size(2)),
                    result.contiguous().view(-1))

        elif config.model.architecture == 'deepspeech2':
            output_prob, output_lengths = model(feature, feature_lengths)
            loss = ctcloss(output_prob.transpose(0, 1), target, output_lengths,
                           target_lengths)
            y_hat = output_prob.max(2)[1]

        loss.backward()
        optimizer.step()

        torch.cuda.empty_cache()

        result = label_to_string(config.train.eos_id, config.train.blank_id,
                                 result, id2char)
        y_hat = label_to_string(config.train.eos_id, config.train.blank_id,
                                y_hat, id2char)

        distance, length = get_distance(result, y_hat)

        total_distance += distance
        total_length += length

        cer = total_distance / total_length

        if batch_idx % config.train.train_save_epoch_interval == 0:
            train_epoch_loss.append(loss)
            train_epoch_cer.append(cer)

        if config.model.architecture == 'las' and config.model.use_joint_ctc_attention:
            if batch_idx % config.train.print_interval == 0:
                print('Epoch {epoch} : {batch_idx} / {total_idx}\t'
                      'CTC loss : {ctc_loss:.4f}\t'
                      'Cross entropy loss : {cross_entropy_loss:.4f}\t'
                      'loss : {loss:.4f}\t'
                      'cer : {cer:.2f}\t'.format(
                          epoch=epoch,
                          batch_idx=batch_idx,
                          total_idx=len(train_sampler),
                          ctc_loss=ctc_loss,
                          cross_entropy_loss=cross_entropy_loss,
                          loss=loss,
                          cer=cer))

        else:
            if batch_idx % config.train.print_interval == 0:
                print('Epoch {epoch} : {batch_idx} / {total_idx}\t'
                      'loss : {loss:.4f}\t'
                      'cer : {cer:.2f}\t'.format(epoch=epoch,
                                                 batch_idx=batch_idx,
                                                 total_idx=len(train_sampler),
                                                 loss=loss,
                                                 cer=cer))

    validation_epoch_result = validate(config, model, device, valid_loader,
                                       id2char, ctcloss, crossentropyloss)

    save_train_epoch_result(train_epoch_result, config, epoch_idx)
    save_validation_epoch_result(validation_epoch_result, config, epoch_idx)