def save_checkpoint(model, optimizer, epoch, global_step, args):
	#SAVE
	save_dir = model_utils.make_joint_checkpoint_name(args, epoch)
	save_dir = os.path.join(args.savemodel, save_dir)
	if not os.path.exists(save_dir):
		os.makedirs(save_dir)

	model_path = os.path.join(save_dir, 'model_{:04d}.pth'.format(epoch))

	if epoch % args.save_freq == 0:
		torch.save({
			'state_dict': model.state_dict(),
			'optimizer': optimizer.state_dict(),
			'epoch': epoch
			}, model_path)
		print('<=== checkpoint has been saved to {}.'.format(model_path))
def main(args):
	train_loader, test_loader = make_data_loader(args)

	torch.manual_seed(args.seed)
	torch.cuda.manual_seed(args.seed)

	if args.resnet_arch is None:
		model = UNet()
	else:
		model = ResNetUNet(args.resnet_arch)
	# model = DataParallelWithCallback(model)
	model = nn.DataParallel(model).cuda()
	print('#parameters in warp disp model: {}'.format(sum([p.data.nelement() for p in model.parameters()])))

	optimizer = optim.Adam(model.parameters(), 
		lr=args.lr, 
		betas=(0.9, 0.999),
		eps=1e-08, 
		weight_decay=0.0004
	)

	if args.loadmodel is not None:
		state_dict = torch.load(args.loadmodel)['state_dict']
		# if LooseVersion(torch.__version__) >= LooseVersion('0.4.0'):
		# 	keys = list(state_dict.keys())
		# 	for k in keys:
		# 		if k.find('num_batches_tracked') >= 0:
		# 			state_dict.pop(k)
		model.load_state_dict(state_dict)
		print('==> A pre-trained checkpoint has been loaded: {}.'.format(args.loadmodel))
	start_epoch = 1

	if args.auto_resume:
		raise NotImplementedError
		# search for the latest saved checkpoint
		epoch_found = -1
		for epoch in range(args.epochs+1, 1, -1):
			ckpt_path = model_utils.make_joint_checkpoint_name(args, epoch)
			ckpt_path = os.path.join(args.savemodel, ckpt_path)
			if os.path.exists(ckpt_path):
				epoch_found = epoch
				break
		if epoch_found > 0:
			ckpt = torch.load(ckpt_path)
			assert ckpt['epoch'] == epoch_found, [ckpt['epoch'], epoch_found]
			start_epoch = ckpt['epoch'] + 1
			optimizer.load_state_dict(ckpt['optimizer'])
			model.load_state_dict(ckpt['state_dict'])
			print('==> Automatically resumed training from {}.'.format(ckpt_path))

	crit = multiscaleloss(
		downsample_factors=(16, 8, 4, 2, 1),
		weights=(1, 1, 2, 4, 8), 
		loss='l1',
		size_average=True
		).cuda()

	start_full_time = time.time()

	train_print_format = '{}\t{:d}\t{:d}\t{:d}\t{:d}\t{:.3f}\t{:.3f}\t{:.3f}'\
						 '\t{:.6f}'
	test_print_format = '{}\t{:d}\t{:d}\t{:.3f}\t{:.2f}\t{:.3f}\t{:.2f}'\
						'\t{:.6f}'

	os.makedirs(os.path.join(args.savemodel, 'tensorboard'), exist_ok=True)
	writer = SummaryWriter(os.path.join(args.savemodel, 'tensorboard'))

	global_step = 0
	for epoch in range(start_epoch, args.epochs+1):
		total_err = 0
		total_test_err_pct = 0
		total_test_loss = 0
		lr = adjust_learning_rate(optimizer, epoch, len(train_loader))
			 
		## training ##
		start_time = time.time() 
		for batch_idx, data in enumerate(train_loader):
			end = time.time()
			loss, losses = train(
				model, crit, optimizer, data
			)
			global_step += 1
			writer.add_scalar('train/total_loss', loss * 20, global_step)
			if (batch_idx + 1) % args.print_freq == 0:
				print(train_print_format.format(
					'Train', global_step, epoch, batch_idx, len(train_loader),
					loss, 
					end - start_time, time.time() - start_time, lr
					))
				sys.stdout.flush()
			start_time = time.time()
		 
		## test ##
		start_time = time.time()

		for batch_idx, batch_data in enumerate(test_loader):
			err, err_pct, loss = test_disp(
				model, crit, batch_data, args.cmd
			)
			total_err += err
			total_test_err_pct += err_pct
			total_test_loss += loss

		writer.add_scalar('test/loss', total_test_loss / (len(test_loader) + 1e-30) * 20, epoch)
		writer.add_scalar('test/err', total_err / (len(test_loader) + 1e-30), epoch)
		writer.add_scalar('test/err_pct', total_test_err_pct / (len(test_loader) + 1e-30) * 100, epoch)
		print(test_print_format.format(
					'Test', global_step, epoch,
					total_err/(len(test_loader) + 1e-30), 
					total_test_err_pct/(len(test_loader) + 1e-30) * 100,
					total_test_loss / (len(test_loader) + 1e-30),
					time.time() - start_time, lr
					))
		sys.stdout.flush()
		
		save_checkpoint(model, optimizer, epoch, global_step, args)
	print('full time = %.2f HR' %((time.time() - start_full_time)/3600))
