예제 #1
0
def train_iter(args, loader, model, closs, optimizer, epoch, best_loss):
	losses = AverageMeter()
	batch_time = AverageMeter()
	losses = AverageMeter()
	c_losses = AverageMeter()
	model.train()
	end = time.time()
	if args.coord_switch_flag:
		coord_switch_loss = nn.L1Loss()
		sc_losses = AverageMeter()

	if epoch < 1 or (epoch>=args.wepoch and epoch< args.wepoch+2):
		thr = None
	else:
		thr = 2.5

	for i,frames in enumerate(loader):
		frame1_var = frames[0].cuda()
		frame2_var = frames[1].cuda()
		
		if epoch < args.wepoch:
			output = forward(frame1_var, frame2_var, model, warm_up=True)
			color2_est = output[0]
			aff = output[1]
			b,x,_ = aff.size()
			color1_est = None
			if args.color_switch_flag:
				color1_est = output[2]
			loss_ = L1_loss(color2_est, frame2_var, 10, 10, thr=thr, pred1=color1_est, frame1_var = frame1_var)

			if epoch >=1 and args.lc > 0:
				constraint_loss = torch.sum(closs(aff.view(b,1,x,x))) * args.lc
				c_losses.update(constraint_loss.item(), frame1_var.size(0))
				loss = loss_ + constraint_loss
			else:
				loss = loss_
			if(i % args.log_interval == 0):
				save_vis(color2_est, frame2_var, frame1_var, frame2_var, args.savepatch)
		else:
			output = forward(frame1_var, frame2_var, model, warm_up=False, patch_size = args.patch_size)
			color2_est = output[0]
			aff = output[1]
			new_c = output[2]
			coords = output[3]
			Fcolor2_crop = output[-1]

			b,x,x = aff.size()
			color1_est = None
			count = 3

			constraint_loss = torch.sum(closs(aff.view(b,1,x,x))) * args.lc
			c_losses.update(constraint_loss.item(), frame1_var.size(0))

			if args.color_switch_flag:
				count += 1
				color1_est = output[count]

			loss_color = L1_loss(color2_est, Fcolor2_crop, 10, 10, thr=thr, pred1=color1_est, frame1_var = frame1_var)
			loss_ = loss_color + constraint_loss
			
			if args.coord_switch_flag:
				count += 1
				grids = output[count]
				C11 = output[count+1]
				loss_coord = args.coord_switch * coord_switch_loss(C11, grids)
				loss = loss_ + loss_coord
				sc_losses.update(loss_coord.item(), frame1_var.size(0))				
			else:
				loss = loss_
				
			if(i % args.log_interval == 0):
				save_vis(color2_est, Fcolor2_crop, frame1_var, frame2_var, args.savepatch, new_c)

		losses.update(loss.item(), frame1_var.size(0))
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		batch_time.update(time.time() - end)
		end = time.time()			

		if epoch >= args.wepoch and args.coord_switch_flag:
			logger.info('Epoch: [{0}][{1}/{2}]\t'
				'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
				'Color Loss {loss.val:.4f} ({loss.avg:.4f})\t '
				'Coord switch Loss {scloss.val:.4f} ({scloss.avg:.4f})\t '
				'Constraint Loss {c_loss.val:.4f} ({c_loss.avg:.4f})\t '.format(
				epoch, i+1, len(loader), batch_time=batch_time, loss=losses, scloss=sc_losses, c_loss= c_losses))
		else:
			logger.info('Epoch: [{0}][{1}/{2}]\t'
				'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
				'Color Loss {loss.val:.4f} ({loss.avg:.4f})\t '
				'Constraint Loss {c_loss.val:.4f} ({c_loss.avg:.4f})\t '.format(
				epoch, i+1, len(loader), batch_time=batch_time, loss=losses, c_loss= c_losses))

		if((i + 1) % args.save_interval == 0):
			is_best = losses.avg < best_loss
			best_loss = min(losses.avg, best_loss)
			checkpoint_path = os.path.join(args.savedir, 'checkpoint_latest.pth.tar')
			save_checkpoint({
					'epoch': epoch + 1,
					'state_dict': model.state_dict(),
					'best_loss': best_loss,
				}, is_best, filename=checkpoint_path, savedir = args.savedir)
			log_current(epoch, losses.avg, best_loss, filename = "log_current.txt", savedir=args.savedir)

	return best_loss
