Exemplo n.º 1
0
def train(cfg):
	
	use_gpu = cfg.device == 'cuda'
	# 1、make dataloader
	train_loader, val_loader, test_loader, num_query, num_class =  darts_make_data_loader(cfg)
	# print(num_query)

	# 2、make model
	if cfg.model_name == 'ssnet':
		model = SSNetwork(num_class, cfg, use_gpu)
	
	elif cfg.model_name == 'fsnet':
		model = FSNetwork(num_class, cfg.in_planes, cfg.init_size, cfg.layers, use_gpu, cfg.pretrained) 


	# 3、make optimizer
	optimizer = darts_make_optimizer(cfg, model)
	# print(optimizer)

	# 4、make lr scheduler
	lr_scheduler = darts_make_lr_scheduler(cfg, optimizer)
	# print(lr_scheduler)

	# 5、make loss 
	loss_func = darts_make_loss(cfg)
	model._set_loss(loss_func, compute_loss_acc)
	
	# 6、make architect
	architect = Architect(model, cfg)
	
	# get parameters
	log_period = cfg.log_period
	ckpt_period = cfg.ckpt_period
	eval_period = cfg.eval_period
	output_dir =  cfg.output_dir
	device = cfg.device 
	epochs = cfg.max_epochs
	ckpt_save_path = output_dir + cfg.ckpt_dir 

	use_gpu = device == 'cuda'
	batch_size = cfg.batch_size
	batch_num = len(train_loader)
	log_iters = batch_num // log_period 
	pretrained = cfg.pretrained is not None
	parallel = False
	use_neck = cfg.use_neck 

	if not os.path.exists(ckpt_save_path):
		os.makedirs(ckpt_save_path)

	logger = logging.getLogger("DARTS.train")
	size = count_parameters(model)
	logger.info("the param number of the model is {:.2f} M".format(size))

	logger.info("Start training")
	
	
	if pretrained:
		start_epoch = model.start_epoch 
	if parallel:
		model = nn.DataParallel(model)
	if use_gpu:
		model = model.to(device)

	best_mAP, best_r1 = 0., 0.
	is_best = False
	avg_loss, avg_acc = RunningAverageMeter(), RunningAverageMeter()
	avg_time = AverageMeter()
	# num = 3 -> epoch = 2
	for epoch in range(epochs):
		lr_scheduler.step()
		lr = lr_scheduler.get_lr()[0]
		# architect lr.step
		architect.lr_scheduler.step()
		
		if pretrained and epoch < model.start_epoch :
			continue

		model.train()
		avg_loss.reset()
		avg_acc.reset()
		avg_time.reset()

		for i, batch in enumerate(train_loader):
			
			t0 = time.time()
			imgs, labels = batch
			val_imgs, val_labels = next(iter(val_loader))
			
			if use_gpu:
				imgs = imgs.to(device)
				labels = labels.to(device)
				val_imgs = val_imgs.to(device)
				val_labels = val_labels.to(device)

			# 1、update alpha
			architect.step(imgs, labels, val_imgs, val_labels, lr, optimizer, unrolled = cfg.unrolled)

			optimizer.zero_grad()
			res = model(imgs)
			# loss = loss_func(score, feats, labels)
			loss, acc = compute_loss_acc(use_neck, res, labels, loss_func)
			# print("loss:",loss.item())

			loss.backward()
			nn.utils.clip_grad_norm(model.parameters(), cfg.grad_clip)
			
			# 2、update weights
			optimizer.step()

			# acc = (score.max(1)[1] == labels).float().mean()
			# print("acc:", acc)
			t1 = time.time()
			avg_time.update((t1 - t0) / batch_size)
			avg_loss.update(loss)
			avg_acc.update(acc)
			

			# log info
			if (i+1) % log_iters == 0:
				logger.info("epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".format(
					epoch+1, i+1, batch_num, avg_loss.avg, avg_acc.avg))

		logger.info("end epochs {}/{} with lr: {:.5f} and avg_time is: {:.3f} ms".format(epoch+1, epochs, lr, avg_time.avg * 1000))

		
		# test the model
		if (epoch + 1) % eval_period == 0:
			
			model.eval()
			metrics = R1_mAP(num_query, use_gpu = use_gpu)

			with torch.no_grad():

				for vi, batch in enumerate(test_loader):

					imgs, labels, camids = batch

					if use_gpu:
						imgs = imgs.to(device)

					feats = model(imgs)
					metrics.update((feats, labels, camids))

				# compute cmc and mAP
				cmc, mAP = metrics.compute()
				logger.info("validation results at epoch {}".format(epoch + 1))
				logger.info("mAP:{:2%}".format(mAP))
				for r in [1,5,10]:
					logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(r, cmc[r-1]))

				# determine whether current model is the best
				if mAP > best_mAP:
					is_best = True
					best_mAP = mAP
					logger.info("Get a new best mAP")
				if cmc[0] > best_r1:
					is_best = True
					best_r1 = cmc[0]
					logger.info("Get a new best r1")

		# whether to save the model
		if (epoch + 1) % ckpt_period == 0 or is_best:

			if parallel:
				torch.save(model.module.state_dict(), ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
				model.module._parse_genotype(file = ckpt_save_path + "genotype_{}.json".format(epoch + 1))
			else:
				torch.save(model.state_dict(), ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
				model._parse_genotype(file = ckpt_save_path + "genotype_{}.json".format(epoch + 1))
			
			logger.info("checkpoint {} was saved".format(epoch + 1))

			if is_best:
				if parallel:
					torch.save(model.module.state_dict(), ckpt_save_path + "best_ckpt.pth")
					model.module._parse_genotype(file = ckpt_save_path + "best_genotype.json")
				else:
					torch.save(model.state_dict(), ckpt_save_path + "best_ckpt.pth")
					model._parse_genotype(file = ckpt_save_path + "best_genotype.json")

				logger.info("best_checkpoint was saved")
				is_best = False
		

	logger.info("training is end")
Exemplo n.º 2
0
        shuffle=True)

n_student_layer = len(student_encoder.bert.encoder.layer)
student_encoder = load_model_wonbon(student_encoder,
                                    args.encoder_checkpoint,
                                    args,
                                    'student',
                                    verbose=True)
logger.info('*' * 77)
student_classifier = load_model(student_classifier,
                                args.cls_checkpoint,
                                args,
                                'classifier',
                                verbose=True)

n_param_student = count_parameters(student_encoder) + count_parameters(
    student_classifier)
logger.info('number of layers in student model = %d' % n_student_layer)
logger.info(
    'num parameters in student model are %d and %d' %
    (count_parameters(student_encoder), count_parameters(student_classifier)))

#########################################################################
# Prepare optimizer
#########################################################################
if args.do_train:
    param_optimizer = list(student_encoder.named_parameters()) + list(
        student_classifier.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
Exemplo n.º 3
0
                inter = self.inters[ind](inter)
        return outs


get_hourglass = \
    {'large_hourglass':
         exkp(n=5, nstack=2, dims=[256, 256, 384, 384, 384, 512], modules=[2, 2, 2, 2, 2, 4]),
     'small_hourglass':
         exkp(n=5, nstack=1, dims=[256, 256, 384, 384, 384, 512], modules=[2, 2, 2, 2, 2, 4])}

if __name__ == '__main__':
    from collections import OrderedDict
    from utils.utils import count_parameters, count_flops, load_model

    def hook(self, input, output):
        print(output.data.cpu().numpy().shape)
        # pass

    net = get_hourglass['large_hourglass']
    load_model(net, '../ckpt/pretrain/checkpoint.t7')
    count_parameters(net)
    count_flops(net, input_size=512)

    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.register_forward_hook(hook)

    with torch.no_grad():
        y = net(torch.randn(2, 3, 512, 512).cuda())
    # print(y.size())
Exemplo n.º 4
0
import config as cfg
from model import EfficientDet
from utils.utils import count_parameters
""" Quick test on parameters number """

model = EfficientDet.from_pretrained('efficientdet-d0').to('cpu')

model.eval()

print('Model: {}, params: {:.6f}M, params in paper: {}'.format(
    cfg.MODEL_NAME,
    count_parameters(model) / 1e6, cfg.PARAMS))
print('   Backbone: {:.6f}M'.format(count_parameters(model.backbone) / 1e6))
print('   Adjuster: {:.6f}M'.format(count_parameters(model.adjuster) / 1e6))
print('      BiFPN: {:.6f}M'.format(count_parameters(model.bifpn) / 1e6))
print('       Head: {:.6f}M'.format(
    (count_parameters(model.classifier) + count_parameters(model.regresser)) /
    1e6))
Exemplo n.º 5
0
def main(opt):
    torch.manual_seed(opt.seed)
    torch.backends.cudnn.benchmark = not opt.not_cuda_benchmark and not opt.test
    Dataset = get_dataset(opt.dataset, opt.task)
    val_dataset = Dataset(opt, 'val')
    opt = Opts().update_dataset_info_and_set_heads(opt, val_dataset)
    print(opt)

    logger = Logger(opt)

    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
    opt.device = torch.device('cuda' if opt.gpus[0] >= 0 else 'cpu')

    print('Creating model...')
    model = create_model(opt.arch, opt.heads, opt.head_conv)
    optimizer = torch.optim.Adam(model.parameters(), opt.lr)
    start_epoch = 0
    if opt.load_model != '':
        model, optimizer, start_epoch = load_model(
            model, opt.load_model, optimizer, opt.resume, opt.lr, opt.lr_step)

    print("Model:")
    print(model)
    print("Total number of parameters: {}".format(count_parameters(model)))
    print("Trainable parameters: {}".format(count_parameters(model, trainable=True)))


    Trainer = train_factory[opt.task]
    trainer = Trainer(opt, model, optimizer)
    trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)

    print('Setting up data...')
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=True
    )

    if opt.test:
        _, preds = trainer.val(0, val_loader)
        val_loader.dataset.run_eval(preds, opt.save_dir)
        print("{} images failed!".format(len(val_loader.dataset.failed_images)))
        print(val_loader.dataset.failed_images)
        return

    train_loader = torch.utils.data.DataLoader(
        Dataset(opt, 'train'),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.num_workers,
        pin_memory=True,
        drop_last=True
    )

    # Report failed images
    failed_images = list(train_loader.dataset.failed_images | val_loader.dataset.failed_images)
    if len(failed_images):
        print("{} images failed!".format(len(failed_images)))
        dump_path = os.path.join(opt.save_dir, 'failed_images.json')
        with open(dump_path, 'w') as f:
            json.dump(failed_images, f, sort_keys=True, indent=4, separators=(',', ': '))
        print("Failed image paths saved in: {}".format(dump_path))


    print('Starting training...')
    best = 1e10
    for epoch in range(start_epoch + 1, opt.num_epochs + 1):
        mark = epoch if opt.save_all else 'last'
        log_dict_train, _ = trainer.train(epoch, train_loader)
        logger.write('epoch: {} |'.format(epoch))
        for k, v in log_dict_train.items():
            logger.scalar_summary('train_{}'.format(k), v, epoch)
            logger.write('{} {:8f} | '.format(k, v))
        if opt.val_intervals > 0 and epoch % opt.val_intervals == 0:
            save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(mark)),
                       epoch, model, optimizer)
            with torch.no_grad():
                log_dict_val, preds = trainer.val(epoch, val_loader)
            for k, v in log_dict_val.items():
                logger.scalar_summary('val_{}'.format(k), v, epoch)
                logger.write('{} {:8f} | '.format(k, v))
            if log_dict_val[opt.metric] < best:
                best = log_dict_val[opt.metric]
                save_model(os.path.join(opt.save_dir, 'model_best.pth'),
                           epoch, model)
        else:
            save_model(os.path.join(opt.save_dir, 'model_last.pth'),
                       epoch, model, optimizer)
        logger.write('\n')
        if epoch in opt.lr_step:
            save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(epoch)),
                       epoch, model, optimizer)
            lr = opt.lr * (0.1 ** (opt.lr_step.index(epoch) + 1))
            print('Drop LR to', lr)
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
    logger.close()
