def validation(model, criterion1, criterion2, lamb, loader, device, log_callback): end = time.time() model.eval() # return validation_loss, validation_acc with torch.no_grad(): # the output of the dataloader is (batch_idx, image, mask, c, v, t) for batch_idx, data in enumerate(loader): target_image, target_mask, input_c, input_v, input_t = data target_image = target_image.to(device, non_blocking=True) target_mask = target_mask.to(device, non_blocking=True) input_c = input_c.to(device, non_blocking=True) input_v = input_v.to(device, non_blocking=True) input_t = input_t.to(device, non_blocking=True) # compute the output out_image, out_mask = model(input_c, input_v, input_t) # compute the loss loss1 = criterion1(out_image, target_image) #loss2 = criterion2(out_mask, target_mask.long().squeeze()) loss2 = criterion2(out_mask, target_mask) loss = loss1 + lamb * loss2 batch_time.update(time.time() - end) end = time.time() # records essential information into log file. log_callback( 'epoch: {0}\t' 'Time {batch_time.sum:.3f}s / {1} epochs, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {1} epochs, ({data_time.avg:3f})\n' 'Loss = {loss:.8f}\n'.format(epoch, batch_idx, batch_time=batch_time, data_time=data_time, loss=loss.item())) log_callback() log_callback('Loss{0} = {loss1:.8f}\t'.format(1, loss1=loss1.item())) log_callback('Loss{0} = {loss1:.8f}\t'.format(2, loss1=loss2.item())) log_callback(Timer.timeString()) batch_time.reset() return loss.item()
def train(epoch, model, optimizer, criterion1, criterion2, lamb, loader, device, log_callback): end = time.time() model.train() for param_group in optimizer.param_groups: learning_rate = param_group['lr'] # the output of the dataloader is (batch_idx, image, mask, c, v, t) for batch_idx, data in enumerate(loader): target_image, target_mask, input_c, input_v, input_t = data target_image = target_image.to(device, non_blocking=True) target_mask = target_mask.to(device, non_blocking=True) input_c = input_c.to(device, non_blocking=True) input_v = input_v.to(device, non_blocking=True) input_t = input_t.to(device, non_blocking=True) data_time.update(time.time() - end) # input all the input vectors into the model out_image, out_mask = model(input_c, input_v, input_t) # compute the loss according to the paper loss1 = criterion1(out_image, target_image) # Note that the target should remove the channel size as stated in document #loss2 = criterion2(out_mask, target_mask.long().squeeze()) loss2 = criterion2(out_mask, target_mask) loss = loss1 + lamb * loss2 optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() # record essential informations into log file. if batch_idx % args.log_interval == 0: log_callback( 'Epoch: {0}\t' 'Time {batch_time.sum:.3f}s / {1} batches, ({batch_time.avg:.3f})\t' 'Data load {data_time.sum:.3f}s / {1} batches, ({data_time.avg:3f})\n' 'Learning rate = {2}\n'.format(epoch, args.log_interval, learning_rate, batch_time=batch_time, data_time=data_time)) log_callback( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(loader.dataset), 100. * batch_idx / len(loader), loss.item())) log_callback() log_callback('Loss{0} = {loss1:.8f}\t'.format(1, loss1=loss1.item())) log_callback('Loss{0} = {loss1:.8f}\t'.format(2, loss1=loss2.item())) log_callback() log_callback("current time: " + Timer.timeString()) batch_time.reset() data_time.reset() torch_utils.save(folderPath + 'ChairCNN_' + str(epoch) + '.cpkt', epoch, model, optimizer, scheduler)