def main(args):
	train_loader, flow_test_loader, disp_test_loader = make_data_loader(args)

	torch.manual_seed(args.seed)
	torch.cuda.manual_seed(args.seed)
	np.random.seed(args.seed)
	random.seed(args.seed)

	model = model_utils.make_model(
		args, 
		do_flow=not args.no_flow,
		do_disp=not args.no_disp,
		do_seg=(args.do_seg or args.do_seg_distill)
	)
	print('Number of model parameters: {}'.format(
		sum([p.data.nelement() for p in model.parameters()]))
	)

	optimizer = optim.Adam(model.parameters(), 
		lr=args.lr, 
		betas=(0.9, 0.999),
		eps=1e-08, 
		weight_decay=0.0004
	)

	if args.loadmodel is not None:
		ckpt = torch.load(args.loadmodel)
		state_dict = ckpt['state_dict']
		model.load_state_dict(model_utils.patch_model_state_dict(state_dict))
		print('==> A pre-trained checkpoint has been loaded.')
	start_epoch = 1

	if args.auto_resume:
		# search for the latest saved checkpoint
		epoch_found = -1
		for epoch in range(args.epochs+1, 1, -1):
			ckpt_dir = model_utils.make_joint_checkpoint_name(args, epoch)
			ckpt_dir = os.path.join(args.savemodel, ckpt_dir)
			ckpt_path = os.path.join(ckpt_dir, 'model_{:04d}.pth'.format(epoch))
			if os.path.exists(ckpt_path):
				epoch_found = epoch
				break
		if epoch_found > 0:
			ckpt = torch.load(ckpt_path)
			assert ckpt['epoch'] == epoch_found, [ckpt['epoch'], epoch_found]
			start_epoch = ckpt['epoch'] + 1
			optimizer.load_state_dict(ckpt['optimizer'])
			model.load_state_dict(ckpt['state_dict'])
			print('==> Automatically resumed training from {}.'.format(ckpt_path))
	else:
		if args.resume is not None:
			ckpt = torch.load(args.resume)
			start_epoch = ckpt['epoch'] + 1
			optimizer.load_state_dict(ckpt['optimizer'])
			model.load_state_dict(ckpt['state_dict'])
			print('==> Manually resumed training from {}.'.format(args.resume))
	
	cudnn.benchmark = True

	(flow_crit, flow_occ_crit), flow_down_scales, flow_weights = model_utils.make_flow_criteria(args)
	(disp_crit, disp_occ_crit), disp_down_scales, disp_weights = model_utils.make_disp_criteria(args)

	hard_seg_crit = None
	soft_seg_crit = None
	self_supervised_crit = None
	criteria = (
		disp_crit, disp_occ_crit, 
		flow_crit, flow_occ_crit
	)

	min_loss=100000000000000000
	min_epo=0
	min_err_pct = 10000
	start_full_time = time.time()

	train_print_format = '{}\t{:d}\t{:d}\t{:d}\t{:d}\t{:.3f}\t{:.3f}\t{:.3f}'\
		'\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.6f}'
	test_print_format = '{}\t{:d}\t{:d}\t{:.3f}\t{:.2f}\t{:.3f}\t{:.2f}\t{:.2f}\t{:.2f}'\
		'\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.6f}'

	global_step = 0
	for epoch in range(start_epoch, args.epochs+1):
		total_train_loss = 0
		total_err = 0
		total_test_err_pct = 0
		total_disp_occ_acc = 0
		total_epe = 0
		total_flow_occ_acc = 0
		total_seg_acc = 0
		lr = adjust_learning_rate(optimizer, epoch, len(train_loader))
			 
		## training ##
		start_time = time.time() 
		for batch_idx, batch_data in enumerate(train_loader):
			end = time.time()
			train_res = train(model, optimizer, batch_data, criteria, args)
			loss, flow_loss, flow_occ_loss, disp_loss, disp_occ_loss = train_res
			global_step += 1
			if (batch_idx + 1) % args.print_freq == 0:
				print(train_print_format.format(
					'Train', global_step, epoch, batch_idx, len(train_loader),
					loss, 
					flow_loss, flow_occ_loss, 
					disp_loss, disp_occ_loss,
					end - start_time, time.time() - start_time, lr
				))
				sys.stdout.flush()
			start_time = time.time()
			total_train_loss += loss

		# should have used the validation set to select the best model
		start_time = time.time()
		for batch_idx, batch_data in enumerate(flow_test_loader):
			loss_data = test_flow(
				model, 
				batch_data,
				criteria, 
				args.cmd, 
				flow_down_scales[0]
			)
			epe, flow_occ_acc, loss, flow_loss, flow_occ_loss = loss_data
			total_epe += epe
			total_flow_occ_acc += flow_occ_acc

		for batch_idx, batch_data in enumerate(disp_test_loader):
			loss_data = test_disp(
				model, 
				batch_data, 
				criteria, 
				args.cmd
			)
			err, err_pct, disp_occ_acc, loss, disp_loss, disp_occ_loss = loss_data
			total_err += err
			total_test_err_pct += err_pct
			total_disp_occ_acc += disp_occ_acc

		if total_test_err_pct/len(disp_test_loader) * 100 < min_err_pct:
			min_loss = total_err/len(disp_test_loader)
			min_epo = epoch
			min_err_pct = total_test_err_pct/len(disp_test_loader) * 100

		print(test_print_format.format(
			'Test', global_step, epoch,
			total_epe / len(flow_test_loader) * args.div_flow,
			total_flow_occ_acc / len(flow_test_loader) * 100,
			total_err/len(disp_test_loader), 
			total_test_err_pct/len(disp_test_loader) * 100,
			total_disp_occ_acc / len(disp_test_loader) * 100,
			flow_loss, flow_occ_loss,
			disp_loss * args.disp_loss_weight, 
			disp_occ_loss * args.disp_loss_weight,
			time.time() - start_time, lr
		))
		
		save_checkpoint(model, optimizer, epoch, global_step, args)
	print('Elapsed time = %.2f HR' %((time.time() - start_full_time)/3600))