Exemplo n.º 6
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    train_loader, val_loader = getDataLoader(args, logger)

    # Network
    aanet = nets.AANet(
        args.max_disp,
        num_downsample=args.num_downsample,
        feature_type=args.feature_type,
        no_feature_mdconv=args.no_feature_mdconv,
        feature_pyramid=args.feature_pyramid,
        feature_pyramid_network=args.feature_pyramid_network,
        feature_similarity=args.feature_similarity,
        aggregation_type=args.aggregation_type,
        useFeatureAtt=args.useFeatureAtt,
        num_scales=args.num_scales,
        num_fusions=args.num_fusions,
        num_stage_blocks=args.num_stage_blocks,
        num_deform_blocks=args.num_deform_blocks,
        no_intermediate_supervision=args.no_intermediate_supervision,
        refinement_type=args.refinement_type,
        mdconv_dilation=args.mdconv_dilation,
        deformable_groups=args.deformable_groups).to(device)

    # logger.info('%s' % aanet) if local_master else None
    if local_master:
        structure_of_net = os.path.join(args.checkpoint_dir,
                                        'structure_of_net.txt')
        with open(structure_of_net, 'w') as f:
            f.write('%s' % aanet)

    if args.pretrained_aanet is not None:
        logger.info('=> Loading pretrained AANet: %s' % args.pretrained_aanet)
        # Enable training from a partially pretrained model
        utils.load_pretrained_net(aanet,
                                  args.pretrained_aanet,
                                  no_strict=(not args.strict))

    aanet.to(device)
    logger.info('=> Use %d GPUs' %
                torch.cuda.device_count()) if local_master else None
    # if torch.cuda.device_count() > 1:
    if args.distributed:
        # aanet = torch.nn.DataParallel(aanet)
        #  尝试分布式训练
        aanet = torch.nn.SyncBatchNorm.convert_sync_batchnorm(aanet)
        aanet = torch.nn.parallel.DistributedDataParallel(
            aanet, device_ids=[local_rank], output_device=local_rank)
        synchronize()

    # Save parameters
    num_params = utils.count_parameters(aanet)
    logger.info('=> Number of trainable parameters: %d' % num_params)
    save_name = '%d_parameters' % num_params
    open(os.path.join(args.checkpoint_dir, save_name), 'a').close(
    ) if local_master else None  # 这是个空文件,只是通过其文件名称指示模型有多少个需要训练的参数

    # Optimizer
    # Learning rate for offset learning is set 0.1 times those of existing layers
    specific_params = list(
        filter(utils.filter_specific_params, aanet.named_parameters()))
    base_params = list(
        filter(utils.filter_base_params, aanet.named_parameters()))

    specific_params = [kv[1]
                       for kv in specific_params]  # kv is a tuple (key, value)
    base_params = [kv[1] for kv in base_params]

    specific_lr = args.learning_rate * 0.1
    params_group = [
        {
            'params': base_params,
            'lr': args.learning_rate
        },
        {
            'params': specific_params,
            'lr': specific_lr
        },
    ]

    optimizer = torch.optim.Adam(params_group, weight_decay=args.weight_decay)

    # Resume training
    if args.resume:
        # 1. resume AANet
        start_epoch, start_iter, best_epe, best_epoch = utils.resume_latest_ckpt(
            args.checkpoint_dir, aanet, 'aanet')
        # 2. resume Optimizer
        utils.resume_latest_ckpt(args.checkpoint_dir, optimizer, 'optimizer')
    else:
        start_epoch = 0
        start_iter = 0
        best_epe = None
        best_epoch = None

    # LR scheduler
    if args.lr_scheduler_type is not None:
        last_epoch = start_epoch if args.resume else start_epoch - 1
        if args.lr_scheduler_type == 'MultiStepLR':
            milestones = [int(step) for step in args.milestones.split(',')]
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=milestones,
                gamma=args.lr_decay_gamma,
                last_epoch=last_epoch
            )  # 最后这个last_epoch参数很重要:如果是resume的话,则会自动调整学习率适去应last_epoch。
        else:
            raise NotImplementedError
    # model.Model(object)对AANet做了进一步封装。
    train_model = model.Model(args,
                              logger,
                              optimizer,
                              aanet,
                              device,
                              start_iter,
                              start_epoch,
                              best_epe=best_epe,
                              best_epoch=best_epoch)

    logger.info('=> Start training...')

    trainLoss_dict, trainLossKey, valLoss_dict, valLossKey = getLossRecord(
        netName="AANet")

    if args.evaluate_only:
        assert args.val_batch_size == 1
        train_model.validate(
            val_loader, local_master, valLoss_dict,
            valLossKey)  # test模式。应该设置--evaluate_only,且--mode为“test”。
        # 保存Loss用于分析
        save_loss_for_matlab(trainLoss_dict, valLoss_dict)
    else:
        for epoch in range(start_epoch, args.max_epoch):  # 训练主循环(Epochs)!!!
            if not args.evaluate_only:
                # ensure distribute worker sample different data,
                # set different random seed by passing epoch to sampler
                if args.distributed:
                    train_loader.sampler.set_epoch(epoch)
                    logger.info(
                        'train_loader.sampler.set_epoch({})'.format(epoch))
                train_model.train(train_loader, local_master, trainLoss_dict,
                                  trainLossKey)
            if not args.no_validate:
                train_model.validate(val_loader, local_master, valLoss_dict,
                                     valLossKey)  # 训练模式下:边训练边验证。
            if args.lr_scheduler_type is not None:
                lr_scheduler.step()  # 调整Learning Rate

            # 保存Loss用于分析。每个epoch结束后,都保存一次,覆盖之前的保存。避免必须训练完成才保存的弊端。
            save_loss_for_matlab(trainLoss_dict, valLoss_dict)

        logger.info('=> End training\n\n')
Exemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch ImageNet Training with sparse masks')
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--lr_decay',
                        default=[30, 60, 90],
                        nargs='+',
                        type=int,
                        help='learning rate decay epochs')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--weight_decay',
                        default=1e-4,
                        type=float,
                        help='weight decay')
    parser.add_argument('--batchsize', default=64, type=int, help='batch size')
    parser.add_argument('--epochs',
                        default=100,
                        type=int,
                        help='number of epochs')
    parser.add_argument('--model',
                        type=str,
                        default='resnet101',
                        help='network model name')

    parser.add_argument(
        '--budget',
        default=-1,
        type=float,
        help='computational budget (between 0 and 1) (-1 for no sparsity)')
    parser.add_argument('-s',
                        '--save_dir',
                        type=str,
                        default='',
                        help='directory to save model')
    parser.add_argument('-r',
                        '--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--dataset-root',
                        default='/esat/visicsrodata/datasets/ilsvrc2012/',
                        type=str,
                        metavar='PATH',
                        help='ImageNet dataset root')
    parser.add_argument('-e',
                        '--evaluate',
                        action='store_true',
                        help='evaluation mode')
    parser.add_argument('--plot_ponder',
                        action='store_true',
                        help='plot ponder cost')
    parser.add_argument('--pretrained',
                        action='store_true',
                        help='start from pretrained model')
    parser.add_argument('--workers',
                        default=8,
                        type=int,
                        help='number of dataloader workers')
    args = parser.parse_args()
    print('Args:', args)

    res = 224

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(res),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    transform_val = transforms.Compose([
        transforms.Resize(int(res / 0.875)),
        transforms.CenterCrop(res),
        transforms.ToTensor(),
        normalize,
    ])

    ## DATA
    trainset = dataloader.imagenet.IN1K(root=args.dataset_root,
                                        split='train',
                                        transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False)

    valset = dataloader.imagenet.IN1K(root=args.dataset_root,
                                      split='val',
                                      transform=transform_val)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=args.batchsize,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=False)

    ## MODEL
    net_module = models.__dict__[args.model]
    model = net_module(sparse=args.budget >= 0,
                       pretrained=args.pretrained).to(device=device)

    ## CRITERION
    class Loss(nn.Module):
        def __init__(self):
            super(Loss, self).__init__()
            self.task_loss = nn.CrossEntropyLoss().to(device=device)
            self.sparsity_loss = dynconv.SparsityCriterion(
                args.budget, args.epochs) if args.budget >= 0 else None

        def forward(self, output, target, meta):
            l = self.task_loss(output, target)
            logger.add('loss_task', l.item())
            if self.sparsity_loss is not None:
                l += 10 * self.sparsity_loss(meta)
            return l

    criterion = Loss()

    ## OPTIMIZER
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    ## CHECKPOINT
    start_epoch = -1
    best_prec1 = 0

    if not args.evaluate and len(args.save_dir) > 0:
        if not os.path.exists(os.path.join(args.save_dir)):
            os.makedirs(os.path.join(args.save_dir))

    if args.resume:
        if os.path.isfile(args.resume):
            print(f"=> loading checkpoint '{args.resume}'")
            checkpoint = torch.load(args.resume)
            # print('check', checkpoint)
            start_epoch = checkpoint['epoch'] - 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(
                f"=> loaded checkpoint '{args.resume}'' (epoch {checkpoint['epoch']}, best prec1 {checkpoint['best_prec1']})"
            )
        else:
            msg = "=> no checkpoint found at '{}'".format(args.resume)
            if args.evaluate:
                raise ValueError(msg)
            else:
                print(msg)

    try:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=args.lr_decay, last_epoch=start_epoch)
    except:
        print('Warning: Could not reload learning rate scheduler')
    start_epoch += 1

    ## Count number of params
    print("* Number of trainable parameters:", utils.count_parameters(model))

    ## EVALUATION
    if args.evaluate:
        # evaluate on validation set
        print(f"########## Evaluation ##########")
        prec1 = validate(args, val_loader, model, criterion, start_epoch)
        return

    ## TRAINING
    for epoch in range(start_epoch, args.epochs):
        print(f"########## Epoch {epoch} ##########")

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(args, train_loader, model, criterion, optimizer, epoch)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(args, val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        utils.save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch + 1,
                'best_prec1': best_prec1,
            },
            folder=args.save_dir,
            is_best=is_best)

        print(f" * Best prec1: {best_prec1}")
Exemplo n.º 8
0
    for up_transform in tmp['up_transform']
]

meshdata = MeshData(args.data_fp,
                    template_fp,
                    split=args.split,
                    test_exp=args.test_exp)
train_loader = DataLoader(meshdata.train_dataset,
                          batch_size=args.batch_size,
                          shuffle=True)
test_loader = DataLoader(meshdata.test_dataset, batch_size=args.batch_size)

# generate/load transform matrices

model = AE(args.in_channels, args.out_channels, args.latent_channels,
           spiral_indices_list, down_transform_list,
           up_transform_list).to(device)
print('Number of parameters: {}'.format(utils.count_parameters(model)))
print(model)

optimizer = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                            args.decay_step,
                                            gamma=args.lr_decay)

run(model, train_loader, test_loader, args.epochs, optimizer, scheduler,
    writer, device)
