Esempio n. 1
0
def train_epoch(dl_train, args):
    model, optimizer, criterion, scheduler = args
    model.train()
    train_loss = []
    for batch_idx, (x, y) in enumerate(dl_train):
        x = x.cuda()
        y = y.cuda()
        if alpha != 0:
            xm, ya, yb, lam = mixup_data(x, y, alpha)
            yhat = model(xm)
            loss = mixup_criterion(criterion, yhat, ya, yb, lam)
        else:
            yhat = model(x)
        loss = criterion(yhat, y)
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss.append(loss.detach().cpu().numpy())
        scheduler.step()
    return np.nanmean(train_loss)
Esempio n. 2
0
scheduler = get_lr_scheduler(optimizer, LRTAG)
warmup = WarmUpLR(optimizer, iters * WARM)

print('[*] train start !!!!!!!!!!!')
for epoch in range(EPOCHS):
    net.train()
    train_loss = 0
    total = 0
    best_acc = 0
    best_epoch = 0
    for i, data in enumerate(trainloader):
        img, label = data[0].cuda(), data[1].cuda()
        batch_size = img.size(0)
        optimizer.zero_grad()
        if MIXUP:
            img, labela, labelb, lam = mixup_data(img, label)
            pre = net(img)
            criterion = torch.nn.CrossEntropyLoss()
            loss = mixup_criterion(criterion, pre, labela, labelb, lam)
        else:
            pre = net(img)
            loss = torch.nn.CrossEntropyLoss()(pre, label)
        train_loss += loss * batch_size
        total += batch_size
        loss.backward()
        optimizer.step()
        progress_bar(i, len(trainloader), 'train')
    if epoch > WARM:
        scheduler.step()
    else:
        warmup.step()
Esempio n. 3
0
def train(writer,
          train_loader,
          val_loader,
          device,
          criterion,
          net,
          optimizer,
          lr_scheduler,
          num_epochs,
          log_file,
          alpha=None,
          is_mixed_precision=False,
          loss_freq=10,
          val_num_steps=None,
          best_acc=0,
          fine_grain=False,
          decay=0.999):
    # Define validation and loss value print frequency
    if len(train_loader) > loss_freq:
        loss_num_steps = int(len(train_loader) / loss_freq)
    else:  # For extremely small sets
        loss_num_steps = len(train_loader)
    if val_num_steps is None:
        val_num_steps = len(train_loader)

    net.train()

    # Use EMA to report final performance instead of select best checkpoint with valtiny
    ema = EMA(net=net, decay=decay)

    epoch = 0

    # Training
    running_loss = 0.0
    while epoch < num_epochs:
        train_correct = 0
        train_all = 0
        time_now = time.time()
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            train_all += labels.shape[0]

            # mixup data within the batch
            if alpha is not None:
                inputs, labels_a, labels_b, lam = mixup_data(x=inputs,
                                                             y=labels,
                                                             alpha=alpha)

            outputs = net(inputs)

            if alpha is not None:
                # Pseudo training accuracy & interesting loss
                loss = mixup_criterion(criterion, outputs, labels_a, labels_b,
                                       lam)
                predicted = outputs.argmax(1)
                train_correct += (
                    lam * (predicted == labels_a).sum().float().item() +
                    (1 - lam) * (predicted == labels_b).sum().float().item())
            else:
                train_correct += (labels == outputs.argmax(1)).sum().item()
                loss = criterion(outputs, labels)

            if is_mixed_precision:
                # 2/3 & 3/3 of mixed precision training with amp
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            optimizer.step()
            if lr_scheduler is not None:
                lr_scheduler.step()

            # EMA update
            ema.update(net=net)

            # Logging
            running_loss += loss.item()
            current_step_num = int(epoch * len(train_loader) + i + 1)
            if current_step_num % loss_num_steps == (loss_num_steps - 1):
                print('[%d, %d] loss: %.4f' %
                      (epoch + 1, i + 1, running_loss / loss_num_steps))
                writer.add_scalar('training loss',
                                  running_loss / loss_num_steps,
                                  current_step_num)
                running_loss = 0.0

            # Validate and find the best snapshot
            if current_step_num % val_num_steps == (val_num_steps - 1) or \
               current_step_num == num_epochs * len(train_loader) - 1:
                # A bug in Apex? https://github.com/NVIDIA/apex/issues/706
                test_acc = test(loader=val_loader,
                                device=device,
                                net=net,
                                fine_grain=fine_grain)
                writer.add_scalar('test accuracy', test_acc, current_step_num)
                net.train()

                # Record best model(Straight to disk)
                if test_acc > best_acc:
                    best_acc = test_acc
                    save_checkpoint(net=net,
                                    optimizer=optimizer,
                                    lr_scheduler=lr_scheduler,
                                    is_mixed_precision=is_mixed_precision,
                                    filename=log_file + '_temp.pt')

        # Evaluate training accuracies (same metric as validation, but must be on-the-fly to save time)
        train_acc = train_correct / train_all * 100
        print('Train accuracy: %.4f' % train_acc)

        writer.add_scalar('train accuracy', train_acc, epoch + 1)

        epoch += 1
        print('Epoch time: %.2fs' % (time.time() - time_now))

    ema.fill_in_bn(state_dict=net.state_dict())
    save_checkpoint(net=ema,
                    optimizer=None,
                    lr_scheduler=None,
                    is_mixed_precision=False,
                    filename=log_file + '_temp-ema.pt')
    return best_acc
