示例#1
0
def train_for_one_epoch(model, g_loss, discriminator_loss, train_loader,
                        optimizer, epoch_number):
    model.train()
    g_loss.train()

    data_time_meter = utils.AverageMeter()
    batch_time_meter = utils.AverageMeter()
    g_loss_meter = utils.AverageMeter(recent=100)
    d_loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels) in enumerate(train_loader):
        batch_size = images.size(0)

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        # Record data time
        data_time_meter.update(time.time() - timestamp)

        # Forward pass, backward pass, and update parameters.
        outputs = model(images, before=True)
        output, soft_label, soft_no_softmax = outputs
        g_loss_output = g_loss((output, soft_label), labels)
        d_loss_value = discriminator_loss([output], [soft_no_softmax])

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if isinstance(g_loss_output, tuple):
            g_loss_value, outputs = g_loss_output
        else:
            g_loss_value = g_loss_output

        loss_value = g_loss_value + d_loss_value

        loss_value.backward()

        # Update parameters and reset gradients.
        optimizer.step()
        optimizer.zero_grad()

        # Record loss and model accuracy.
        g_loss_meter.update(g_loss_value.item(), batch_size)
        d_loss_meter.update(d_loss_value.item(), batch_size)

        top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()

        if i % 20 == 0:
            logging.info(
                'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
                'Time {batch_time.value:.2f} ({batch_time.average:.2f})   '
                'Data {data_time.value:.2f} ({data_time.average:.2f})   '
                'G_Loss {g_loss.value:.3f} {{{g_loss.average:.3f}, {g_loss.average_recent:.3f}}}    '
                'D_Loss {d_loss.value:.3f} {{{d_loss.average:.3f}, {d_loss.average_recent:.3f}}}    '
                'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}}    '
                'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}}    '
                'LR {lr:.5f}'.format(epoch=epoch_number,
                                     batch=i + 1,
                                     epoch_size=len(train_loader),
                                     batch_time=batch_time_meter,
                                     data_time=data_time_meter,
                                     g_loss=g_loss_meter,
                                     d_loss=d_loss_meter,
                                     top1=top1_meter,
                                     top5=top5_meter,
                                     lr=_get_learning_rate(optimizer)))
            summary_dd = collections.defaultdict(dict)
            summary_dd['batch_time']['batch_time'] = batch_time_meter.value
            summary_dd['data_time']['data_time'] = data_time_meter.value
            summary_dd['GAN']['G_Loss'] = g_loss_meter.average_recent
            summary_dd['GAN']['D_Loss'] = d_loss_meter.average_recent
            summary_dd['top1']['top1'] = top1_meter.average_recent
            summary_dd['top5']['top5'] = top5_meter.average_recent
            summary_dd['LR']['LR'] = _get_learning_rate(optimizer)
            summary_defaultdict2txtfig(
                default_dict=summary_dd,
                prefix='train',
                step=(epoch_number - 1) * len(train_loader) + i,
                textlogger=global_textlogger)
        if getattr(global_cfg, 'train_dummy', False):
            break

    # Log the overall train stats
    logging.info('Epoch: [{epoch}] -- TRAINING SUMMARY\t'
                 'Time {batch_time.sum:.2f}   '
                 'Data {data_time.sum:.2f}   '
                 'G_Loss {g_loss.average:.3f}     '
                 'D_Loss {d_loss.average:.3f}     '
                 'Top-1 {top1.average:.2f}    '
                 'Top-5 {top5.average:.2f}    '.format(
                     epoch=epoch_number,
                     batch_time=batch_time_meter,
                     data_time=data_time_meter,
                     g_loss=g_loss_meter,
                     d_loss=d_loss_meter,
                     top1=top1_meter,
                     top5=top5_meter))