eval_error(model, test_loader, device, meshdata, args.out_dir)
Exemplo n.º 9
0
def main(opt):
	#========= Loading Dataset =========#
	data = torch.load(opt.data)
	vocab_size = len(data['dict']['tgt'])
	
	global_labels = None
	for i in range(len(data['train']['src'])):
		labels = torch.tensor(data['train']['tgt'][i]).unsqueeze(0)
		labels = utils.get_gold_binary_full(labels,vocab_size)
		if global_labels is None:
			global_labels = labels
		else:
			global_labels+=labels

	for i in range(len(data['valid']['src'])):
		labels = torch.tensor(data['valid']['tgt'][i]).unsqueeze(0)
		labels = utils.get_gold_binary_full(labels,vocab_size)
		global_labels+=labels
		
	for i in range(len(data['test']['src'])):
		labels = torch.tensor(data['test']['tgt'][i]).unsqueeze(0)
		labels = utils.get_gold_binary_full(labels,vocab_size)
		global_labels+=labels

	global_labels = global_labels[0][0:-4]

	ranked_labels,ranked_idx = torch.sort(global_labels)

	indices = ranked_idx[2:24].long()
	label_count = ranked_labels[2:24]


	train_data,valid_data,test_data,label_adj_matrix,opt = process_data(data,opt)
	print(opt)

	#========= Preparing Model =========#
	model = LAMP(
        opt,
		opt.src_vocab_size,
		opt.tgt_vocab_size,
		opt.max_token_seq_len_e,
		opt.max_token_seq_len_d,
		proj_share_weight=opt.proj_share_weight,
		embs_share_weight=opt.embs_share_weight,
		d_k=opt.d_k,
		d_v=opt.d_v,
		d_model=opt.d_model,
		d_word_vec=opt.d_word_vec,
		d_inner_hid=opt.d_inner_hid,
		n_layers_enc=opt.n_layers_enc,
		n_layers_dec=opt.n_layers_dec,
		n_head=opt.n_head,
		n_head2=opt.n_head2,
		dropout=opt.dropout,
		dec_dropout=opt.dec_dropout,
		dec_dropout2=opt.dec_dropout2,
		encoder=opt.encoder,
		decoder=opt.decoder,
		enc_transform=opt.enc_transform,
		onehot=opt.onehot,
		no_enc_pos_embedding=opt.no_enc_pos_embedding,
		no_dec_self_att=opt.no_dec_self_att,
		loss=opt.loss,
		label_adj_matrix=label_adj_matrix,
		attn_type=opt.attn_type,
		label_mask=opt.label_mask,
		matching_mlp=opt.matching_mlp,
		graph_conv=opt.graph_conv,
		int_preds=opt.int_preds)

	print(model)
	print(opt.model_name)


	opt.total_num_parameters = int(utils.count_parameters(model))

	if opt.load_emb:
		model = utils.load_embeddings(model,'../../Data/word_embedding_dict.pth')
 
	optimizer = torch.optim.Adam(model.get_trainable_parameters(),betas=(0.9, 0.98),lr=opt.lr)
	scheduler = torch.torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt.lr_step_size, gamma=opt.lr_decay,last_epoch=-1)

	adv_optimizer = None
	
	crit = utils.get_criterion(opt)

	if torch.cuda.device_count() > 1 and opt.multi_gpu:
		print("Using", torch.cuda.device_count(), "GPUs!")
		model = nn.DataParallel(model)

	if torch.cuda.is_available() and opt.cuda:
		model = model.cuda()
	
		crit = crit.cuda()
		if opt.gpu_id != -1:
			torch.cuda.set_device(opt.gpu_id)

	if opt.load_pretrained:		
		checkpoint = torch.load(opt.model_name+'/model.chkpt')
		model.load_state_dict(checkpoint['model'])

	try:
		run_model(model,train_data,valid_data,test_data,crit,optimizer, adv_optimizer,scheduler,opt,data['dict'])
	except KeyboardInterrupt:
		print('-' * 89+'\nManual Exit')
		exit()
Exemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser(
        description='PyTorch CIFAR10 Training with sparse masks')
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--lr_decay',
                        default=[150, 250],
                        nargs='+',
                        type=int,
                        help='learning rate decay epochs')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument('--weight_decay',
                        default=5e-4,
                        type=float,
                        help='weight decay')
    parser.add_argument('--batchsize',
                        default=256,
                        type=int,
                        help='batch size')
    parser.add_argument('--epochs',
                        default=350,
                        type=int,
                        help='number of epochs')
    parser.add_argument('--model',
                        type=str,
                        default='resnet32',
                        help='network model name')

    # parser.add_argument('--resnet_n', default=5, type=int, help='number of layers per resnet stage (5 for Resnet-32)')
    parser.add_argument(
        '--budget',
        default=-1,
        type=float,
        help='computational budget (between 0 and 1) (-1 for no sparsity)')
    parser.add_argument('-s',
                        '--save_dir',
                        type=str,
                        default='',
                        help='directory to save model')
    parser.add_argument('-r',
                        '--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e',
                        '--evaluate',
                        action='store_true',
                        help='evaluation mode')
    parser.add_argument('--plot_ponder',
                        action='store_true',
                        help='plot ponder cost')
    parser.add_argument('--workers',
                        default=8,
                        type=int,
                        help='number of dataloader workers')
    parser.add_argument('--pretrained',
                        action='store_true',
                        help='initialize with pretrained model')
    args = parser.parse_args()
    print('Args:', args)

    mean, std = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    ## DATA
    trainset = datasets.CIFAR10(root='../data',
                                train=True,
                                download=True,
                                transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.batchsize,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=False)

    valset = datasets.CIFAR10(root='../data',
                              train=False,
                              download=True,
                              transform=transform_test)
    val_loader = torch.utils.data.DataLoader(valset,
                                             batch_size=args.batchsize,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=False)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    ## MODEL
    net_module = models.__dict__[args.model]
    model = net_module(sparse=args.budget >= 0,
                       pretrained=args.pretrained).to(device=device)

    ## CRITERION
    class Loss(nn.Module):
        def __init__(self):
            super(Loss, self).__init__()
            self.task_loss = nn.CrossEntropyLoss().to(device=device)
            self.sparsity_loss = dynconv.SparsityCriterion(
                args.budget, args.epochs) if args.budget >= 0 else None

        def forward(self, output, target, meta):
            l = self.task_loss(output, target)
            logger.add('loss_task', l.item())
            if self.sparsity_loss is not None:
                l += 10 * self.sparsity_loss(meta)
            return l

    criterion = Loss()

    ## OPTIMIZER
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    ## CHECKPOINT
    start_epoch = -1
    best_prec1 = 0

    if not args.evaluate and len(args.save_dir) > 0:
        if not os.path.exists(os.path.join(args.save_dir)):
            os.makedirs(os.path.join(args.save_dir))

    if args.resume:
        resume_path = args.resume
        if not os.path.isfile(resume_path):
            resume_path = os.path.join(resume_path, 'checkpoint.pth')
        if os.path.isfile(resume_path):
            print(f"=> loading checkpoint '{resume_path}'")
            checkpoint = torch.load(resume_path)
            start_epoch = checkpoint['epoch'] - 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(
                f"=> loaded checkpoint '{resume_path}'' (epoch {checkpoint['epoch']}, best prec1 {checkpoint['best_prec1']})"
            )
        else:
            msg = "=> no checkpoint found at '{}'".format(resume_path)
            if args.evaluate:
                raise ValueError(msg)
            else:
                print(msg)

    try:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=args.lr_decay, last_epoch=start_epoch)
    except:
        print('Warning: Could not reload learning rate scheduler')
    start_epoch += 1

    ## Count number of params
    print("* Number of trainable parameters:", utils.count_parameters(model))

    ## EVALUATION
    if args.evaluate:
        print(f"########## Evaluation ##########")
        prec1 = validate(args, val_loader, model, criterion, start_epoch)
        return

    ## TRAINING
    for epoch in range(start_epoch, args.epochs):
        print(f"########## Epoch {epoch} ##########")

        # train for one epoch
        print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
        train(args, train_loader, model, criterion, optimizer, epoch)
        lr_scheduler.step()

        # evaluate on validation set
        prec1 = validate(args, val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        utils.save_checkpoint(
            {
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch + 1,
                'best_prec1': best_prec1,
            },
            folder=args.save_dir,
            is_best=is_best)

        print(f" * Best prec1: {best_prec1}")
Exemplo n.º 11
0
        exec(
            "a_%s = torch.tensor(1, dtype=torch.float32, requires_grad=True)" %
            num)

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    optimizer_fe = torch.optim.Adam(model_fe.parameters(), lr=args.lr)
    optimizer_para = torch.optim.Adam([
        a_00, a_10, a_20, a_30, a_01, a_11, a_21, a_31, a_02, a_22, a_12, a_32,
        a_b01, a_b11, a_b21, a_b31, a_b02, a_b22, a_b12, a_b32
    ],
                                      lr=args.lr)
    optimizer_x_t = torch.optim.Adam(
        [rand_x_index, rand_x_index2, rand_t_index], lr=args.lr)

    print(('Number of parameters: {}'.format(count_parameters(model))))

    start_time = time.time()

    for epoch in range(args.nepochs):
        # -------------------------------- Train Dataset --------------------------------
        for itr, (data) in enumerate(train_loader):
            # update x_t pairs
            rand_x_t_pairs = torch.cat(
                [rand_x_index, rand_x_index2, rand_t_index], dim=0).to(device)

            # learning rate scheduling
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_fn(itr + epoch * batches_per_epoch)

            x, y = data
Exemplo n.º 12
0
def start(window, args, env):
    alg = DeepTamer(window, args, env)
    print("Number of trainable parameters:", utils.count_parameters(alg.reward_net))
    alg.train()
    env.close()
Exemplo n.º 13
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Test loader
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])
    test_data = dataloader.StereoDataset(data_dir=args.data_dir,
                                         dataset_name=args.dataset_name,
                                         mode=args.mode,
                                         save_filename=True,
                                         transform=test_transform)
    test_loader = DataLoader(dataset=test_data,
                             batch_size=args.batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             pin_memory=True,
                             drop_last=False)

    aanet = nets.AANet(
        args.max_disp,
        num_downsample=args.num_downsample,
        feature_type=args.feature_type,
        no_feature_mdconv=args.no_feature_mdconv,
        feature_pyramid=args.feature_pyramid,
        feature_pyramid_network=args.feature_pyramid_network,
        feature_similarity=args.feature_similarity,
        aggregation_type=args.aggregation_type,
        num_scales=args.num_scales,
        num_fusions=args.num_fusions,
        num_stage_blocks=args.num_stage_blocks,
        num_deform_blocks=args.num_deform_blocks,
        no_intermediate_supervision=args.no_intermediate_supervision,
        refinement_type=args.refinement_type,
        mdconv_dilation=args.mdconv_dilation,
        deformable_groups=args.deformable_groups).to(device)

    # print(aanet)

    if os.path.exists(args.pretrained_aanet):
        print('=> Loading pretrained AANet:', args.pretrained_aanet)
        utils.load_pretrained_net(aanet, args.pretrained_aanet, no_strict=True)
    else:
        print('=> Using random initialization')

    # Save parameters
    num_params = utils.count_parameters(aanet)
    print('=> Number of trainable parameters: %d' % num_params)

    if torch.cuda.device_count() > 1:
        print('=> Use %d GPUs' % torch.cuda.device_count())
        aanet = torch.nn.DataParallel(aanet)

    # Inference
    aanet.eval()

    inference_time = 0
    num_imgs = 0

    num_samples = len(test_loader)
    print('=> %d samples found in the test set' % num_samples)

    for i, sample in enumerate(test_loader):
        if args.count_time and i == args.num_images:  # testing time only
            break

        if i % 100 == 0:
            print('=> Inferencing %d/%d' % (i, num_samples))

        left = sample['left'].to(device)  # [B, 3, H, W]
        right = sample['right'].to(device)

        # Pad
        ori_height, ori_width = left.size()[2:]
        if ori_height < args.img_height or ori_width < args.img_width:
            top_pad = args.img_height - ori_height
            right_pad = args.img_width - ori_width

            # Pad size: (left_pad, right_pad, top_pad, bottom_pad)
            left = F.pad(left, (0, right_pad, top_pad, 0))
            right = F.pad(right, (0, right_pad, top_pad, 0))

        # Warpup
        if i == 0 and args.count_time:
            with torch.no_grad():
                for _ in range(10):
                    aanet(left, right)

        num_imgs += left.size(0)

        with torch.no_grad():
            time_start = time.perf_counter()
            pred_disp = aanet(left, right)[-1]  # [B, H, W]
            inference_time += time.perf_counter() - time_start

        if pred_disp.size(-1) < left.size(-1):
            pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
            pred_disp = F.interpolate(
                pred_disp, (left.size(-2), left.size(-1)),
                mode='bilinear') * (left.size(-1) / pred_disp.size(-1))
            pred_disp = pred_disp.squeeze(1)  # [B, H, W]

        # Crop
        if ori_height < args.img_height or ori_width < args.img_width:
            if right_pad != 0:
                pred_disp = pred_disp[:, top_pad:, :-right_pad]
            else:
                pred_disp = pred_disp[:, top_pad:]

        for b in range(pred_disp.size(0)):
            disp = pred_disp[b].detach().cpu().numpy()  # [H, W]
            save_name = sample['left_name'][b]
            save_name = os.path.join(args.output_dir, save_name)
            utils.check_path(os.path.dirname(save_name))
            if not args.count_time:
                if args.save_type == 'pfm':
                    if args.visualize:
                        skimage.io.imsave(save_name,
                                          (disp * 256.).astype(np.uint16))

                    save_name = save_name[:-3] + 'pfm'
                    write_pfm(save_name, disp)
                elif args.save_type == 'npy':
                    save_name = save_name[:-3] + 'npy'
                    np.save(save_name, disp)
                else:
                    skimage.io.imsave(save_name,
                                      (disp * 256.).astype(np.uint16))

    print('=> Mean inference time for %d images: %.3fs' %
          (num_imgs, inference_time / num_imgs))