예제 #2
0
def train_iter(args, loader, model, closs, optimizer, epoch, best_loss):
    losses = AverageMeter()
    batch_time = AverageMeter()
    losses = AverageMeter()
    c_losses = AverageMeter()
    model.train()
    end = time.time()
    if args.coord_switch_flag:
        coord_switch_loss = nn.L1Loss()
        sc_losses = AverageMeter()

    if epoch < 1 or (epoch >= args.wepoch and epoch < args.wepoch + 2):
        thr = None
    else:
        thr = 2.5
    train_len = len(loader)
    for i, item in enumerate(loader):
        #print('iteration:', i, len(item))
        frames = item[0]
        frames_pair = item[1]
        segments = item[2]
        segments = torch.stack(segments, dim=1)
        #print('segment:',segments.size())
        org_pair = item[3]
        frame1_var = frames[0].cuda()
        frame2_var = frames[1].cuda()
        frame1_org = frames_pair[0].cuda()
        frame1_sal = org_pair[0].cuda()
        if epoch < args.wepoch:
            output = forward(frame1_var, frame2_var, model, warm_up=True)
            color2_est = output[0]
            aff = output[1]
            b, x, _ = aff.size()
            color1_est = None
            if args.color_switch_flag:
                color1_est = output[2]
            loss_ = L1_loss(color2_est,
                            frame2_var,
                            10,
                            10,
                            thr=thr,
                            pred1=color1_est,
                            frame1_var=frame1_var)

            if epoch >= 1 and args.lc > 0:
                constraint_loss = torch.sum(closs(aff.view(b, 1, x,
                                                           x))) * args.lc
                c_losses.update(constraint_loss.item(), frame1_var.size(0))
                loss = loss_ + constraint_loss
            else:
                loss = loss_
            #if(i % args.log_interval == 0):
            #	save_vis(color2_est, frame2_var, frame1_var, frame2_var, args.savepatch)
        else:
            # print("input: ", frame1_var.size(), frame2_var.size())
            output = forward(frame1_var,
                             frame1_org,
                             frame1_sal,
                             frame2_var,
                             model,
                             warm_up=False,
                             patch_size=args.patch_size,
                             segments=segments)
            color2_est = output[0]
            aff = output[1]
            new_c = output[2]
            coords = output[3]
            pred_1 = output[4]
            gt_mask1 = output[5]
            Fcolor2_crop = output[-1]

            b, x, x = aff.size()
            color1_est = None
            count = 5

            constraint_loss = torch.sum(closs(aff.view(b, 1, x, x))) * args.lc
            c_losses.update(constraint_loss.item(), frame1_var.size(0))

            if args.color_switch_flag:
                count += 1
                color1_est = output[count]
            pred_1 = F.upsample(pred_1, gt_mask1.size()[2:], mode='bilinear')
            my_l1_loss = my_L1_loss(pred_1, gt_mask1)
            print('output range:', torch.max(pred_1), torch.min(pred_1),
                  torch.max(gt_mask1), torch.min(gt_mask1))
            gt_mask1 = F.sigmoid(gt_mask1).detach()
            gt_mask1[gt_mask1 > 0.2] = 1
            gt_mask1[gt_mask1 < 0.2] = 0
            frame_loss = my_l1_loss + my_creteria(
                pred_1,
                gt_mask1)  #* (1 - 0.05)#my_BCE_loss(pred_1, gt_mask1) #
            #gt_self = F.sigmoid(pred_1).detach()
            #gt_self[gt_self > 0.2] = 1
            #gt_self[gt_self <= 0.2] = 0
            #frame_loss = frame_loss + my_creteria(pred_1, gt_self) * 0.05
            #loss_color = L1_loss(color2_est, Fcolor2_crop, 10, 10, thr=thr, pred1=color1_est, frame1_var = frame1_var)
            #print('loss:', frame_loss, loss_color, constraint_loss)
            loss_ = frame_loss  #0.01*loss_color + 0.01*constraint_loss

            if args.coord_switch_flag:
                count += 1
                grids = output[count]
                C11 = output[count + 1]
                loss_coord = args.coord_switch * coord_switch_loss(C11, grids)
                loss = loss_ + loss_coord
                sc_losses.update(loss_coord.item(), frame1_var.size(0))
            else:
                loss = loss_

            #if(i % args.log_interval == 0):
            #	save_vis(color2_est, Fcolor2_crop, frame1_var, frame2_var, args.savepatch, new_c)

        losses.update(loss.item(), frame1_var.size(0))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_time.update(time.time() - end)
        end = time.time()

        if ((i + epoch * train_len) % args.save_interval == 0):
            is_best = losses.avg < best_loss
            best_loss = min(losses.avg, best_loss)
            checkpoint_path = os.path.join(
                args.savedir,
                str(epoch + 1) + 'checkpoint_latest.pth.tar')
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                },
                is_best,
                filename=checkpoint_path,
                savedir=args.savedir)
            log_current(epoch,
                        losses.avg,
                        best_loss,
                        filename="log_current.txt",
                        savedir=args.savedir)

    return best_loss