def test_for_one_epoch(model, loss, test_loader, epoch_number):
    model.eval()
    loss.eval()

    data_time_meter = utils.AverageMeter()
    batch_time_meter = utils.AverageMeter()
    loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels) in enumerate(test_loader):
        batch_size = images.size(0)

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        # Record data time
        data_time_meter.update(time.time() - timestamp)

        # Forward pass without computing gradients.
        with torch.no_grad():
            outputs = model(images)
            loss_output = loss(outputs, labels)

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if isinstance(loss_output, tuple):
            loss_value, outputs = loss_output
        else:
            loss_value = loss_output

        # Record loss and model accuracy.
        loss_meter.update(loss_value.item(), batch_size)
        top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()

        logging.info(
            'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
            'Time {batch_time.value:.2f} ({batch_time.average:.2f})   '
            'Data {data_time.value:.2f} ({data_time.average:.2f})   '
            'Loss {loss.value:.3f} {{{loss.average:.3f}, {loss.average_recent:.3f}}}    '
            'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}}    '
            'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}}    '
            .format(epoch=epoch_number,
                    batch=i + 1,
                    epoch_size=len(test_loader),
                    batch_time=batch_time_meter,
                    data_time=data_time_meter,
                    loss=loss_meter,
                    top1=top1_meter,
                    top5=top5_meter))
    # Log the overall test stats
    logging.info('Epoch: [{epoch}] -- TESTING SUMMARY\t'
                 'Time {batch_time.sum:.2f}   '
                 'Data {data_time.sum:.2f}   '
                 'Loss {loss.average:.3f}     '
                 'Top-1 {top1.average:.2f}    '
                 'Top-5 {top5.average:.2f}    '.format(
                     epoch=epoch_number,
                     batch_time=batch_time_meter,
                     data_time=data_time_meter,
                     loss=loss_meter,
                     top1=top1_meter,
                     top5=top5_meter))
示例#3
0
def train_for_one_epoch(model, loss, train_loader, optimizer, epoch_number):
    model.train()
    loss.train()

    data_time_meter = utils.AverageMeter()
    batch_time_meter = utils.AverageMeter()
    loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels) in enumerate(train_loader):
        batch_size = images.size(0)

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        # Record data time
        data_time_meter.update(time.time() - timestamp)

        # Forward pass, backward pass, and update parameters.
        outputs = model(images)
        loss_output = loss(outputs, labels)

        ############# Pruning ###############
        # weights = [ p for n,p in model.named_parameters() if 'weight' in n and 'se' not in n and 'conv' in n and len(p.size())==4]
        # lamb = 0.001
        # reg_loss = gpls(weights, lamb)
        ####################################

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if isinstance(loss_output, tuple):
            loss_value, outputs = loss_output
        else:
            loss_value = loss_output

        ########### Prning #################
        # loss_value += reg_loss
        ####################################

        loss_value.backward()

        # Update parameters and reset gradients.
        optimizer.step()
        optimizer.zero_grad()

        # Record loss and model accuracy.
        loss_meter.update(loss_value.item(), batch_size)
        top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()
        if i % 100 == 0:
            logging.info(
                'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
                'Time {batch_time.value:.2f} ({batch_time.average:.2f})   '
                'Data {data_time.value:.2f} ({data_time.average:.2f})   '
                'Loss {loss.value:.3f} {{{loss.average:.3f}, {loss.average_recent:.3f}}}    '
                'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}}    '
                'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}}    '
                'LR {lr:.5f}'.format(epoch=epoch_number,
                                     batch=i + 1,
                                     epoch_size=len(train_loader),
                                     batch_time=batch_time_meter,
                                     data_time=data_time_meter,
                                     loss=loss_meter,
                                     top1=top1_meter,
                                     top5=top5_meter,
                                     lr=_get_learning_rate(optimizer)))

    ############# Pruning ##############
    # weights = [ p for n,p in model.named_parameters() if 'weight' in n and 'se' not in n and 'conv' in n and len(p.size())==4]
    # for wt in weights:
    #     norm_ch = wt.pow(2).sum(dim=[0,2,3]).pow(1/2.)
    #     for i in range(len(norm_ch)):
    #         if norm_ch[i]<1e-8:
    #             wt[:,i,:,:].data *= 0
    ####################################

    # Log the overall train stats
    logging.info('Epoch: [{epoch}] -- TRAINING SUMMARY\t'
                 'Time {batch_time.sum:.2f}   '
                 'Data {data_time.sum:.2f}   '
                 'Loss {loss.average:.3f}     '
                 'Top-1 {top1.average:.2f}    '
                 'Top-5 {top5.average:.2f}    '.format(
                     epoch=epoch_number,
                     batch_time=batch_time_meter,
                     data_time=data_time_meter,
                     loss=loss_meter,
                     top1=top1_meter,
                     top5=top5_meter))
