def calc_loss(criterion, outputs, labels, opts, cycle_n=None):
    if opts.use_schp:
        labels, soft_labels = labels
        soft_preds, soft_edges = soft_labels
    else:
        soft_preds = None
        soft_edges = None
    if opts.use_mixup:
        if 'edgev2' in opts.model:
            edges = True
        else:
            edges = False
        labels_a, labels_b, lam = labels
        loss = mixup_criterion(criterion,
                               outputs,
                               labels_a,
                               labels_b,
                               lam,
                               edges=edges,
                               soft_preds=soft_preds,
                               soft_edges=soft_edges,
                               cycle_n=cycle_n)
    else:
        if 'ACE2P' in opts.model:
            loss = criterion(outputs,
                             labels[0],
                             edges=labels[1],
                             soft_preds=soft_preds,
                             soft_edges=soft_edges,
                             cycle_n=cycle_n)
        elif 'edgev1' in opts.model:
            loss_fusion = criterion(outputs[0],
                                    labels[0],
                                    soft_preds=soft_preds,
                                    soft_edges=soft_edges,
                                    cycle_n=cycle_n)
            loss_class = criterion(outputs[1],
                                   labels[0],
                                   soft_preds=soft_preds,
                                   soft_edges=soft_edges,
                                   cycle_n=cycle_n)
            loss_edge = torch.nn.MSELoss()(outputs[2], labels[1])
            loss = loss_class + loss_edge + loss_fusion
        elif 'edgev2' in opts.model:
            loss = criterion(outputs,
                             labels[0],
                             edges=labels[1],
                             soft_preds=soft_preds,
                             soft_edges=soft_edges,
                             cycle_n=cycle_n)
            # loss_edge = EdgeLoss()(outputs, labels[1])
            # loss = loss_class + (opts.edge_loss_weight * loss_edge)
        else:
            loss = criterion(outputs,
                             labels,
                             soft_preds=soft_preds,
                             soft_edges=soft_edges,
                             cycle_n=cycle_n)

    return loss
Exemple #2
0
def train_epoch(dl_train, args):
    model, optimizer, criterion, scheduler = args
    model.train()
    train_loss = []
    for batch_idx, (x, y) in enumerate(dl_train):
        x = x.cuda()
        y = y.cuda()
        if alpha != 0:
            xm, ya, yb, lam = mixup_data(x, y, alpha)
            yhat = model(xm)
            loss = mixup_criterion(criterion, yhat, ya, yb, lam)
        else:
            yhat = model(x)
        loss = criterion(yhat, y)
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss.append(loss.detach().cpu().numpy())
        scheduler.step()
    return np.nanmean(train_loss)
Exemple #3
0
print('[*] train start !!!!!!!!!!!')
for epoch in range(EPOCHS):
    net.train()
    train_loss = 0
    total = 0
    best_acc = 0
    best_epoch = 0
    for i, data in enumerate(trainloader):
        img, label = data[0].cuda(), data[1].cuda()
        batch_size = img.size(0)
        optimizer.zero_grad()
        if MIXUP:
            img, labela, labelb, lam = mixup_data(img, label)
            pre = net(img)
            criterion = torch.nn.CrossEntropyLoss()
            loss = mixup_criterion(criterion, pre, labela, labelb, lam)
        else:
            pre = net(img)
            loss = torch.nn.CrossEntropyLoss()(pre, label)
        train_loss += loss * batch_size
        total += batch_size
        loss.backward()
        optimizer.step()
        progress_bar(i, len(trainloader), 'train')
    if epoch > WARM:
        scheduler.step()
    else:
        warmup.step()
    print('[*] epoch:{} - train loss: {:.3f}'.format(epoch,
                                                     train_loss / total))
    acc, recalldic, precisiondic = eval_fuse(testloader, net,
Exemple #4
0
def train(writer,
          train_loader,
          val_loader,
          device,
          criterion,
          net,
          optimizer,
          lr_scheduler,
          num_epochs,
          log_file,
          alpha=None,
          is_mixed_precision=False,
          loss_freq=10,
          val_num_steps=None,
          best_acc=0,
          fine_grain=False,
          decay=0.999):
    # Define validation and loss value print frequency
    if len(train_loader) > loss_freq:
        loss_num_steps = int(len(train_loader) / loss_freq)
    else:  # For extremely small sets
        loss_num_steps = len(train_loader)
    if val_num_steps is None:
        val_num_steps = len(train_loader)

    net.train()

    # Use EMA to report final performance instead of select best checkpoint with valtiny
    ema = EMA(net=net, decay=decay)

    epoch = 0

    # Training
    running_loss = 0.0
    while epoch < num_epochs:
        train_correct = 0
        train_all = 0
        time_now = time.time()
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            train_all += labels.shape[0]

            # mixup data within the batch
            if alpha is not None:
                inputs, labels_a, labels_b, lam = mixup_data(x=inputs,
                                                             y=labels,
                                                             alpha=alpha)

            outputs = net(inputs)

            if alpha is not None:
                # Pseudo training accuracy & interesting loss
                loss = mixup_criterion(criterion, outputs, labels_a, labels_b,
                                       lam)
                predicted = outputs.argmax(1)
                train_correct += (
                    lam * (predicted == labels_a).sum().float().item() +
                    (1 - lam) * (predicted == labels_b).sum().float().item())
            else:
                train_correct += (labels == outputs.argmax(1)).sum().item()
                loss = criterion(outputs, labels)

            if is_mixed_precision:
                # 2/3 & 3/3 of mixed precision training with amp
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            if lr_scheduler is not None:
                lr_scheduler.step()

            # EMA update
            ema.update(net=net)

            # Logging
            running_loss += loss.item()
            current_step_num = int(epoch * len(train_loader) + i + 1)
            if current_step_num % loss_num_steps == (loss_num_steps - 1):
                print('[%d, %d] loss: %.4f' %
                      (epoch + 1, i + 1, running_loss / loss_num_steps))
                writer.add_scalar('training loss',
                                  running_loss / loss_num_steps,
                                  current_step_num)
                running_loss = 0.0

            # Validate and find the best snapshot
            if current_step_num % val_num_steps == (val_num_steps - 1) or \
               current_step_num == num_epochs * len(train_loader) - 1:
                # A bug in Apex? https://github.com/NVIDIA/apex/issues/706
                test_acc = test(loader=val_loader,
                                device=device,
                                net=net,
                                fine_grain=fine_grain)
                writer.add_scalar('test accuracy', test_acc, current_step_num)
                net.train()

                # Record best model(Straight to disk)
                if test_acc > best_acc:
                    best_acc = test_acc
                    save_checkpoint(net=net,
                                    optimizer=optimizer,
                                    lr_scheduler=lr_scheduler,
                                    is_mixed_precision=is_mixed_precision,
                                    filename=log_file + '_temp.pt')

        # Evaluate training accuracies (same metric as validation, but must be on-the-fly to save time)
        train_acc = train_correct / train_all * 100
        print('Train accuracy: %.4f' % train_acc)

        writer.add_scalar('train accuracy', train_acc, epoch + 1)

        epoch += 1
        print('Epoch time: %.2fs' % (time.time() - time_now))

    ema.fill_in_bn(state_dict=net.state_dict())
    save_checkpoint(net=ema,
                    optimizer=None,
                    lr_scheduler=None,
                    is_mixed_precision=False,
                    filename=log_file + '_temp-ema.pt')
    return best_acc