def train(train_loader, model, criterion, optimizer, epoch, args, vis=None): logger.info('Starting training epoch {}'.format(epoch)) centroids = np.load("./dataset/annotation_centroids.npy") centroids = torch.Tensor(centroids).float().cuda() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() model.train() end = time.time() for i, (img_input, annotation_input, _) in enumerate(train_loader): data_time.update(time.time() - end) (batch_size, num_frames, num_channels, H, W) = img_input.shape reshaped_annotation_input = annotation_input.reshape(-1, 3, H, W).cuda() annotation_input_downsample = torch.nn.functional.interpolate(reshaped_annotation_input, scale_factor=SCALE, mode='bilinear', align_corners=False) H_d = annotation_input_downsample.shape[-2] W_d = annotation_input_downsample.shape[-1] annotation_input = rgb2class(annotation_input_downsample, centroids) annotation_input = annotation_input.reshape(batch_size, num_frames, H_d, W_d) img_input = img_input.reshape(-1, num_channels, H, W).cuda() features = model(img_input) feature_dim = features.shape[1] features = features.reshape(batch_size, num_frames, feature_dim, H_d, W_d) ref = features[:, 0:num_frames - 1, :, :, :] target = features[:, -1, :, :, :] ref_label = annotation_input[:, 0:num_frames - 1, :, :] target_label = annotation_input[:, -1, :, :] ref_label = torch.zeros(batch_size, num_frames - 1, centroids.shape[0], H_d, W_d).cuda().scatter_( 2, ref_label.unsqueeze(2), 1) loss, prediction = criterion(ref, target, ref_label, target_label) loss /= args.iter_size loss.backward() losses.update(loss.item(), batch_size) if (i + 1) % args.iter_size == 0: optimizer.step() optimizer.zero_grad() batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logger.info('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses)) if vis is not None: global_step = epoch * len(train_loader) + i vis.line(Y=[[losses.val, losses.avg]], X=[global_step], win='loss', update='append') vis.line(Y=[optimizer.param_groups[0]['lr']], X=[global_step], win='lr', update='append') images = dataset.davis.denormalize_images(img_input) vis.images(images[num_frames - 1::num_frames], opts=dict(caption='input_images'), win='inputs') vis.images(reshaped_annotation_input[num_frames - 1::num_frames], opts=dict(caption='annotations'), win='GT') upsampled_pred = torch.nn.functional.interpolate(prediction.view(batch_size, -1, H_d, W_d), size=(H, W), mode='bilinear', align_corners=False) pred = torch.argmax(upsampled_pred, 1, keepdim=True) * 30 # (B, 1, H, W) vis.images(pred, opts=dict(caption='predictions'), win='pred') logger.info('Finished training epoch {}'.format(epoch)) return losses.avg
def validate(val_loader, model, criterion, args): logger.info('starting validation...') centroids = np.load("./dataset/annotation_centroids.npy") centroids = torch.Tensor(centroids).float().cuda() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() model.eval() end = time.time() for i, (img_input, annotation_input, _) in enumerate(val_loader): data_time.update(time.time() - end) (batch_size, num_frames, num_channels, H, W) = img_input.shape annotation_input = annotation_input.reshape(-1, 3, H, W).cuda() annotation_input_downsample = torch.nn.functional.interpolate( annotation_input, scale_factor=SCALE, mode='bilinear', align_corners=False) H_d = annotation_input_downsample.shape[-2] W_d = annotation_input_downsample.shape[-1] annotation_input = rgb2class(annotation_input_downsample, centroids) annotation_input = annotation_input.reshape(batch_size, num_frames, H_d, W_d) img_input = img_input.reshape(-1, num_channels, H, W).cuda() features = model(img_input) feature_dim = features.shape[1] features = features.reshape(batch_size, num_frames, feature_dim, H_d, W_d) ref = features[:, 0:num_frames - 1, :, :, :] target = features[:, -1, :, :, :] ref_label = annotation_input[:, 0:num_frames - 1, :, :] target_label = annotation_input[:, -1, :, :] ref_label = torch.zeros(batch_size, num_frames - 1, centroids.shape[0], H_d, W_d).cuda().scatter_(2, ref_label.unsqueeze(2), 1) loss = criterion(ref, target, ref_label, target_label) / args.iter_size losses.update(loss.item(), batch_size) batch_time.update(time.time() - end) end = time.time() if i % 25 == 0: logger.info('Validate: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( i, len(val_loader), batch_time=batch_time, loss=losses)) logger.info('Finished validation') return losses.avg
def train(train_loader, model, criterion, optimizer, epoch, args): logger.info('Starting training epoch {}'.format(epoch)) centroids = np.load("./dataset/annotation_centroids.npy") centroids = torch.Tensor(centroids).float().cuda() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() model.train() end = time.time() for i, (img_input, annotation_input, _) in enumerate(train_loader): data_time.update(time.time() - end) (batch_size, num_frames, num_channels, H, W) = img_input.shape #################### logger.info('Input shape --- Batch Size: {0}\t' 'Num frames: {1}\t' 'Num channels: {2}\t' 'H x W: {3}x{4}\t'.format(batch_size, num_frames, num_channels, H, W)) #################### annotation_input = annotation_input.reshape(-1, 3, H, W).cuda() annotation_input_downsample = torch.nn.functional.interpolate( annotation_input, scale_factor=SCALE, mode='bilinear', align_corners=False) H_d = annotation_input_downsample.shape[-2] W_d = annotation_input_downsample.shape[-1] annotation_input = rgb2class(annotation_input_downsample, centroids) annotation_input = annotation_input.reshape(batch_size, num_frames, H_d, W_d) img_input = img_input.reshape(-1, num_channels, H, W).cuda() features = model(img_input) feature_dim = features.shape[1] features = features.reshape(batch_size, num_frames, feature_dim, H_d, W_d) ref = features[:, 0:num_frames - 1, :, :, :] target = features[:, -1, :, :, :] ref_label = annotation_input[:, 0:num_frames - 1, :, :] target_label = annotation_input[:, -1, :, :] ref_label = torch.zeros(batch_size, num_frames - 1, centroids.shape[0], H_d, W_d).cuda().scatter_(2, ref_label.unsqueeze(2), 1) loss = criterion(ref, target, ref_label, target_label) / args.iter_size loss.backward() losses.update(loss.item(), batch_size) if (i + 1) % args.iter_size == 0: optimizer.step() optimizer.zero_grad() batch_time.update(time.time() - end) end = time.time() if i % 25 == 0: logger.info('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses)) logger.info('Finished training epoch {}'.format(epoch)) return losses.avg