示例#4
0
def train_for_one_epoch(model, loss, train_loader, optimizer, epoch_number):
    model.train()
    loss.train()

    data_time_meter = utils.AverageMeter()
    batch_time_meter = utils.AverageMeter()
    loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels) in enumerate(train_loader):
        batch_size = images.size(0)

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        # Record data time
        data_time_meter.update(time.time() - timestamp)

        # Forward pass, backward pass, and update parameters.
        outputs = model(images)
        loss_output = loss(outputs, labels)

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if isinstance(loss_output, tuple):
            loss_value, outputs = loss_output
        else:
            loss_value = loss_output
        loss_value.backward()

        # Update parameters and reset gradients.
        optimizer.step()
        optimizer.zero_grad()

        # Record loss and model accuracy.
        loss_meter.update(loss_value.item(), batch_size)
        top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()
        if i % 50 == 0:
            logging.info(
                'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
                'Time {batch_time.value:.2f} ({batch_time.average:.2f})   '
                'Data {data_time.value:.2f} ({data_time.average:.2f})   '
                'Loss {loss.value:.3f} {{{loss.average:.3f}, {loss.average_recent:.3f}}}    '
                'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}}    '
                'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}}    '
                'LR {lr:.5f}'.format(
                    epoch=epoch_number, batch=i + 1, epoch_size=len(train_loader),
                    batch_time=batch_time_meter, data_time=data_time_meter,
                    loss=loss_meter, top1=top1_meter, top5=top5_meter,
                    lr=_get_learning_rate(optimizer)))

    # Log the overall train stats
    logging.info(
        'Epoch: [{epoch}] -- TRAINING SUMMARY\t'
        'Time {batch_time.sum:.2f}   '
        'Data {data_time.sum:.2f}   '
        'Loss {loss.average:.3f}     '
        'Top-1 {top1.average:.2f}    '
        'Top-5 {top5.average:.2f}    '.format(
            epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter,
            loss=loss_meter, top1=top1_meter, top5=top5_meter))
