Exemplo n.º 1
0
def validation(model, criterion1, criterion2, lamb, loader, device,
               log_callback):
    end = time.time()
    model.eval()

    # return validation_loss, validation_acc
    with torch.no_grad():
        # the output of the dataloader is (batch_idx, image, mask, c, v, t)
        for batch_idx, data in enumerate(loader):
            target_image, target_mask, input_c, input_v, input_t = data
            target_image = target_image.to(device, non_blocking=True)
            target_mask = target_mask.to(device, non_blocking=True)
            input_c = input_c.to(device, non_blocking=True)
            input_v = input_v.to(device, non_blocking=True)
            input_t = input_t.to(device, non_blocking=True)

            # compute the output
            out_image, out_mask = model(input_c, input_v, input_t)

            # compute the loss
            loss1 = criterion1(out_image, target_image)
            #loss2 = criterion2(out_mask, target_mask.long().squeeze())
            loss2 = criterion2(out_mask, target_mask)
            loss = loss1 + lamb * loss2

            batch_time.update(time.time() - end)
            end = time.time()

        # records essential information into log file.
        log_callback(
            'epoch: {0}\t'
            'Time {batch_time.sum:.3f}s / {1} epochs, ({batch_time.avg:.3f})\t'
            'Data load {data_time.sum:.3f}s / {1} epochs, ({data_time.avg:3f})\n'
            'Loss = {loss:.8f}\n'.format(epoch,
                                         batch_idx,
                                         batch_time=batch_time,
                                         data_time=data_time,
                                         loss=loss.item()))

        log_callback()

        log_callback('Loss{0} = {loss1:.8f}\t'.format(1, loss1=loss1.item()))

        log_callback('Loss{0} = {loss1:.8f}\t'.format(2, loss1=loss2.item()))

        log_callback(Timer.timeString())

        batch_time.reset()

        return loss.item()
Exemplo n.º 2
0
def train(epoch, model, optimizer, criterion1, criterion2, lamb, loader,
          device, log_callback):

    end = time.time()
    model.train()

    for param_group in optimizer.param_groups:
        learning_rate = param_group['lr']

    # the output of the dataloader is (batch_idx, image, mask, c, v, t)
    for batch_idx, data in enumerate(loader):
        target_image, target_mask, input_c, input_v, input_t = data
        target_image = target_image.to(device, non_blocking=True)
        target_mask = target_mask.to(device, non_blocking=True)
        input_c = input_c.to(device, non_blocking=True)
        input_v = input_v.to(device, non_blocking=True)
        input_t = input_t.to(device, non_blocking=True)

        data_time.update(time.time() - end)

        # input all the input vectors into the model
        out_image, out_mask = model(input_c, input_v, input_t)

        # compute the loss according to the paper
        loss1 = criterion1(out_image, target_image)
        # Note that the target should remove the channel size as stated in document
        #loss2 = criterion2(out_mask, target_mask.long().squeeze())
        loss2 = criterion2(out_mask, target_mask)
        loss = loss1 + lamb * loss2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        # record essential informations into log file.
        if batch_idx % args.log_interval == 0:
            log_callback(
                'Epoch: {0}\t'
                'Time {batch_time.sum:.3f}s / {1} batches, ({batch_time.avg:.3f})\t'
                'Data load {data_time.sum:.3f}s / {1} batches, ({data_time.avg:3f})\n'
                'Learning rate = {2}\n'.format(epoch,
                                               args.log_interval,
                                               learning_rate,
                                               batch_time=batch_time,
                                               data_time=data_time))

            log_callback(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(loader.dataset),
                    100. * batch_idx / len(loader), loss.item()))
            log_callback()

            log_callback('Loss{0} = {loss1:.8f}\t'.format(1,
                                                          loss1=loss1.item()))

            log_callback('Loss{0} = {loss1:.8f}\t'.format(2,
                                                          loss1=loss2.item()))

            log_callback()
            log_callback("current time: " + Timer.timeString())

            batch_time.reset()
            data_time.reset()

    torch_utils.save(folderPath + 'ChairCNN_' + str(epoch) + '.cpkt', epoch,
                     model, optimizer, scheduler)