Exemplo n.º 1
0
def train(train_Loader,model,criterion,optimizer,ith_epoch):

    data_time = Meter() # measure average batch data loading time
    batch_time = Meter() # measure average batch computing time, including forward and backward
    losses = Meter() # record average losses across all mini-batches within an epoch
    prec1 = Meter()
    prec3 = Meter()

    model.train()
    end = time.time()
    for ith_batch, data in enumerate(train_Loader):

        input , label = data['image'], data['label']
        input, label = input.cuda(), label.cuda()
        data_time.update(time.time()-end)
        end = time.time()

        # Forward pass
        input_var,label_var = Variable(input), Variable(label)
        output = model(input_var)
        loss = criterion(output,label) # average loss within a mini-batch

        # measure accuracy and record loss
        res, cls1, cls3 = utility_Func.accuracy(output.data,label,topk=(0,2))
        losses.update(loss.data[0])
        prec1.update(res[0])
        prec3.update(res[1])

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        optimizer.n_iters = optimizer.n_iters + 1 if hasattr(optimizer, 'n_iters') else 0

        batch_time.update(time.time()-end)
        end = time.time()

        bt_avg,dt_avg,loss_avg,prec1_avg,prec3_avg = batch_time.avg(),data_time.avg(),losses.avg(),prec1.avg(),prec3.avg()
        if ith_batch % args.print_freq == 0:
            print('Train : ith_batch, batches, ith_epoch : %s %s %s\n' %(ith_batch,len(train_Loader),ith_epoch),
                  'Averaged Batch-computing Time : %s \n' % bt_avg,
                  'Averaged Batch-loading Time : %s \n' % dt_avg,
                  'Averaged Batch-Loss : %s \n' % loss_avg,
                  'Averaged Batch-prec1 : %s \n' % prec1_avg,
                  'Averaged Batch-prec3 : %s \n' % prec3_avg)

    return losses.avg(),prec1.avg(),prec3.avg()
Exemplo n.º 2
0
def validate(val_Loader,model,criterion,ith_epoch):

    batch_time = Meter()  # measure average batch processing time, including forward and output
    losses = Meter()  # record average losses across all mini-batches within an epoch
    top1 = Meter()  # record average top1 precision across all mini-batches within an epoch
    top3 = Meter()  # record average top3 precision
    cls_top1,cls_top3 = {i:Meter() for i in range(80)},{i:Meter() for i in range(80)}

    model.eval()
    end = time.time()
    for ith_batch, data in enumerate(val_Loader):

        # Forward pass
        tmp = list()
        final_output = torch.zeros(len(data['label']), 80).cuda()
        for i in range(10):
            input = data['image'][i]
            input = input.cuda()
            input_var = Variable(input)
            output = model(input_var)  # args.batchSize //32  x 80
            tmp.append(output.data)

        for i in range(len(data['label'])):
            for j in range(10):
                final_output[i,:]+=tmp[j][i,:]
            final_output[i,:].div_(10.0)
        final_output_var = Variable(final_output)
        loss = criterion(final_output_var,data['label'].cuda())  # average loss within a mini-batch

        # measure accuracy and record loss
        res, cls1, cls3 = utility_Func.accuracy(final_output,data['label'].cuda(),topk=(0, 2))
        losses.update(loss.data[0])
        top1.update(res[0])
        top3.update(res[1])
        for i in range(len(data['label'])):
            cls_top1[data['label'][i]].update(cls1[i])
            cls_top3[data['label'][i]].update(cls3[i])

        batch_time.update(time.time() - end)
        end = time.time()

        bt_avg,loss_avg,top1_avg,top3_avg = batch_time.avg(),losses.avg(),top1.avg(),top3.avg()
        if ith_batch % args.print_freq == 0:
            print('Validate : ith_batch, batches, ith_epoch : %s %s %s \n' % (ith_batch, len(val_Loader), ith_epoch),
                  'Averaged Batch-computing Time : %s \n' % bt_avg,
                  'Averaged Batch-Loss : %s \n' % loss_avg,
                  'Averaged Batch-Prec@1 : %s \n' % top1_avg,
                  'Averaged Batch-Prec@3 : %s \n' % top3_avg)

    return losses.avg(),top1.avg(),top3.avg(),cls_top1,cls_top3