示例#5
0
文件: train.py 项目: szq0214/MEAL-V2
def train_for_one_epoch(model, g_loss, discriminator_loss, train_loader, optimizer, epoch_number, args):
    model.train()
    g_loss.train()

    data_time_meter = utils.AverageMeter()
    batch_time_meter = utils.AverageMeter()
    g_loss_meter = utils.AverageMeter(recent=100)
    d_loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels) in enumerate(train_loader):
        batch_size = images.size(0)

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        # Record data time
        data_time_meter.update(time.time() - timestamp)

        if args.w_cutmix == True:
            r = np.random.rand(1)
            if args.beta > 0 and r < args.cutmix_prob:
                # generate mixed sample
                lam = np.random.beta(args.beta, args.beta)
                rand_index = torch.randperm(images.size()[0]).cuda()
                target_a = labels
                target_b = labels[rand_index]
                bbx1, bby1, bbx2, bby2 = utils.rand_bbox(images.size(), lam)
                images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]

        # Forward pass, backward pass, and update parameters.
        outputs = model(images, before=True)
        output, soft_label, soft_no_softmax = outputs
        g_loss_output = g_loss((output, soft_label), labels)
        d_loss_value = discriminator_loss([output], [soft_no_softmax])

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if isinstance(g_loss_output, tuple):
            g_loss_value, outputs = g_loss_output
        else:
            g_loss_value = g_loss_output

        loss_value = g_loss_value + d_loss_value

        loss_value.backward()

        # Update parameters and reset gradients.
        optimizer.step()
        optimizer.zero_grad()

        # Record loss and model accuracy.
        g_loss_meter.update(g_loss_value.item(), batch_size)
        d_loss_meter.update(d_loss_value.item(), batch_size)

        top1, top5 = utils.topk_accuracy(outputs, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()

        if i%20 == 0:
            logging.info(
                'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
                'Time {batch_time.value:.2f} ({batch_time.average:.2f})   '
                'Data {data_time.value:.2f} ({data_time.average:.2f})   '
                'G_Loss {g_loss.value:.3f} {{{g_loss.average:.3f}, {g_loss.average_recent:.3f}}}    '
                'D_Loss {d_loss.value:.3f} {{{d_loss.average:.3f}, {d_loss.average_recent:.3f}}}    '
                'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}}    '
                'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}}    '
                'LR {lr:.5f}'.format(
                    epoch=epoch_number, batch=i + 1, epoch_size=len(train_loader),
                    batch_time=batch_time_meter, data_time=data_time_meter,
                    g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter,
                    lr=_get_learning_rate(optimizer)))
    # Log the overall train stats
    logging.info(
        'Epoch: [{epoch}] -- TRAINING SUMMARY\t'
        'Time {batch_time.sum:.2f}   '
        'Data {data_time.sum:.2f}   '
        'G_Loss {g_loss.average:.3f}     '
        'D_Loss {d_loss.average:.3f}     '
        'Top-1 {top1.average:.2f}    '
        'Top-5 {top5.average:.2f}    '.format(
            epoch=epoch_number, batch_time=batch_time_meter, data_time=data_time_meter,
            g_loss=g_loss_meter, d_loss=d_loss_meter, top1=top1_meter, top5=top5_meter))
示例#6
0
def train_for_one_epoch(model, g_loss, discriminator_loss, train_loader,
                        optimizer, epoch_number, args):
    model.train()
    g_loss.train()

    data_time_meter = utils.AverageMeter()
    batch_time_meter = utils.AverageMeter()
    g_loss_meter = utils.AverageMeter(recent=100)
    d_loss_meter = utils.AverageMeter(recent=100)
    top1_meter = utils.AverageMeter(recent=100)
    top5_meter = utils.AverageMeter(recent=100)

    timestamp = time.time()
    for i, (images, labels, soft_labels) in enumerate(train_loader):
        batch_size = args.batch_size

        # Record data time
        data_time_meter.update(time.time() - timestamp)

        images = torch.cat(images, dim=0)
        soft_labels = torch.cat(soft_labels, dim=0)
        labels = torch.cat(labels, dim=0)

        if args.soft_label_type == 'ori':
            soft_labels = soft_labels.cuda()
        else:
            soft_labels = Recover_soft_label(soft_labels, args.soft_label_type,
                                             args.num_classes)
            soft_labels = soft_labels.cuda()

        if utils.is_model_cuda(model):
            images = images.cuda()
            labels = labels.cuda()

        if args.w_cutmix == True:
            r = np.random.rand(1)
            if args.beta > 0 and r < args.cutmix_prob:
                # generate mixed sample
                lam = np.random.beta(args.beta, args.beta)
                rand_index = torch.randperm(images.size()[0]).cuda()
                target_a = soft_labels
                target_b = soft_labels[rand_index]
                bbx1, bby1, bbx2, bby2 = utils.rand_bbox(images.size(), lam)
                images[:, :, bbx1:bbx2,
                       bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
                lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                           (images.size()[-1] * images.size()[-2]))

        # Forward pass, backward pass, and update parameters.
        output = model(images)
        # output, soft_label, soft_no_softmax = outputs
        if args.w_cutmix == True:
            g_loss_output1 = g_loss((output, target_a), labels)
            g_loss_output2 = g_loss((output, target_b), labels)
        else:
            g_loss_output = g_loss((output, soft_labels), labels)
        if args.use_discriminator_loss:
            # Our stored label is "after softmax", this is slightly different from original MEAL V2
            # that used probibilaties "before softmax" for the discriminator.
            output_softmax = nn.functional.softmax(output)
            if args.w_cutmix == True:
                d_loss_value = discriminator_loss(
                    [output_softmax], [target_a]) * lam + discriminator_loss(
                        [output_softmax], [target_b]) * (1 - lam)
            else:
                d_loss_value = discriminator_loss([output_softmax],
                                                  [soft_labels])

        # Sometimes loss function returns a modified version of the output,
        # which must be used to compute the model accuracy.
        if args.w_cutmix == True:
            if isinstance(g_loss_output1, tuple):
                g_loss_value1, output1 = g_loss_output1
                g_loss_value2, output2 = g_loss_output2
                g_loss_value = g_loss_value1 * lam + g_loss_value2 * (1 - lam)
            else:
                g_loss_value = g_loss_output1 * lam + g_loss_output2 * (1 -
                                                                        lam)
        else:
            if isinstance(g_loss_output, tuple):
                g_loss_value, output = g_loss_output
            else:
                g_loss_value = g_loss_output

        if args.use_discriminator_loss:
            loss_value = g_loss_value + d_loss_value
        else:
            loss_value = g_loss_value

        loss_value.backward()

        # Update parameters and reset gradients.
        optimizer.step()
        optimizer.zero_grad()

        # Record loss and model accuracy.
        g_loss_meter.update(g_loss_value.item(), batch_size)
        d_loss_meter.update(d_loss_value.item(), batch_size)

        top1, top5 = utils.topk_accuracy(output, labels, recalls=(1, 5))
        top1_meter.update(top1, batch_size)
        top5_meter.update(top5, batch_size)

        # Record batch time
        batch_time_meter.update(time.time() - timestamp)
        timestamp = time.time()

        if i % 20 == 0:
            logging.info(
                'Epoch: [{epoch}][{batch}/{epoch_size}]\t'
                'Time {batch_time.value:.2f} ({batch_time.average:.2f})   '
                'Data {data_time.value:.2f} ({data_time.average:.2f})   '
                'G_Loss {g_loss.value:.3f} {{{g_loss.average:.3f}, {g_loss.average_recent:.3f}}}    '
                'D_Loss {d_loss.value:.3f} {{{d_loss.average:.3f}, {d_loss.average_recent:.3f}}}    '
                'Top-1 {top1.value:.2f} {{{top1.average:.2f}, {top1.average_recent:.2f}}}    '
                'Top-5 {top5.value:.2f} {{{top5.average:.2f}, {top5.average_recent:.2f}}}    '
                'LR {lr:.5f}'.format(epoch=epoch_number,
                                     batch=i + 1,
                                     epoch_size=len(train_loader),
                                     batch_time=batch_time_meter,
                                     data_time=data_time_meter,
                                     g_loss=g_loss_meter,
                                     d_loss=d_loss_meter,
                                     top1=top1_meter,
                                     top5=top5_meter,
                                     lr=_get_learning_rate(optimizer)))
    # Log the overall train stats
    logging.info('Epoch: [{epoch}] -- TRAINING SUMMARY\t'
                 'Time {batch_time.sum:.2f}   '
                 'Data {data_time.sum:.2f}   '
                 'G_Loss {g_loss.average:.3f}     '
                 'D_Loss {d_loss.average:.3f}     '
                 'Top-1 {top1.average:.2f}    '
                 'Top-5 {top5.average:.2f}    '.format(
                     epoch=epoch_number,
                     batch_time=batch_time_meter,
                     data_time=data_time_meter,
                     g_loss=g_loss_meter,
                     d_loss=d_loss_meter,
                     top1=top1_meter,
                     top5=top5_meter))