Exemplo n.º 14
0
def train():

    use_gpu = cfg.MODEL.DEVICE == "cuda"
    # 1、make dataloader
    train_loader, val_loader, test_loader, num_query, num_class = darts_make_data_loader(
        cfg)
    # print(num_query, num_class)

    # 2、make model
    model = CNetwork(num_class, cfg)
    # tensor = torch.randn(2, 3, 256, 128)
    # res = model(tensor)
    # print(res[0].size()) [2, 751]

    # 3、make optimizer
    optimizer = make_optimizer(cfg, model)
    # make architecture optimizer
    arch_optimizer = torch.optim.Adam(
        model._arch_parameters(),
        lr=cfg.SOLVER.ARCH_LR,
        betas=(0.5, 0.999),
        weight_decay=cfg.SOLVER.ARCH_WEIGHT_DECAY)

    # 4、make lr scheduler
    lr_scheduler = make_lr_scheduler(cfg, optimizer)
    # make lr scheduler
    arch_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        arch_optimizer, [80, 160], 0.1)

    # 5、make loss
    loss_fn = darts_make_loss(cfg)

    # get parameters
    device = cfg.MODEL.DEVICE
    use_gpu = device == "cuda"
    pretrained = cfg.MODEL.PRETRAINED != ""

    log_period = cfg.OUTPUT.LOG_PERIOD
    ckpt_period = cfg.OUTPUT.CKPT_PERIOD
    eval_period = cfg.OUTPUT.EVAL_PERIOD
    output_dir = cfg.OUTPUT.DIRS
    ckpt_save_path = output_dir + cfg.OUTPUT.CKPT_DIRS

    epochs = cfg.SOLVER.MAX_EPOCHS
    batch_size = cfg.SOLVER.BATCH_SIZE
    grad_clip = cfg.SOLVER.GRAD_CLIP

    batch_num = len(train_loader)
    log_iters = batch_num // log_period

    if not os.path.exists(ckpt_save_path):
        os.makedirs(ckpt_save_path)

    # create *_result.xlsx
    # save the result for analyze
    name = (cfg.OUTPUT.LOG_NAME).split(".")[0] + ".xlsx"
    result_path = cfg.OUTPUT.DIRS + name

    wb = xl.Workbook()
    sheet = wb.worksheets[0]
    titles = [
        'size/M', 'speed/ms', 'final_planes', 'acc', 'mAP', 'r1', 'r5', 'r10',
        'loss', 'acc', 'mAP', 'r1', 'r5', 'r10', 'loss', 'acc', 'mAP', 'r1',
        'r5', 'r10', 'loss'
    ]
    sheet.append(titles)
    check_epochs = [40, 80, 120, 160, 200, 240, 280, 320, 360, epochs]
    values = []

    logger = logging.getLogger("CNet_Search.train")
    size = count_parameters(model)
    values.append(format(size, '.2f'))
    values.append(model.final_planes)

    logger.info("the param number of the model is {:.2f} M".format(size))

    logger.info("Starting Search CNetwork")

    best_mAP, best_r1 = 0., 0.
    is_best = False
    avg_loss, avg_acc = RunningAverageMeter(), RunningAverageMeter()
    avg_time, global_avg_time = AverageMeter(), AverageMeter()

    if use_gpu:
        model = model.to(device)

    if pretrained:
        logger.info("load self pretrained chekpoint to init")
        model.load_pretrained_model(cfg.MODEL.PRETRAINED)
    else:
        logger.info("use kaiming init to init the model")
        model.kaiming_init_()

    for epoch in range(epochs):

        lr_scheduler.step()
        lr = lr_scheduler.get_lr()[0]
        # architect lr.step
        arch_lr_scheduler.step()

        # if save epoch_num k, then run k+1 epoch next
        if pretrained and epoch < model.start_epoch:
            continue

        # print(epoch)
        # exit(1)
        model.train()
        avg_loss.reset()
        avg_acc.reset()
        avg_time.reset()

        for i, batch in enumerate(train_loader):

            t0 = time.time()
            imgs, labels = batch
            val_imgs, val_labels = next(iter(val_loader))

            if use_gpu:
                imgs = imgs.to(device)
                labels = labels.to(device)
                val_imgs = val_imgs.to(device)
                val_labels = val_labels.to(device)

            # 1、 update the weights
            optimizer.zero_grad()
            res = model(imgs)

            # loss = loss_fn(scores, feats, labels)
            loss, acc = compute_loss_acc(res, labels, loss_fn)
            loss.backward()

            if grad_clip != 0:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()

            # 2、update the alpha
            arch_optimizer.zero_grad()
            res = model(val_imgs)

            val_loss, val_acc = compute_loss_acc(res, val_labels, loss_fn)
            val_loss.backward()
            arch_optimizer.step()

            # compute the acc
            # acc = (scores.max(1)[1] == labels).float().mean()

            t1 = time.time()
            avg_time.update((t1 - t0) / batch_size)
            avg_loss.update(loss)
            avg_acc.update(acc)

            # log info
            if (i + 1) % log_iters == 0:
                logger.info(
                    "epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".
                    format(epoch + 1, i + 1, batch_num, avg_loss.avg,
                           avg_acc.avg))

        logger.info(
            "end epochs {}/{} with lr: {:.5f} and avg_time is: {:.3f} ms".
            format(epoch + 1, epochs, lr, avg_time.avg * 1000))
        global_avg_time.update(avg_time.avg)

        # test the model
        if (epoch + 1) % eval_period == 0:

            model.eval()
            metrics = R1_mAP(num_query, use_gpu=use_gpu)

            with torch.no_grad():
                for vi, batch in enumerate(test_loader):
                    # break
                    # print(len(batch))
                    imgs, labels, camids = batch
                    if use_gpu:
                        imgs = imgs.to(device)

                    feats = model(imgs)
                    metrics.update((feats, labels, camids))

                #compute cmc and mAP
                cmc, mAP = metrics.compute()
                logger.info("validation results at epoch {}".format(epoch + 1))
                logger.info("mAP:{:2%}".format(mAP))
                for r in [1, 5, 10]:
                    logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(
                        r, cmc[r - 1]))

                # determine whether current model is the best
                if mAP > best_mAP:
                    is_best = True
                    best_mAP = mAP
                    logger.info("Get a new best mAP")
                if cmc[0] > best_r1:
                    is_best = True
                    best_r1 = cmc[0]
                    logger.info("Get a new best r1")

                # add the result to sheet
                if (epoch + 1) in check_epochs:
                    val = [avg_acc.avg, mAP, cmc[0], cmc[4], cmc[9]]
                    change = [format(v * 100, '.2f') for v in val]
                    change.append(format(avg_loss.avg, '.3f'))
                    values.extend(change)

        # whether to save the model
        if (epoch + 1) % ckpt_period == 0 or is_best:
            torch.save(model.state_dict(),
                       ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
            model._parse_genotype(file=ckpt_save_path +
                                  "genotype_{}.json".format(epoch + 1))
            logger.info("checkpoint {} was saved".format(epoch + 1))

            if is_best:
                torch.save(model.state_dict(),
                           ckpt_save_path + "best_ckpt.pth")
                model._parse_genotype(file=ckpt_save_path +
                                      "best_genotype.json")
                logger.info("best_checkpoint was saved")
                is_best = False
        # exit(1)

    values.insert(1, format(global_avg_time.avg * 1000, '.2f'))
    sheet.append(values)
    wb.save(result_path)

    logger.info("Ending Search CNetwork")
Exemplo n.º 15
0
def train():

	# 1、make dataloader
	# prepare train,val img_info list, elem is tuple; 
	train_loader, val_loader, num_query, num_class = make_data_loader(cfg)
	
	# 2、make model
	model = build_model(cfg, num_class)

	# 3、 make optimizer
	optimizer = make_optimizer(cfg, model)

	# 4、 make lr_scheduler
	scheduler = make_lr_scheduler(cfg, optimizer)

	# 5、make loss 
	loss_fn = make_loss(cfg, num_class)

	# get parameters 
	device = cfg.MODEL.DEVICE
	use_gpu = device == "cuda"
	pretrained = cfg.MODEL.PRETRAIN_PATH != ""
	parallel = cfg.MODEL.PARALLEL

	log_period = cfg.OUTPUT.LOG_PERIOD
	ckpt_period = cfg.OUTPUT.CKPT_PERIOD
	eval_period = cfg.OUTPUT.EVAL_PERIOD
	output_dir = cfg.OUTPUT.DIRS
	ckpt_save_path = output_dir + cfg.OUTPUT.CKPT_DIRS
	
	epochs = cfg.SOLVER.MAX_EPOCHS
	batch_size = cfg.SOLVER.BATCH_SIZE
	grad_clip = cfg.SOLVER.GRAD_CLIP

	batch_num = len(train_loader)
	log_iters = batch_num // log_period 

	if not os.path.exists(ckpt_save_path):
		os.makedirs(ckpt_save_path)

	# create *_result.xlsx
	# save the result for analyze
	name = (cfg.OUTPUT.LOG_NAME).split(".")[0] + ".xlsx"
	result_path = cfg.OUTPUT.DIRS + name

	wb = xl.Workbook()
	sheet = wb.worksheets[0]
	titles = ['size/M','speed/ms','final_planes', 'acc', 'mAP', 'r1', 'r5', 'r10', 'loss',
			  'acc', 'mAP', 'r1', 'r5', 'r10', 'loss','acc', 'mAP', 'r1', 'r5', 'r10', 'loss']
	sheet.append(titles)
	check_epochs = [40, 80, 120, 160, 200, 240, 280, 320, 360, epochs]
	values = []

	logger = logging.getLogger("CDNet.train")
	size = count_parameters(model)
	values.append(format(size, '.2f'))
	values.append(model.final_planes)
	
	logger.info("the param number of the model is {:.2f} M".format(size))
	infer_size = infer_count_parameters(model)
	logger.info("the infer param number of the model is {:.2f}M".format(infer_size))

	shape = [1, 3]
	shape.extend(cfg.DATA.IMAGE_SIZE)
	
	# if cfg.MODEL.NAME == 'cdnet' :
	# 	infer_model = CDNetwork(num_class, cfg)
	# elif cfg.MODEL.NAME == 'cnet':
	# 	infer_model = CNetwork(num_class, cfg)
	# else:
	# 	infer_model = model 

	# for scaling experiment
	flops, _ = get_model_infos(model, shape)
	logger.info("the total flops number of the model is {:.2f} M".format(flops))
	
	logger.info("Starting Training CDNetwork")
	
	best_mAP, best_r1 = 0., 0.
	is_best = False
	avg_loss, avg_acc = RunningAverageMeter(),RunningAverageMeter()
	avg_time, global_avg_time = AverageMeter(), AverageMeter()

	if parallel:
		model = nn.DataParallel(model)
		
	if use_gpu:
		model = model.to(device)

	for epoch in range(epochs):
		
		scheduler.step()
		lr = scheduler.get_lr()[0]
		# if save epoch_num k, then run k+1 epoch next
		if pretrained and epoch < model.start_epoch:
			continue

		# rest the record
		model.train()
		avg_loss.reset()
		avg_acc.reset()
		avg_time.reset()

		for i, batch in enumerate(train_loader):

			t0 = time.time()
			imgs, labels = batch 

			if use_gpu:
				imgs = imgs.to(device)
				labels = labels.to(device)

			res = model(imgs)
		
			loss, acc = compute_loss_acc(res, labels, loss_fn)
			loss.backward()

			if grad_clip != 0:
				nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

			optimizer.step()
			optimizer.zero_grad()

			t1 = time.time()
			avg_time.update((t1 - t0) / batch_size)
			avg_loss.update(loss)
			avg_acc.update(acc)

			# log info
			if (i+1) % log_iters == 0:
				logger.info("epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".format(
					epoch+1, i+1, batch_num, avg_loss.avg, avg_acc.avg))

		logger.info("end epochs {}/{} with lr: {:.5f} and avg_time is: {:.3f} ms".format(epoch+1, epochs, lr, avg_time.avg * 1000))
		global_avg_time.update(avg_time.avg)

		# test the model
		if (epoch + 1) % eval_period == 0 or (epoch + 1) in check_epochs:

			model.eval()
			metrics = R1_mAP(num_query, use_gpu = use_gpu)

			with torch.no_grad():
				for vi, batch in enumerate(val_loader):
					
					imgs, labels, camids = batch
					if use_gpu:
						imgs = imgs.to(device)

					feats = model(imgs)
					metrics.update((feats, labels, camids))

				#compute cmc and mAP
				cmc, mAP = metrics.compute()
				logger.info("validation results at epoch {}".format(epoch + 1))
				logger.info("mAP:{:2%}".format(mAP))
				for r in [1,5,10]:
					logger.info("CMC curve, Rank-{:<3}:{:.2%}".format(r, cmc[r-1]))

				# determine whether current model is the best
				if mAP > best_mAP:
					is_best = True
					best_mAP = mAP
					logger.info("Get a new best mAP")
				if cmc[0] > best_r1:
					is_best = True
					best_r1 = cmc[0]
					logger.info("Get a new best r1")

				# add the result to sheet
				if (epoch + 1) in check_epochs:
					val = [avg_acc.avg, mAP, cmc[0], cmc[4], cmc[9]]
					change = [format(v * 100, '.2f') for v in val]
					change.append(format(avg_loss.avg, '.3f'))
					values.extend(change)
					
		# whether to save the model
		if (epoch + 1) % ckpt_period == 0 or is_best:
			torch.save(model.state_dict(), ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
			logger.info("checkpoint {} was saved".format(epoch + 1))

			if is_best:
				torch.save(model.state_dict(), ckpt_save_path + "best_ckpt.pth")
				logger.info("best_checkpoint was saved")
				is_best = False
		

	values.insert(1, format(global_avg_time.avg * 1000, '.2f'))
	values.append(format(infer_size, '.2f'))
	sheet.append(values)
	wb.save(result_path)
	logger.info("best_mAP:{:.2%}, best_r1:{:.2%}".format(best_mAP, best_r1))
	logger.info("Ending training CDNetwork")
Exemplo n.º 16
0
    def __init__(self,
                 model_type,
                 elem_prop,
                 mat_prop,
                 input_dims,
                 hidden_dims,
                 output_dims,
                 out_dir,
                 edm=False,
                 batch_size=1,
                 random_seed=None,
                 save_network_info=True):
        super(NeuralNetWrapper, self).__init__()

        self.model_type = model_type
        self.elem_prop = elem_prop
        self.mat_prop = mat_prop

        self.input_dims = input_dims
        self.hidden_dims = hidden_dims
        self.output_dims = output_dims

        self.out_dir = out_dir
        self.edm = edm
        self.batch_size = batch_size
        self.random_seed = random_seed
        self.data_type = torch.float

        self.save_network_info = save_network_info

        # Default to using the last GPU available
        self.CUDA_count = torch.cuda.device_count()
        self.compute_device = get_compute_device()

        print(f'Creating Model of type {self.model_type}')
        if self.model_type == 'CrabNet':
            self.model = CrabNet(self.compute_device,
                                 input_dims=self.input_dims,
                                 d_model=201,
                                 nhead=3,
                                 num_layers=3,
                                 dim_feedforward=64,
                                 dropout=0.1,
                                 edm=self.edm)
        elif self.model_type == 'DenseNet':
            self.model = DenseNet(self.compute_device,
                                  input_dims=self.input_dims,
                                  hidden_dims=self.hidden_dims,
                                  output_dims=self.output_dims,
                                  dropout=0.1,
                                  edm=self.edm)

        self.model.to(self.compute_device,
                      dtype=self.data_type,
                      non_blocking=True)

        self.num_network_params = count_parameters(self.model)
        print(f'number of network params: {self.num_network_params}')

        # self.criterion = nn.MSELoss()
        self.criterion = nn.L1Loss()
        # self.optim_lr = 1e-3
        # self.optimizer = optim.Adam(self.model.parameters(), lr=self.optim_lr)
        self.optim_lr = 5e-4
        # self.optimizer = optim.AdamW(self.model.parameters(), lr=1e-3)
        self.optimizer = optim.AdamW(self.model.parameters(),
                                     lr=self.optim_lr,
                                     weight_decay=1e-6)

        # Logging
        self.start_time = datetime.now()
        self.start_datetime = self.start_time.strftime('%Y-%m-%d-%H%M%S.%f')
        self.log_filename = (f'{self.start_datetime}-{self.model_type}-'
                             f'{self.elem_prop}-{self.mat_prop}.log')
        self.sub_dir = (f'{self.start_datetime}-{xstrh(self.random_seed)}'
                        f'{self.model_type}-'
                        f'{self.elem_prop}-{self.mat_prop}')
        self.log_dir = os.path.join(self.out_dir, self.sub_dir)

        if 'CUSTOM' in self.mat_prop:
            os.makedirs(self.log_dir, exist_ok=True)

        if self.save_network_info:
            os.makedirs(self.log_dir, exist_ok=True)
            self.log_file = os.path.join(self.out_dir, self.sub_dir,
                                         self.log_filename)

            print(56 * '*')
            print(f'creating and writing to log file {self.log_file}')
            print(56 * '*')
            with open(self.log_file, 'a') as f:
                try:
                    f.write('Start time: ')
                    f.write(f'{self.start_datetime}\n')
                    f.write(f'random seed: {self.random_seed}\n')
                    f.write('Model type: ')
                    f.write(f'{self.model_type}\n')
                    f.write('Material property: ')
                    f.write(f'{self.mat_prop}\n')
                    f.write('Element property: ')
                    f.write(f'{self.elem_prop}\n')
                    f.write(f'EDM input: {self.edm}\n')
                    f.write('Network architecture:\n')
                    f.write(f'{self.model}\n')
                    f.write(f'Number of params: ')
                    f.write(f'{self.num_network_params}\n')
                    f.write(f'CUDA count: {self.CUDA_count}\n')
                    f.write(f'Compute device: {self.compute_device}\n')
                    f.write('Criterion and Optimizer:\n')
                    f.write(f'{self.criterion}\n')
                    f.write(f'{self.optimizer}\n')
                    f.write(56 * '*' + '\n')
                except:
                    pass
Exemplo n.º 17
0
def train():

    # 1、make dataloader
    train_loader, val_loader, num_class = imagenet_make_data_loader(cfg)

    # 2、make model
    model = build_model(cfg, num_class)

    # 3、 make optimizer
    optimizer = make_optimizer(cfg, model)

    # 4、 make lr_scheduler
    scheduler = make_lr_scheduler(cfg, optimizer)

    # 5、make loss: default use softmax loss
    loss_fn = make_loss(cfg, num_class)

    # get parameters
    device = cfg.MODEL.DEVICE
    use_gpu = device == "cuda"
    pretrained = cfg.MODEL.PRETRAIN_PATH != ""
    parallel = cfg.MODEL.PARALLEL

    log_period = cfg.OUTPUT.LOG_PERIOD
    ckpt_period = cfg.OUTPUT.CKPT_PERIOD
    eval_period = cfg.OUTPUT.EVAL_PERIOD
    output_dir = cfg.OUTPUT.DIRS
    ckpt_save_path = output_dir + cfg.OUTPUT.CKPT_DIRS

    epochs = cfg.SOLVER.MAX_EPOCHS
    batch_size = cfg.SOLVER.BATCH_SIZE
    grad_clip = cfg.SOLVER.GRAD_CLIP

    batch_num = len(train_loader)
    log_iters = batch_num // log_period

    if not os.path.exists(ckpt_save_path):
        os.makedirs(ckpt_save_path)

    # create *_result.xlsx
    # save the result for analyze
    name = (cfg.OUTPUT.LOG_NAME).split(".")[0] + ".xlsx"
    result_path = cfg.OUTPUT.DIRS + name

    wb = xl.Workbook()
    sheet = wb.worksheets[0]
    titles = [
        'size/M', 'speed/ms', 'final_planes', 'acc', 'loss', 'acc', 'loss',
        'acc', 'loss'
    ]
    sheet.append(titles)
    check_epochs = [40, 80, 120, 160, 200, 240, 280, 320, 360, epochs]
    values = []

    logger = logging.getLogger("CDNet.train")
    size = count_parameters(model)
    values.append(format(size, '.2f'))
    values.append(model.final_planes)
    logger.info("the total parameters is: {:.2f}".format(size))

    logger.info("Starting Training CDNetwork")
    best_acc = 0.
    is_best = False
    avg_loss, avg_acc = RunningAverageMeter(), RunningAverageMeter()
    avg_time, global_avg_time = AverageMeter(), AverageMeter()

    if pretrained:
        start_epoch = model.start_epoch

    if parallel:
        model = nn.DataParallel(model)

    if use_gpu:
        model = model.to(device)

    for epoch in range(epochs):

        scheduler.step()
        lr = scheduler.get_lr()[0]

        # if save epoch_num k, then run k+1 epoch next
        if pretrained and epoch < start_epoch:
            continue

        # rest the record
        model.train()
        avg_loss.reset()
        avg_acc.reset()
        avg_time.reset()

        for i, batch in enumerate(train_loader):

            t0 = time.time()
            imgs, labels = batch

            if use_gpu:
                imgs = imgs.to(device)
                labels = labels.to(device)

            res = model(imgs)

            loss, acc = compute_loss_acc(res, labels, loss_fn)
            loss.backward()

            if grad_clip != 0:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()
            optimizer.zero_grad()

            t1 = time.time()
            avg_time.update((t1 - t0) / batch_size)
            avg_loss.update(loss)
            avg_acc.update(acc)

            # log info
            if (i + 1) % log_iters == 0:
                logger.info(
                    "epoch {}: {}/{} with loss is {:.5f} and acc is {:.3f}".
                    format(epoch + 1, i + 1, batch_num, avg_loss.avg,
                           avg_acc.avg))

        logger.info(
            "end epochs {}/{} with lr: {:.5f} and avg_time is: {:.3f} ms".
            format(epoch + 1, epochs, lr, avg_time.avg * 1000))
        global_avg_time.update(avg_time.avg)

        # test the model
        if (epoch + 1) % eval_period == 0 or (epoch + 1) in check_epochs:

            model.eval()
            logger.info("begin eval the model")
            val_acc = RunningAverageMeter()
            with torch.no_grad():

                for vi, batch in enumerate(val_loader):

                    imgs, labels = batch

                    if use_gpu:
                        imgs = imgs.to(device)
                        labels = labels.to(device)

                    res = model(imgs)
                    # acc = (scores.max(1)[1] == labels).float().mean()
                    _, acc = compute_loss_acc(res, labels)
                    val_acc.update(acc)

                logger.info("validation results at epoch:{}".format(epoch + 1))
                logger.info("acc:{:.2%}".format(val_acc.avg))

                # determine whether current model is the best
                if val_acc.avg > best_acc:
                    logger.info("get a new best acc")
                    best_acc = val_acc.avg
                    is_best = True

                # add the result to sheet
                if (epoch + 1) in check_epochs:
                    val = [
                        format(val_acc.avg * 100, '.2f'),
                        format(avg_loss.avg, '.3f')
                    ]
                    values.extend(val)
        # exit(1)
        # whether to save the model
        if (epoch + 1) % ckpt_period == 0 or is_best:
            if parallel:
                torch.save(
                    model.module.state_dict(),
                    ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
            else:
                torch.save(
                    model.state_dict(),
                    ckpt_save_path + "checkpoint_{}.pth".format(epoch + 1))
            logger.info("checkpoint {} was saved".format(epoch + 1))

            if is_best:
                if parallel:
                    torch.save(model.module.state_dict(),
                               ckpt_save_path + "best_ckpt.pth")
                else:
                    torch.save(model.state_dict(),
                               ckpt_save_path + "best_ckpt.pth")
                logger.info("best_checkpoint was saved")
                is_best = False

    values.insert(1, format(global_avg_time.avg * 1000, '.2f'))
    sheet.append(values)
    wb.save(result_path)
    logger.info("best_acc:{:.2%}".format(best_acc))
    logger.info("Ending training CDNetwork on imagenet")
Exemplo n.º 18
0
def test(args):

    if device != 'cpu':
        cudnn.benchmark = True
    checkpoints_dir = os.path.join(args.save_dir, args.model_name)
    # make sure these arguments are kept from commandline and not from loaded args
    vars_to_replace = [
        'batch_size', 'eval_split', 'imsize', 'root', 'save_dir'
    ]
    store_dict = {}
    for var in vars_to_replace:
        store_dict[var] = getattr(args, var)
    args, model_dict, _ = load_checkpoint(checkpoints_dir, 'best', map_loc,
                                          store_dict)
    for var in vars_to_replace:
        setattr(args, var, store_dict[var])

    loader, dataset = get_loader(args.root,
                                 args.batch_size,
                                 args.resize,
                                 args.imsize,
                                 augment=False,
                                 split=args.eval_split,
                                 mode='test',
                                 drop_last=False)
    print("Extracting features for %d samples from the %s set..." %
          (len(dataset), args.eval_split))
    vocab_size = len(dataset.get_vocab())
    model = get_model(args, vocab_size)

    print("recipe encoder", count_parameters(model.text_encoder))
    print("image encoder", count_parameters(model.image_encoder))

    model.load_state_dict(model_dict, strict=False)

    if device != 'cpu' and torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)

    model = model.to(device)
    model.eval()

    total_step = len(loader)
    loader = iter(loader)
    print("Loaded model from %s ..." % (checkpoints_dir))
    all_f1, all_f2 = None, None
    allids = []

    for _ in tqdm(range(total_step)):

        img, title, ingrs, instrs, ids = loader.next()

        img = img.to(device)
        title = title.to(device)
        ingrs = ingrs.to(device)
        instrs = instrs.to(device)
        with torch.no_grad():
            out = model(img, title, ingrs, instrs)
            f1, f2, _ = out
        allids.extend(ids)
        if all_f1 is not None:
            all_f1 = np.vstack((all_f1, f1.cpu().detach().numpy()))
            all_f2 = np.vstack((all_f2, f2.cpu().detach().numpy()))
        else:
            all_f1 = f1.cpu().detach().numpy()
            all_f2 = f2.cpu().detach().numpy()

    print("Done.")

    file_to_save = os.path.join(checkpoints_dir,
                                'feats_' + args.eval_split + '.pkl')
    print(np.shape(all_f1))

    with open(file_to_save, 'wb') as f:
        pickle.dump(all_f1, f)
        pickle.dump(all_f2, f)
        pickle.dump(allids, f)
    print("Saved features to disk.")
Exemplo n.º 19
0
def train_lstm(args):
    # gpus
    device = torch.device(
        'cuda' if args.cuda and torch.cuda.is_available() else 'cpu')

    # load vocabulary
    annfiles = [os.path.join(args.root_path, pth) for pth in args.annpaths]
    text_proc = build_vocab(annfiles, args.min_freq, args.max_seqlen)
    vocab_size = len(text_proc.vocab)

    # transforms
    sp = spt.Compose([spt.CornerCrop(size=args.imsize), spt.ToTensor()])
    tp = tpt.Compose([
        tpt.TemporalRandomCrop(args.clip_len),
        tpt.LoopPadding(args.clip_len)
    ])

    # dataloading
    train_dset = ActivityNetCaptions_Train(args.root_path,
                                           ann_path='train_fps.json',
                                           sample_duration=args.clip_len,
                                           spatial_transform=sp,
                                           temporal_transform=tp)
    trainloader = DataLoader(train_dset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.n_cpu,
                             drop_last=True,
                             timeout=100)
    max_train_it = int(len(train_dset) / args.batch_size)
    val_dset = ActivityNetCaptions_Val(
        args.root_path,
        ann_path=['val_1_fps.json', 'val_2_fps.json'],
        sample_duration=args.clip_len,
        spatial_transform=sp,
        temporal_transform=tp)
    valloader = DataLoader(val_dset,
                           batch_size=args.batch_size,
                           shuffle=True,
                           num_workers=args.n_cpu,
                           drop_last=True,
                           timeout=100)
    #max_val_it = int(len(val_dset) / args.batch_size)
    max_val_it = 10

    # models
    video_encoder = generate_3dcnn(args)
    caption_gen = generate_rnn(vocab_size, args)
    models = [video_encoder, caption_gen]

    # initialize pretrained embeddings
    if args.emb_init is not None:
        begin = time.time()
        print("initializing embeddings from {}...".format(args.emb_init))
        lookup = get_pretrained_from_txt(args.emb_init)
        first = next(iter(lookup.values()))
        try:
            assert len(first) == args.embedding_size
        except AssertionError:
            print("embedding size not compatible with pretrained embeddings.")
            print(
                "specified size {}, pretrained model includes size {}".format(
                    args.embedding_size, len(first)))
            sys.exit(1)
        matrix = torch.randn_like(caption_gen.emb.weight)
        for char, vec in lookup.items():
            if char in text_proc.vocab.stoi.keys():
                id = text_proc.vocab.stoi[char]
                matrix[id, :] = torch.tensor(vec)
        caption_gen.init_embedding(matrix)
        print("{} | successfully initialized".format(
            sec2str(time.time() - begin), args.emb_init))

    # move models to device
    n_gpu = torch.cuda.device_count()
    if n_gpu > 1 and args.dataparallel:
        video_encoder = nn.DataParallel(video_encoder)
        caption_gen = nn.DataParallel(caption_gen)
    else:
        n_gpu = 1
    print("using {} gpus...".format(n_gpu))

    # loss function
    criterion = nn.CrossEntropyLoss(ignore_index=text_proc.vocab.stoi['<pad>'])

    # optimizer, scheduler
    params = list(video_encoder.parameters()) + list(caption_gen.parameters())
    optimizer = optim.SGD([{
        "params": video_encoder.parameters(),
        "lr": args.lr_cnn,
        "momentum": args.momentum_cnn
    }, {
        "params": caption_gen.parameters(),
        "lr": args.lr_rnn,
        "momentum": args.momentum_rnn
    }],
                          weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='max',
                                                     factor=0.9,
                                                     patience=args.patience,
                                                     verbose=True)

    # count parameters
    num_params = sum(count_parameters(model) for model in models)
    print("# of params in model : {}".format(num_params))

    # joint training loop
    print("start training")
    begin = time.time()
    for ep in range(args.max_epochs):

        # train for epoch
        video_encoder, caption_gen, optimizer = train_epoch(
            trainloader,
            video_encoder,
            caption_gen,
            optimizer,
            criterion,
            device,
            text_proc,
            max_it=max_train_it,
            opt=args)

        # save models
        enc_save_dir = os.path.join(args.model_save_path, "encoder")
        enc_filename = "ep{:04d}.pth".format(ep + 1)
        if not os.path.exists(enc_save_dir):
            os.makedirs(enc_save_dir)
        enc_save_path = os.path.join(enc_save_dir, enc_filename)

        dec_save_dir = os.path.join(args.model_save_path, "decoder")
        dec_filename = "ep{:04d}.pth".format(ep + 1)
        dec_save_path = os.path.join(dec_save_dir, dec_filename)
        if not os.path.exists(dec_save_dir):
            os.makedirs(dec_save_dir)

        if n_gpu > 1 and args.dataparallel:
            torch.save(video_encoder.module.state_dict(), enc_save_path)
            torch.save(caption_gen.module.state_dict(), dec_save_path)
        else:
            torch.save(video_encoder.state_dict(), enc_save_path)
            torch.save(caption_gen.state_dict(), dec_save_path)

        print("saved encoder model to {}".format(enc_save_path))
        print("saved decoder model to {}".format(dec_save_path))

        # evaluate
        print("begin evaluation for epoch {} ...".format(ep + 1))
        nll, ppl, metrics = validate(valloader,
                                     video_encoder,
                                     caption_gen,
                                     criterion,
                                     device,
                                     text_proc,
                                     max_it=max_val_it,
                                     opt=args)
        if metrics is not None:
            scheduler.step(metrics["METEOR"])

        print(
            "training time {}, epoch {:04d}/{:04d} done, validation loss: {:.06f}, perplexity: {:.03f}"
            .format(sec2str(time.time() - begin), ep + 1, args.max_epochs, nll,
                    ppl))

    print("end training")
Exemplo n.º 20
0
                    writer = Writer(args, outdir=out_dir_ii)

                    # Defining model
                    model = SageModelWPathways(in_channels,
                                               n_classes,
                                               args.num_layers,
                                               args.hidden_gcn,
                                               args.hidden_fc,
                                               pathway_edge_index.to(device),
                                               n_cmt,
                                               edge_index,
                                               mode=args.feature_agg_mode,
                                               batchnorm=bn,
                                               do_layers=do_layers).to(device)

                    args.model_parameters = utils.count_parameters(model)
                    #     writer.save_args()
                    print('Number of parameters: {}'.format(
                        args.model_parameters))
                    print(model)

                    # OPTIMIZER
                    optimizer = torch.optim.Adam(model.parameters(),
                                                 lr=lr,
                                                 weight_decay=l2)
                    scheduler = torch.optim.lr_scheduler.StepLR(
                        optimizer, 1, gamma=args.lr_decay)

                    val_loss = train_eval.gcnmodel_run_es(
                        model, train_loader, val_loader, test_loader,
                        sample_weights.to(device), start_epoch, args.epochs,
Exemplo n.º 21
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    train_loader, val_loader = getDataLoader(args, logger)

    net = selectModel(args.model)

    # logger.info('%s' % net) if local_master else None

    # if args.pretrained_net is not None:
    #     logger.info('=> Loading pretrained Net: %s' % args.pretrained_net)
    #     # Enable training from a partially pretrained model
    #     utils.load_pretrained_net(net, args.pretrained_net, strict=args.strict, logger=logger)

    net.to(device)
    # if torch.cuda.device_count() > 1:
    if args.distributed:
        # aanet = torch.nn.DataParallel(aanet)
        #  尝试分布式训练
        net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
        net = torch.nn.parallel.DistributedDataParallel(
            net, device_ids=[local_rank], output_device=local_rank)
        synchronize()

    # Save parameters
    num_params = utils.count_parameters(net)
    logger.info('=> Number of trainable parameters: %d' % num_params)

    # 网络的特殊部分,设置特殊的学习率:specific_lr = args.learning_rate * 0.1
    params_group = setInitLR(net, args)

    # Optimizer
    optimizer = torch.optim.Adam(params_group, weight_decay=args.weight_decay)

    # Resume training
    if args.resume:
        # 1. resume Net
        start_epoch, start_iter, best_epe, best_epoch = utils.resume_latest_ckpt(
            args.checkpoint_dir, net, 'net_latest', False, logger)
        # 2. resume Optimizer
        utils.resume_latest_ckpt(args.checkpoint_dir, optimizer,
                                 'optimizer_latest', True, logger)
    else:
        start_epoch = 0
        start_iter = 0
        best_epe = None
        best_epoch = None

    # LR scheduler
    if args.lr_scheduler_type is not None:
        last_epoch = start_epoch if args.resume else start_epoch - 1
        if args.lr_scheduler_type == 'MultiStepLR':
            milestones = [int(step) for step in args.milestones.split(',')]
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=milestones,
                gamma=args.lr_decay_gamma,
                last_epoch=last_epoch
            )  # 最后这个last_epoch参数很重要:如果是resume的话,则会自动调整学习率适去应last_epoch。
        else:
            raise NotImplementedError
    # model.Model(net)对net做了进一步封装。
    train_model = model.Model(args,
                              logger,
                              optimizer,
                              net,
                              device,
                              start_iter,
                              start_epoch,
                              best_epe=best_epe,
                              best_epoch=best_epoch)
    logger.info('=> Start training...')

    for epoch in range(start_epoch, args.max_epoch):  # 训练主循环(Epochs)!!!
        # ensure distribute worker sample different data,
        # set different random seed by passing epoch to sampler
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
            logger.info('train_loader.sampler.set_epoch({})'.format(epoch))

        train_model.train(train_loader, local_master)

        if args.do_validate:
            train_model.validate(val_loader, local_master)  # 训练模式下:边训练边验证。

        if args.lr_scheduler_type is not None:
            lr_scheduler.step()  # 调整Learning Rate

    logger.info('=> End training\n\n')
Exemplo n.º 22
0
    def train(self):
        self.logger.save_config({"args:": self.args})
        if self.args.her_k > 0:
            buffer = ReplayBuffer(
                obs_dim=self.obs_dim,
                act_dim=self.env.unwrapped.action_space.shape[0],
                size=self.args.replay_size)
        else:
            buffer = ReplayBuffer(
                obs_dim=self.obs_dim,
                act_dim=self.env.unwrapped.action_space.shape[0],
                size=self.args.replay_size)

        var_counts = tuple(
            utils.count_parameters(module) for module in [
                self.main_net.policy, self.main_net.q1, self.main_net.q2,
                self.main_net.value_function, self.main_net
            ])
        self.logger.log(
            '\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d, \t V: %d, \t total: %d\n'
            % var_counts)

        start_time = time.time()
        obs, reward, done, episode_ret, episode_len = self.env.reset(
        ), 0, False, 0, 0
        tot_steps = 0
        for epoch in range(0, self.args.epochs):
            episode_buffer = []
            # Set the network in eval mode (e.g. Dropout, BatchNorm etc.)
            for t in range(self.args.max_episode_len):
                self.window.processEvents()
                if not self.window.isVisible():
                    return

                if tot_steps > self.args.start_steps:
                    action = self.get_action(obs)
                else:
                    action = self.env.action_space.sample()

                # Step in the env
                obs2, reward, done, _ = self.env.step(action)
                tot_steps += 1
                episode_len += 1
                episode_ret += reward

                done = False if episode_len == self.args.max_episode_len else done
                # Save and log
                episode_buffer.append((obs, action, reward, obs2, done))
                # buffer.store(obs, action, reward, obs2, done)
                obs = obs2

                if done or (episode_len == self.args.max_episode_len):
                    if self.args.her_k > 0:
                        for trans_id, (obs, action, reward, obs2,
                                       done) in enumerate(episode_buffer):
                            buffer.store(
                                np.concatenate([obs, self.env_target], axis=0),
                                action, reward,
                                np.concatenate([obs2, self.env_target],
                                               axis=0), done)
                            for k in range(self.args.her_k):
                                future_exp = np.random.randint(
                                    trans_id, len(episode_buffer))
                                _, _, _, her_obs2, _ = episode_buffer[
                                    future_exp]
                                her_reward, her_done = (
                                    reward + 100, True) if np.allclose(
                                        self.her_goal_f(her_obs2),
                                        self.her_goal_f(obs2)) else (reward,
                                                                     done)
                                buffer.store(
                                    np.concatenate(
                                        [obs, self.her_goal_f(her_obs2)],
                                        axis=0), action, her_reward,
                                    np.concatenate(
                                        [obs2, self.her_goal_f(her_obs2)],
                                        axis=0), her_done)

                    else:
                        for obs, action, reward, obs2, done in episode_buffer:
                            buffer.store(obs, action, reward, obs2, done)

                    self.update_net(buffer, episode_len)
                    self.logger.store(EpRet=episode_ret, EpLen=episode_len)
                    obs, reward, done, episode_ret, episode_len = self.env.reset(
                    ), 0, False, 0, 0

                    if (epoch % self.args.save_freq
                            == 0) or (epoch == self.args.epochs - 1):
                        self.logger.save_state({'env': self.env},
                                               self.main_net, None)

                    if epoch % self.args.log_freq == 0:
                        self.test_agent(self.args.eval_epochs)
                        # Log info about epoch
                        self.logger.log_tabular(tot_steps, 'Epoch', epoch)
                        self.logger.log_tabular(tot_steps,
                                                'EpRet',
                                                with_min_and_max=True)
                        self.logger.log_tabular(tot_steps,
                                                'TestEpRet',
                                                with_min_and_max=True)
                        self.logger.log_tabular(tot_steps,
                                                'EpLen',
                                                average_only=True)
                        self.logger.log_tabular(tot_steps,
                                                'TestEpLen',
                                                average_only=True)
                        self.logger.log_tabular(tot_steps, 'TotalEnvInteracts',
                                                tot_steps)
                        self.logger.log_tabular(tot_steps,
                                                'Q1Vals',
                                                with_min_and_max=True)
                        self.logger.log_tabular(tot_steps,
                                                'Q2Vals',
                                                with_min_and_max=True)
                        self.logger.log_tabular(tot_steps,
                                                'VVals',
                                                with_min_and_max=True)
                        self.logger.log_tabular(tot_steps,
                                                'LogPi',
                                                with_min_and_max=True)
                        self.logger.log_tabular(tot_steps,
                                                'LossPi',
                                                average_only=True)
                        self.logger.log_tabular(tot_steps,
                                                'LossQ1',
                                                average_only=True)
                        self.logger.log_tabular(tot_steps,
                                                'LossQ2',
                                                average_only=True)
                        self.logger.log_tabular(tot_steps,
                                                'LossV',
                                                average_only=True)
                        self.logger.log_tabular(tot_steps, 'Time',
                                                time.time() - start_time)
                        self.logger.dump_tabular()
                    break
Exemplo n.º 23
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Train loader
    train_transform_list = [
        transforms.RandomCrop(args.img_height, args.img_width),
        transforms.RandomColor(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ]
    train_transform = transforms.Compose(train_transform_list)

    train_data = dataloader.StereoDataset(
        data_dir=args.data_dir,
        dataset_name=args.dataset_name,
        mode='train' if args.mode != 'train_all' else 'train_all',
        load_pseudo_gt=args.load_pseudo_gt,
        transform=train_transform)

    logger.info('=> {} training samples found in the training set'.format(
        len(train_data)))

    train_loader = DataLoader(dataset=train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True,
                              drop_last=True)

    # Validation loader
    val_transform = transforms.Compose([
        transforms.RandomCrop(args.val_img_height,
                              args.val_img_width,
                              validate=True),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])
    val_data = dataloader.StereoDataset(data_dir=args.data_dir,
                                        dataset_name=args.dataset_name,
                                        mode=args.mode,
                                        transform=val_transform)

    val_loader = DataLoader(dataset=val_data,
                            batch_size=args.val_batch_size,
                            shuffle=False,
                            num_workers=args.num_workers,
                            pin_memory=True,
                            drop_last=False)

    # Network
    aanet = nets.AANet(
        args.max_disp,
        num_downsample=args.num_downsample,
        feature_type=args.feature_type,
        no_feature_mdconv=args.no_feature_mdconv,
        feature_pyramid=args.feature_pyramid,
        feature_pyramid_network=args.feature_pyramid_network,
        feature_similarity=args.feature_similarity,
        aggregation_type=args.aggregation_type,
        num_scales=args.num_scales,
        num_fusions=args.num_fusions,
        num_stage_blocks=args.num_stage_blocks,
        num_deform_blocks=args.num_deform_blocks,
        no_intermediate_supervision=args.no_intermediate_supervision,
        refinement_type=args.refinement_type,
        mdconv_dilation=args.mdconv_dilation,
        deformable_groups=args.deformable_groups).to(device)

    logger.info('%s' % aanet)

    if args.pretrained_aanet is not None:
        logger.info('=> Loading pretrained AANet: %s' % args.pretrained_aanet)
        # Enable training from a partially pretrained model
        utils.load_pretrained_net(aanet,
                                  args.pretrained_aanet,
                                  no_strict=(not args.strict))

    if torch.cuda.device_count() > 1:
        logger.info('=> Use %d GPUs' % torch.cuda.device_count())
        aanet = torch.nn.DataParallel(aanet)

    # Save parameters
    num_params = utils.count_parameters(aanet)
    logger.info('=> Number of trainable parameters: %d' % num_params)
    save_name = '%d_parameters' % num_params
    open(os.path.join(args.checkpoint_dir, save_name), 'a').close()

    # Optimizer
    # Learning rate for offset learning is set 0.1 times those of existing layers
    specific_params = list(
        filter(utils.filter_specific_params, aanet.named_parameters()))
    base_params = list(
        filter(utils.filter_base_params, aanet.named_parameters()))

    specific_params = [kv[1]
                       for kv in specific_params]  # kv is a tuple (key, value)
    base_params = [kv[1] for kv in base_params]

    specific_lr = args.learning_rate * 0.1
    params_group = [
        {
            'params': base_params,
            'lr': args.learning_rate
        },
        {
            'params': specific_params,
            'lr': specific_lr
        },
    ]

    optimizer = torch.optim.Adam(params_group, weight_decay=args.weight_decay)

    # Resume training
    if args.resume:
        # AANet
        start_epoch, start_iter, best_epe, best_epoch = utils.resume_latest_ckpt(
            args.checkpoint_dir, aanet, 'aanet')

        # Optimizer
        utils.resume_latest_ckpt(args.checkpoint_dir, optimizer, 'optimizer')
    else:
        start_epoch = 0
        start_iter = 0
        best_epe = None
        best_epoch = None

    # LR scheduler
    if args.lr_scheduler_type is not None:
        last_epoch = start_epoch if args.resume else start_epoch - 1
        if args.lr_scheduler_type == 'MultiStepLR':
            milestones = [int(step) for step in args.milestones.split(',')]
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=milestones,
                gamma=args.lr_decay_gamma,
                last_epoch=last_epoch)
        else:
            raise NotImplementedError

    train_model = model.Model(args,
                              logger,
                              optimizer,
                              aanet,
                              device,
                              start_iter,
                              start_epoch,
                              best_epe=best_epe,
                              best_epoch=best_epoch)

    logger.info('=> Start training...')

    if args.evaluate_only:
        assert args.val_batch_size == 1
        train_model.validate(val_loader)
    else:
        for _ in range(start_epoch, args.max_epoch):
            if not args.evaluate_only:
                train_model.train(train_loader)
            if not args.no_validate:
                train_model.validate(val_loader)
            if args.lr_scheduler_type is not None:
                lr_scheduler.step()

        logger.info('=> End training\n\n')
Exemplo n.º 24
0
def main():
    logger = create_logger(cfg.local_rank, save_dir=cfg.log_dir)
    summary_writer = create_summary(cfg.local_rank, log_dir=cfg.log_dir)
    print = logger.info

    print(cfg)
    num_gpus = torch.cuda.device_count()
    if cfg.dist:
        device = torch.device(
            'cuda:%d' % cfg.local_rank) if cfg.dist else torch.device('cuda')
        torch.cuda.set_device(cfg.local_rank)
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=num_gpus,
                                rank=cfg.local_rank)
    else:
        device = torch.device('cuda')

    print('==> Preparing data..')
    cifar = 100 if 'cifar100' in cfg.log_name else 10

    train_dataset = CIFAR_split(
        cifar=cifar,
        root=cfg.data_dir,
        split='train',
        ratio=0.5,
        transform=cifar_search_transform(is_training=True))
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=num_gpus, rank=cfg.local_rank)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.batch_size // num_gpus if cfg.dist else cfg.batch_size,
        shuffle=not cfg.dist,
        num_workers=cfg.num_workers,
        sampler=train_sampler if cfg.dist else None)

    val_dataset = CIFAR_split(
        cifar=cifar,
        root=cfg.data_dir,
        split='val',
        ratio=0.5,
        transform=cifar_search_transform(is_training=False))
    val_sampler = torch.utils.data.distributed.DistributedSampler(
        val_dataset, num_replicas=num_gpus, rank=cfg.local_rank)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.batch_size // num_gpus if cfg.dist else cfg.batch_size,
        shuffle=not cfg.dist,
        num_workers=cfg.num_workers,
        sampler=val_sampler if cfg.dist else None)

    print('==> Building model..')
    model = Network(C=cfg.init_ch,
                    num_cells=cfg.num_cells,
                    num_nodes=cfg.num_nodes,
                    multiplier=cfg.num_nodes,
                    num_classes=cifar)

    if not cfg.dist:
        model = nn.DataParallel(model).to(device)
    else:
        # model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = model.to(device)
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[
                cfg.local_rank,
            ], output_device=cfg.local_rank)

    # proxy_model is used for 2nd order update
    if cfg.order == '2nd':
        proxy_model = Network(cfg.init_ch, cfg.num_cells, cfg.num_nodes).cuda()

    count_parameters(model)

    weights = [v for k, v in model.named_parameters() if 'alpha' not in k]
    alphas = [v for k, v in model.named_parameters() if 'alpha' in k]
    optimizer_w = optim.SGD(weights,
                            cfg.w_lr,
                            momentum=0.9,
                            weight_decay=cfg.w_wd)
    optimizer_a = optim.Adam(alphas,
                             lr=cfg.a_lr,
                             betas=(0.5, 0.999),
                             weight_decay=cfg.a_wd)
    criterion = nn.CrossEntropyLoss().cuda()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer_w,
                                                     cfg.max_epochs,
                                                     eta_min=cfg.w_min_lr)

    alphas = []

    def train(epoch):
        model.train()
        print('\nEpoch: %d lr: %f' % (epoch, scheduler.get_lr()[0]))
        alphas.append([])
        start_time = time.time()

        for batch_idx, ((inputs_w, targets_w), (inputs_a, targets_a)) \
            in enumerate(zip(train_loader, val_loader)):

            inputs_w, targets_w = inputs_w.to(device), targets_w.to(
                device, non_blocking=True)
            inputs_a, targets_a = inputs_a.to(device), targets_a.to(
                device, non_blocking=True)

            # 1. update alpha
            if epoch > cfg.a_start:
                optimizer_a.zero_grad()

                if cfg.order == '1st':
                    # using 1st order update
                    outputs = model(inputs_a)
                    val_loss = criterion(outputs, targets_a)
                    val_loss.backward()
                else:
                    # using 2nd order update
                    val_loss = update(model, proxy_model, criterion,
                                      optimizer_w, inputs_a, targets_a,
                                      inputs_w, targets_w)

                optimizer_a.step()
            else:
                val_loss = torch.tensor([0]).cuda()

            # 2. update weights
            outputs = model(inputs_w)
            cls_loss = criterion(outputs, targets_w)

            optimizer_w.zero_grad()
            cls_loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer_w.step()

            if batch_idx % cfg.log_interval == 0:
                step = len(train_loader) * epoch + batch_idx
                duration = time.time() - start_time

                print(
                    '[%d/%d - %d/%d] cls_loss: %5f val_loss: %5f (%d samples/sec)'
                    % (epoch, cfg.max_epochs, batch_idx, len(train_loader),
                       cls_loss.item(), val_loss.item(),
                       cfg.batch_size * cfg.log_interval / duration))

                start_time = time.time()
                summary_writer.add_scalar('cls_loss', cls_loss.item(), step)
                summary_writer.add_scalar('val_loss', val_loss.item(), step)
                summary_writer.add_scalar('learning rate',
                                          optimizer_w.param_groups[0]['lr'],
                                          step)

                alphas[-1].append(
                    model.module.alpha_normal.detach().cpu().numpy())
                alphas[-1].append(
                    model.module.alpha_reduce.detach().cpu().numpy())
        return

    def eval(epoch):
        model.eval()

        correct = 0
        total_loss = 0
        with torch.no_grad():
            for step, (inputs, targets) in enumerate(val_loader):
                inputs, targets = inputs.to(device), targets.to(
                    device, non_blocking=True)

                outputs = model(inputs)
                total_loss += criterion(outputs, targets).item()
                _, predicted = torch.max(outputs.data, 1)
                correct += predicted.eq(targets.data).cpu().sum().item()

            acc = 100. * correct / len(val_loader.dataset)
            total_loss = total_loss / len(val_loader)
            print('Val_loss==> %.5f Precision@1 ==> %.2f%% \n' %
                  (total_loss, acc))
            summary_writer.add_scalar('Precision@1', acc, global_step=epoch)
            summary_writer.add_scalar('val_loss_per_epoch',
                                      total_loss,
                                      global_step=epoch)
        return

    for epoch in range(cfg.max_epochs):
        train_sampler.set_epoch(epoch)
        val_sampler.set_epoch(epoch)
        train(epoch)
        eval(epoch)
        scheduler.step(epoch)  # move to here after pytorch1.1.0
        print(model.module.genotype())
        if cfg.local_rank == 0:
            torch.save(alphas, os.path.join(cfg.ckpt_dir, 'alphas.t7'))
            torch.save(model.state_dict(),
                       os.path.join(cfg.ckpt_dir, 'search_checkpoint.t7'))
            torch.save({'genotype': model.module.genotype()},
                       os.path.join(cfg.ckpt_dir, 'genotype.t7'))

    summary_writer.close()