示例#4
0
def main(args):
    train_loader, test_loader = make_data_loader(args)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    model = model_utils.make_model(args,
                                   do_flow=not args.no_flow,
                                   do_disp=not args.no_disp,
                                   do_seg=(args.do_seg or args.do_seg_distill))
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.0004)

    if args.loadmodel is not None:
        ckpt = torch.load(args.loadmodel)
        state_dict = ckpt['state_dict']
        missing_keys, unexpected_keys = model.load_state_dict(
            model_utils.patch_model_state_dict(state_dict))
        assert not unexpected_keys, 'Got unexpected keys: {}'.format(
            unexpected_keys)
        if missing_keys:
            for mk in missing_keys:
                assert mk.find(
                    'seg_decoder'
                ) >= 0, 'Only segmentation decoder can be initialized randomly.'
        print('==> A pre-trained model has been loaded.')
    start_epoch = 1

    if args.auto_resume:
        # search for the latest saved checkpoint
        epoch_found = -1
        for epoch in range(args.epochs + 1, 1, -1):
            ckpt_dir = model_utils.make_joint_checkpoint_name(args, epoch)
            ckpt_dir = os.path.join(args.savemodel, ckpt_dir)
            ckpt_path = os.path.join(ckpt_dir,
                                     'model_{:04d}.pth'.format(epoch))
            if os.path.exists(ckpt_path):
                epoch_found = epoch
                break
        if epoch_found > 0:
            ckpt = torch.load(ckpt_path)
            assert ckpt['epoch'] == epoch_found, [ckpt['epoch'], epoch_found]
            start_epoch = ckpt['epoch'] + 1
            optimizer.load_state_dict(ckpt['optimizer'])
            model.load_state_dict(ckpt['state_dict'])
            print('==> Automatically resumed training from {}.'.format(
                ckpt_path))
    else:
        if args.resume is not None:
            ckpt = torch.load(args.resume)
            start_epoch = ckpt['epoch'] + 1
            optimizer.load_state_dict(ckpt['optimizer'])
            model.load_state_dict(ckpt['state_dict'])
            print('==> Manually resumed training from {}.'.format(args.resume))

    cudnn.benchmark = True

    (flow_crit, flow_occ_crit
     ), flow_down_scales, flow_weights = model_utils.make_flow_criteria(args)
    (disp_crit, disp_occ_crit
     ), disp_down_scales, disp_weights = model_utils.make_disp_criteria(args)

    hard_seg_crit = model_utils.make_seg_criterion(args, hard_lab=True)
    soft_seg_crit = model_utils.make_seg_criterion(args, hard_lab=False)
    args.hard_seg_loss_weight *= float(disp_weights[0])
    args.soft_seg_loss_weight *= float(disp_weights[0])

    self_supervised_crit = make_self_supervised_loss(
        args,
        disp_downscales=disp_down_scales,
        disp_pyramid_weights=disp_weights,
        flow_downscales=flow_down_scales,
        flow_pyramid_weights=flow_weights).cuda()
    criteria = (disp_crit, disp_occ_crit, flow_crit, flow_occ_crit,
                hard_seg_crit, soft_seg_crit, self_supervised_crit)

    min_loss = 100000000000000000
    min_epo = 0
    min_err_pct = 10000
    start_full_time = time.time()

    train_print_format = '{}\t{:d}\t{:d}\t{:d}\t{:d}\t{:.3f}\t{:.3f}\t{:.3f}'\
          '\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.6f}'
    test_print_format = '{}\t{:d}\t{:d}\t{:.3f}\t{:.2f}\t{:.3f}\t{:.2f}\t{:.2f}\t{:.2f}'\
         '\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.6f}'

    global_step = 0
    for epoch in range(start_epoch, args.epochs + 1):
        total_train_loss = 0
        total_err = 0
        total_test_err_pct = 0
        total_disp_occ_acc = 0
        total_epe = 0
        total_flow_occ_acc = 0
        total_seg_acc = 0
        lr = adjust_learning_rate(optimizer, epoch, len(train_loader))

        ## training ##
        start_time = time.time()
        for batch_idx, batch_data in enumerate(train_loader):
            end = time.time()
            # (cur_im, nxt_im), (flow, flow_occ), (left_im, right_im), (disp, disp_occ, seg_im) = data
            # if args.seg_root_dir is None:
            # 	seg_im = None
            train_res = train(model, optimizer, batch_data, criteria, args)
            loss, flow_loss, flow_occ_loss, disp_loss, disp_occ_loss, seg_loss, seg_distill_loss, ss_loss, ss_losses = train_res
            global_step += 1
            if (batch_idx + 1) % args.print_freq == 0:
                print(
                    train_print_format.format('Train', global_step,
                                              epoch, batch_idx,
                                              len(train_loader), loss,
                                              flow_loss, flow_occ_loss,
                                              disp_loss, disp_occ_loss,
                                              seg_loss, seg_distill_loss,
                                              ss_loss, end - start_time,
                                              time.time() - start_time, lr))
                for k, v in ss_losses.items():
                    print('{: <10}\t{:.3f}'.format(k, v))
                sys.stdout.flush()
            start_time = time.time()
            total_train_loss += loss

        # should have had a validation set

        save_checkpoint(model, optimizer, epoch, global_step, args)
    print('Elapsed time = %.2f HR' % ((time.time() - start_full_time) / 3600))