Esempio n. 4
0
def train(writer,
          labeled_loader,
          pseudo_labeled_loader,
          val_loader,
          device,
          criterion,
          net,
          optimizer,
          lr_scheduler,
          num_epochs,
          tensorboard_prefix,
          gamma1,
          gamma2,
          labeled_weight,
          start_at,
          num_classes,
          decay=0.999,
          alpha=-1,
          is_mixed_precision=False,
          loss_freq=10,
          val_num_steps=None,
          best_acc=0,
          fine_grain=False):
    # Define validation and loss value print frequency
    # Pseudo labeled defines epoch
    min_len = len(pseudo_labeled_loader)
    if min_len > loss_freq:
        loss_num_steps = int(min_len / loss_freq)
    else:  # For extremely small sets
        loss_num_steps = min_len
    if val_num_steps is None:
        val_num_steps = min_len

    if is_mixed_precision:
        scaler = GradScaler()

    net.train()

    # Use EMA to report final performance instead of select best checkpoint with valtiny
    ema = EMA(net=net, decay=decay)

    epoch = 0

    # Training
    running_loss = 0.0
    running_stats = {
        'disagree': -1,
        'current_win': -1,
        'avg_weights': 1.0,
        'gamma1': 0,
        'gamma2': 0
    }
    iter_labeled = iter(labeled_loader)
    while epoch < num_epochs:
        train_correct = 0
        train_all = 0
        time_now = time.time()
        for i, data in enumerate(pseudo_labeled_loader, 0):
            # Pseudo labeled data
            inputs_pseudo, labels_pseudo = data
            inputs_pseudo, labels_pseudo = inputs_pseudo.to(
                device), labels_pseudo.to(device)

            # Hard labels
            probs_pseudo = labels_pseudo.clone().detach()
            labels_pseudo = labels_pseudo.argmax(-1)  # data type?

            # Labeled data
            inputs_labeled, labels_labeled = next(iter_labeled, (0, 0))
            if type(inputs_labeled) == type(labels_labeled) == int:
                iter_labeled = iter(labeled_loader)
                inputs_labeled, labels_labeled = next(iter_labeled, (0, 0))
            inputs_labeled, labels_labeled = inputs_labeled.to(
                device), labels_labeled.to(device)

            # To probabilities (in fact, just one-hot)
            probs_labeled = torch.nn.functional.one_hot(labels_labeled.clone().detach(), num_classes=num_classes) \
                .float()

            # Combine
            inputs = torch.cat([inputs_pseudo, inputs_labeled])
            labels = torch.cat([labels_pseudo, labels_labeled])
            probs = torch.cat([probs_pseudo, probs_labeled])
            optimizer.zero_grad()
            train_all += labels.shape[0]

            # mixup data within the batch
            if alpha != -1:
                dynamic_weights, stats = criterion.dynamic_weights_calc(
                    net=net,
                    inputs=inputs,
                    targets=probs,
                    split_index=inputs_pseudo.shape[0],
                    labeled_weight=labeled_weight)
                inputs, dynamic_weights, labels_a, labels_b, lam = mixup_data(
                    x=inputs,
                    w=dynamic_weights,
                    y=labels,
                    alpha=alpha,
                    keep_max=True)
            with autocast(is_mixed_precision):
                outputs = net(inputs)

            if alpha != -1:
                # Pseudo training accuracy & interesting loss
                predicted = outputs.argmax(1)
                train_correct += (
                    lam * (predicted == labels_a).sum().float().item() +
                    (1 - lam) * (predicted == labels_b).sum().float().item())
                loss, true_loss = criterion(pred=outputs,
                                            y_a=labels_a,
                                            y_b=labels_b,
                                            lam=lam,
                                            dynamic_weights=dynamic_weights)
            else:
                train_correct += (labels == outputs.argmax(1)).sum().item()
                loss, true_loss, stats = criterion(
                    inputs=outputs,
                    targets=probs,
                    split_index=inputs_pseudo.shape[0],
                    gamma1=gamma1,
                    gamma2=gamma2)

            if is_mixed_precision:
                accelerator.backward(scaler.scale(loss))
                scaler.step(optimizer)
                scaler.update()
            else:
                accelerator.backward(loss)
                optimizer.step()
            criterion.step()
            if lr_scheduler is not None:
                lr_scheduler.step()

            # EMA update
            ema.update(net=net)

            # Logging
            running_loss += true_loss
            for key in stats.keys():
                running_stats[key] += stats[key]
            current_step_num = int(epoch * len(pseudo_labeled_loader) + i + 1)
            if current_step_num % loss_num_steps == (loss_num_steps - 1):
                print('[%d, %d] loss: %.4f' %
                      (epoch + 1, i + 1, running_loss / loss_num_steps))
                writer.add_scalar(tensorboard_prefix + 'training loss',
                                  running_loss / loss_num_steps,
                                  current_step_num)
                running_loss = 0.0
                for key in stats.keys():
                    print('[%d, %d] ' % (epoch + 1, i + 1) + key + ' : %.4f' %
                          (running_stats[key] / loss_num_steps))
                    writer.add_scalar(tensorboard_prefix + key,
                                      running_stats[key] / loss_num_steps,
                                      current_step_num)
                    running_stats[key] = 0.0

            # Validate and find the best snapshot
            if current_step_num % val_num_steps == (val_num_steps - 1) or \
               current_step_num == num_epochs * len(pseudo_labeled_loader) - 1:
                # Apex bug https://github.com/NVIDIA/apex/issues/706, fixed in PyTorch1.6, kept here for BC
                test_acc = test(loader=val_loader,
                                device=device,
                                net=net,
                                fine_grain=fine_grain,
                                is_mixed_precision=is_mixed_precision)
                writer.add_scalar(tensorboard_prefix + 'test accuracy',
                                  test_acc, current_step_num)
                net.train()

                # Record best model(Straight to disk)
                if test_acc >= best_acc:
                    best_acc = test_acc
                    save_checkpoint(net=net,
                                    optimizer=optimizer,
                                    lr_scheduler=lr_scheduler,
                                    is_mixed_precision=is_mixed_precision)

        # Evaluate training accuracies (same metric as validation, but must be on-the-fly to save time)
        train_acc = train_correct / train_all * 100
        print('Train accuracy: %.4f' % train_acc)

        writer.add_scalar(tensorboard_prefix + 'train accuracy', train_acc,
                          epoch + 1)

        epoch += 1
        print('Epoch time: %.2fs' % (time.time() - time_now))

    ema.fill_in_bn(state_dict=net.state_dict())
    save_checkpoint(net=ema,
                    optimizer=None,
                    lr_scheduler=None,
                    is_mixed_precision=False,
                    filename='temp-ema.pt')
    return best_acc