Exemplo n.º 25
0
    def train(self):
        self.logger.save_config({"args:": self.args})
        buffer = VPGBuffer(obs_dim=self.env.unwrapped.observation_space.shape,
                           act_dim=self.env.unwrapped.action_space.shape,
                           size=(self.args.batch_size + 1) *
                           self.args.max_episode_len,
                           gamma=self.args.gae_gamma,
                           lam=self.args.gae_lambda)

        var_counts = tuple(
            utils.count_parameters(module) for module in
            [self.actor_critic.policy, self.actor_critic.value_function])
        self.logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' %
                        var_counts)

        start_time = time.time()
        tot_steps = 0
        obs, reward, done, episode_ret, episode_len = self.env.reset(
        ), 0, False, 0, 0
        for epoch in range(0, self.args.epochs):
            # Set the network in eval mode (e.g. Dropout, BatchNorm etc.)
            self.actor_critic.eval()
            for t in range(self.args.max_episode_len):
                action, _, logp_t, v_t = self.actor_critic(
                    torch.Tensor(obs).unsqueeze(dim=0).to(self.device))

                # Save and log
                buffer.store(obs,
                             action.detach().cpu().numpy(), reward, v_t.item(),
                             logp_t.detach().cpu().numpy())
                self.logger.store(VVals=v_t)

                obs, reward, done, _ = self.env.step(
                    action.detach().cpu().numpy()[0])
                tot_steps += 1
                episode_ret += reward
                episode_len += 1

                self.window.processEvents()
                if self.render_enabled and epoch % self.renderSpin.value(
                ) == 0:
                    self.window.render(self.env)
                    time.sleep(self.window.renderSpin.value())
                if not self.window.isVisible():
                    return

                terminal = done or (episode_len == self.args.max_episode_len)
                if terminal:
                    if not terminal:
                        print(
                            'Warning: trajectory cut off by epoch at %d steps.'
                            % episode_len)
                    last_val = reward if done else self.actor_critic.value_function(
                        torch.Tensor(obs).to(self.device).unsqueeze(
                            dim=0)).item()
                    buffer.finish_path(last_val=last_val)

                    if epoch % self.args.batch_size == 0:
                        self.actor_critic.train(
                        )  # Switch module to training mode
                        self.update_net(buffer.get())
                        self.actor_critic.eval()
                    if terminal:
                        self.logger.store(EpRet=episode_ret, EpLen=episode_len)
                    obs, reward, done, episode_ret, episode_len = self.env.reset(
                    ), 0, False, 0, 0
                    break

            if (epoch % self.args.save_freq == 0) or (epoch
                                                      == self.args.epochs - 1):
                self.logger.save_state({'env': self.env.unwrapped},
                                       self.actor_critic, None)
                pass

            # Log info about epoch
            self.logger.log_tabular(tot_steps, 'Epoch', epoch)
            self.logger.log_tabular(tot_steps, 'EpRet', with_min_and_max=True)
            self.logger.log_tabular(tot_steps, 'EpLen', average_only=True)
            self.logger.log_tabular(tot_steps, 'VVals', with_min_and_max=True)
            self.logger.log_tabular(tot_steps, 'TotalEnvInteracts', tot_steps)
            if epoch % self.args.batch_size == 0:
                self.logger.log_tabular(tot_steps, 'LossPi', average_only=True)
                self.logger.log_tabular(tot_steps, 'LossV', average_only=True)
                self.logger.log_tabular(tot_steps,
                                        'DeltaLossPi',
                                        average_only=True)
                self.logger.log_tabular(tot_steps,
                                        'DeltaLossV',
                                        average_only=True)
                self.logger.log_tabular(tot_steps,
                                        'Entropy',
                                        average_only=True)
                self.logger.log_tabular(tot_steps, 'KL', average_only=True)
            self.logger.log_tabular(tot_steps, 'Time',
                                    time.time() - start_time)
            self.logger.dump_tabular()