示例#5
0
def main(args):
    train_loader, test_loader = make_data_loader(args)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    model = model_utils.make_model(args, do_seg=args.do_seg)
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=0.0004)

    if args.loadmodel is not None:
        ckpt = torch.load(args.loadmodel)
        state_dict = ckpt['state_dict']
        model.load_state_dict(model_utils.patch_model_state_dict(state_dict))
        print('==> A pre-trained checkpoint has been loaded {}'.format(
            args.loadmodel))
    start_epoch = 1

    if args.auto_resume:
        # search for the latest saved checkpoint
        epoch_found = -1
        for epoch in range(args.epochs + 1, 1, -1):
            ckpt_path = model_utils.make_joint_checkpoint_name(args, epoch)
            ckpt_path = os.path.join(args.savemodel, ckpt_path)
            if os.path.exists(ckpt_path):
                epoch_found = epoch
                break
        if epoch_found > 0:
            ckpt = torch.load(ckpt_path)
            assert ckpt['epoch'] == epoch_found, [ckpt['epoch'], epoch_found]
            start_epoch = ckpt['epoch'] + 1
            optimizer.load_state_dict(ckpt['optimizer'])
            model.load_state_dict(ckpt['state_dict'])
            print('==> Automatically resumed training from {}.'.format(
                ckpt_path))

    cudnn.benchmark = True

    (flow_crit, flow_occ_crit
     ), flow_down_scales, flow_weights = model_utils.make_flow_criteria(args)
    (disp_crit, disp_occ_crit
     ), disp_down_scales, disp_weights = model_utils.make_disp_criteria(args)

    criteria = (disp_crit, disp_occ_crit, flow_crit, flow_occ_crit)

    min_loss = 100000000000000000
    min_epo = 0
    min_err_pct = 10000
    start_full_time = time.time()

    train_print_format = '{}\t{:d}\t{:d}\t{:d}\t{:d}\t{:.3f}\t{:.3f}\t{:.3f}'\
          '\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t\t{:.6f}'
    test_print_format = '{}\t{:d}\t{:d}\t{:.3f}\t{:.2f}\t{:.3f}\t{:.2f}\t{:.2f}\t{:.2f}'\
         '\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.3f}\t{:.6f}'

    global_step = 0
    for epoch in range(start_epoch, args.epochs + 1):
        total_train_loss = 0
        total_err = 0
        total_test_err_pct = 0
        total_disp_occ_acc = 0
        total_epe = 0
        total_flow_occ_acc = 0
        total_seg_acc = 0
        lr = adjust_learning_rate(optimizer, epoch, len(train_loader))

        ## training ##
        start_time = time.time()
        for batch_idx, batch_data in enumerate(train_loader):
            end = time.time()
            train_res = train(model, optimizer, batch_data, criteria, args)
            loss, flow_loss, flow_occ_loss, disp_loss, disp_occ_loss = train_res
            global_step += 1
            if (batch_idx + 1) % args.print_freq == 0:
                print(
                    train_print_format.format('Train', global_step,
                                              epoch, batch_idx,
                                              len(train_loader), loss,
                                              flow_loss, flow_occ_loss,
                                              disp_loss, disp_occ_loss,
                                              end - start_time,
                                              time.time() - start_time, lr))
                sys.stdout.flush()
            start_time = time.time()
            total_train_loss += loss

        save_checkpoint(model, optimizer, epoch, global_step, args)
    print('Elapsed time = %.2f HR' % ((time.time() - start_full_time) / 3600))