def get_input(images, labels, opts, device, cur_iter):
    if 'ACE2P' in opts.model:
        edges = generate_edge_tensor(labels)
        edges = edges.type(torch.cuda.LongTensor)
    elif 'edge' in opts.model:
        edges = labels[1]
        edges = edges.to(device, dtype=torch.float32)
        labels = labels[0]
    else:
        edges = None
    images = images.to(device, dtype=torch.float32)
    labels = labels.to(device, dtype=torch.long)

    if opts.use_mixup:
        if edges is not None:
            labels = [labels, edges]
            has_edge = True
        else:
            has_edge = False
        if opts.use_mixup_mwh:
            stage1, stage2 = (np.array(opts.mwh_stages) *
                              opts.total_itrs).astype(int)
            mask = random.random()
            if cur_iter >= stage2:
                # threshold = math.cos( math.pi * (epoch - 150) / ((200 - 150) * 2))
                threshold = (opts.total_itrs - cur_iter) / (opts.total_itrs -
                                                            stage2)
                # threshold = 1.0 - math.cos( math.pi * (200 - epoch) / ((200 - 150) * 2))
                if mask < threshold:
                    images, labels_a, labels_b, lam = mixup_data(
                        images,
                        labels,
                        opts.mixup_alpha,
                        device,
                        has_edge=has_edge)
                else:
                    images = images, [images, images]
                    labels_a, labels_b = labels, labels
                    lam = 1.0
            elif cur_iter >= stage1:
                # in the main paper it was each epochs or mini batch, here i changed it to val_interval iterations
                if (cur_iter // opts.val_interval) % 2 == 0:
                    images, labels_a, labels_b, lam = mixup_data(
                        images,
                        labels,
                        opts.mixup_alpha,
                        device,
                        has_edge=has_edge)
                else:
                    images = images, [images, images]
                    labels_a, labels_b = labels, labels
                    lam = 1.0
            else:
                images, labels_a, labels_b, lam = mixup_data(images,
                                                             labels,
                                                             opts.mixup_alpha,
                                                             device,
                                                             has_edge=has_edge)
        else:
            images, labels_a, labels_b, lam = mixup_data(images,
                                                         labels,
                                                         opts.mixup_alpha,
                                                         device,
                                                         has_edge=has_edge)
            if has_edge:
                images[0], images[1][0], images[1][1], labels_a[
                    0], labels_a[1], labels_b[0], labels_b[1] = map(
                        Variable,
                        (images[0], images[1][0], images[1][1], labels_a[0],
                         labels_a[1], labels_b[0], labels_b[1]))
            else:
                images[0], images[1][
                    0], images[1][1], labels_a, labels_b = map(
                        Variable, (images[0], images[1][0], images[1][1],
                                   labels_a, labels_b))

        return images, [labels_a, labels_b, lam]
    else:
        if 'ACE2P' or 'edgev1' in opts.model:
            return images, [labels, edges]
        else:
            return images, labels