예제 #1
0
def main():
	args = parser.parse_args()
	for arg in vars(args):
		print(arg, " : ", getattr(args, arg))
	
	augment = not args.no_augment
	train_loader, val_loader = train_util.load_data(
		args.dataset, 
		args.batch_size, 
		dataset_path=args.data_dir,
		augment=augment)
	
	print("=> creating model '{}'".format(args.arch))
	model_args = {
		"num_classes": 10 if args.dataset == "cifar10" else 100	
	}
	model = models.__dict__[args.arch](**model_args)
	print("Device count", torch.cuda.device_count())
	if args.parallel:
		model = nn.DataParallel(model)

	print('Number of model parameters: {}'.format(
		sum([p.data.nelement() for p in model.parameters()])))

	model = model.cuda()

	cudnn.benchmark = True

	criterion = nn.CrossEntropyLoss().cuda()
	optim_hparams = {
		'base_lr' : args.lr, 
		'momentum' : args.momentum,
		'weight_decay' : args.weight_decay,
		'optim_type' : args.optim_type
	}

	lr_hparams = {
		'lr_sched' : args.lr_sched, 
		'use_iter': args.train_by_iters}

	lr_hparams['iters_per_epoch'] = args.iters_per_epoch if args.iters_per_epoch else 391

	inner_lr_hparams = {
		'lr_sched' : args.inner_anneal,
		'use_iter' : args.train_by_iters}

	inner_lr_hparams['iters_per_epoch'] = args.iters_per_epoch if args.iters_per_epoch else 391

	optimizer = optim_util.create_optimizer(
		model,
		optim_hparams)

	curr_iter = args.start_iter
	epoch = args.start_epoch

	best_val = 0

	inner_opt = optim_util.one_step_optim(
		model, args.inner_lr)
	while True:
		model.train()
		train_acc = train_util.AverageMeter()
		train_loss = train_util.AverageMeter()
		timer = train_util.AverageMeter()
		for i, (input_data, target) in enumerate(train_loader):
					
			lr = lr_util.adjust_lr(
				optimizer,
				epoch,
				curr_iter,
				args.lr,
				lr_hparams)

			inner_lr = lr_util.adjust_lr(
				inner_opt,
				epoch,
				curr_iter,
				args.inner_lr,
				inner_lr_hparams)

			target = target.cuda(non_blocking=True)
			input_data = input_data.cuda()

			update_hparams = {
				'update_type' : args.update_type.split('zero_switch_')[-1],
				'inner_lr' : inner_lr[0],
				'use_bn' : not args.no_bn,
				'label_noise' : 0,
				'use_norm_one' : args.use_norm_one
			}

			if args.label_noise > 0:
				label_noise = train_util.label_noise_sched(
					args.label_noise, 
					epoch, 
					curr_iter, 
					args.train_by_iters, 
					args.ln_sched, 
					iters_per_epoch=args.iters_per_epoch,
					ln_decay=args.ln_decay)
				if args.update_type != 'mean_zero_label_noise' or args.also_flip_labels:
					# if it is equal, we don't want to flip the labels
					target = train_util.apply_label_noise(
						target,
						label_noise,
						num_classes=10 if args.dataset == 'cifar10' else 100)

				update_hparams['label_noise'] = label_noise
				
			loss, output, time_taken = update_loss_util.update_step(
				criterion,
				optimizer,
				model,
				input_data,
				target,
				update_hparams)

			prec1 = accuracy(output.data, target, topk=(1,))[0]
			train_loss.update(loss, target.size(0))
			train_acc.update(prec1, target.size(0))
			timer.update(time_taken, 1)
			avg_loss = train_loss.avg
			avg_acc = train_acc.avg

			loss_str = 'Loss '
			loss_str += '{:.4f} (standard)\t'.format(avg_loss)


			if i % args.print_freq == 0:
				log_str = ('Epoch: [{0}][{1}/{2}]\t'

				  'Time {3:.3f}\t {4}'
				  'Prec@1 {5:.3f})').format(
					  epoch, i, len(train_loader), timer.avg, loss_str, avg_acc)
				print(log_str)

			curr_iter += 1

		print("Validating accuracy.")
		val_acc, val_loss = train_util.validate(
			val_loader,
			model,
			criterion,
			epoch,
			print_freq=args.print_freq)

		is_best = val_acc > best_val
		best_val = val_acc if is_best else best_val

		print('Best accuracy: ', best_val)

		epoch += 1
		if args.train_by_iters:
			if curr_iter > args.iters:
				break
		else:
			if epoch > args.epochs:
				break
예제 #2
0
def main():
    args = parser.parse_args()
    for arg in vars(args):
        print(arg, " : ", getattr(args, arg))
    timestamp = datetime.utcnow().strftime("%H_%M_%S_%f-%d_%m_%y")
    save_str = "arch_%s_reg_%s_%s" % (args.arch, args.reg_type, timestamp)
    save_dir = os.path.join(args.save_dir, save_str)

    augment = not args.no_augment
    train_loader, val_loader = train_util.load_data(args.dataset,
                                                    args.batch_size,
                                                    dataset_path=args.data_dir,
                                                    augment=augment)

    print("=> creating model '{}'".format(args.arch))
    model_args = {"num_classes": 10 if args.dataset == "cifar10" else 100}
    if args.reg_type == 'dropout':
        print("Using dropout.")
        model_args['dropRate'] = args.dropout
    model = models.__dict__[args.arch](**model_args)

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model = model.cuda()

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_val = checkpoint['best_val']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss().cuda()
    optim_hparams = {
        'base_lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }
    lr_hparams = {'lr_sched': args.lr_sched}
    optimizer = train_util.create_optimizer(model, optim_hparams)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    save_util.write_args(args, save_dir)
    scalar_summary_file = os.path.join(save_dir, "scalars.txt")
    scalar_dict = {}
    best_val = 0
    all_dict = {}

    for epoch in range(args.start_epoch, args.epochs):
        lr = train_util.adjust_lr(optimizer, epoch + 1, args.lr, lr_hparams)

        train_hparams = {
            "reg_type":
            args.reg_type,
            "noise_level":
            train_util.adjust_act_noise(args.act_noise_decay,
                                        args.act_noise_decay_rate,
                                        args.act_noise, epoch + 1)
        }

        train_acc, train_loss = train_util.train_loop(
            train_loader,
            model,
            criterion,
            optimizer,
            epoch,
            train_hparams,
            print_freq=args.print_freq)

        print("Validating accuracy.")
        val_acc, val_loss = train_util.validate(val_loader,
                                                model,
                                                criterion,
                                                epoch,
                                                print_freq=args.print_freq)

        scalar_epoch = {
            "lr": lr,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "val_loss": val_loss,
            "val_acc": val_acc
        }

        scalar_dict[epoch + 1] = scalar_epoch

        save_util.log_scalar_file(scalar_epoch, epoch + 1, scalar_summary_file)

        is_best = val_acc > best_val
        best_val = max(val_acc, best_val)

        save_util.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_val': best_val,
            }, scalar_dict, is_best, save_dir)

        print('Best accuracy: ', best_val)
예제 #3
0
def main():
    args = parser.parse_args()
    for arg in vars(args):
        print(arg, " : ", getattr(args, arg))
    timestamp = datetime.utcnow().strftime("%H_%M_%S_%f-%d_%m_%y")
    save_str = "arch_%s_reg_%s_%s" % (args.arch, args.reg_type, timestamp)
    save_dir = os.path.join(args.save_dir, save_str)

    augment = not args.no_augment
    train_loader, val_loader = train_util.load_data(args.dataset,
                                                    args.batch_size,
                                                    dataset_path=args.data_dir,
                                                    augment=augment)

    print("=> creating model '{}'".format(args.arch))
    model_args = {"num_classes": 10 if args.dataset == "cifar10" else 100}
    model = models.__dict__[args.arch](**model_args)

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    if args.resume:
        model_file = os.path.join(args.resume, "checkpoint.pth.tar")
        model = save_util.load_model_sdict(model, model_file)

    model = model.cuda()

    cudnn.benchmark = True

    criterion = nn.CrossEntropyLoss().cuda()
    optim_hparams = {
        'base_lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }
    lr_hparams = {'lr_sched': args.lr_sched}
    optimizer = train_util.create_optimizer(model, optim_hparams)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    save_util.write_args(args, save_dir)
    scalar_summary_file = os.path.join(save_dir, "scalars.txt")
    scalar_dict = {}
    best_val = 0
    best_epoch = -1
    for epoch in range(args.start_epoch, args.epochs):
        lr = train_util.adjust_lr(optimizer, epoch + 1, args.lr, lr_hparams)

        train_hparams = {
            "reg_type": args.reg_type,
            "data_reg": args.data_reg,
            "j_thresh": args.j_thresh
        }

        train_acc, train_loss = train_util.train_loop(
            train_loader,
            model,
            criterion,
            optimizer,
            epoch,
            train_hparams,
            print_freq=args.print_freq)

        val_acc, val_loss = train_util.validate(val_loader,
                                                model,
                                                criterion,
                                                epoch,
                                                print_freq=args.print_freq)

        is_best = val_acc > best_val
        best_val = max(val_acc, best_val)
        if is_best:
            best_epoch = epoch + 1

        scalar_epoch = {
            "lr": lr,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "val_loss": val_loss,
            "val_acc": val_acc,
            "best_val": best_val,
            "best_epoch": best_epoch
        }
        scalar_dict[epoch + 1] = scalar_epoch

        save_util.log_scalar_file(scalar_epoch, epoch + 1, scalar_summary_file)

        save_util.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_val': best_val,
            }, scalar_dict, is_best, save_dir)

        print('Best accuracy: ', best_val)
예제 #4
0
                dest='hidden_input',
                action='store',
                default=150,
                type=int)
ap.add_argument('--learning_rate',
                dest='learning_rate',
                action='store',
                default=0.001,
                type=float)
ap.add_argument('--epochs', dest='epochs', action='store', default=5, type=int)
ap.add_argument('--gpu', dest='mode', action='store', default="gpu")
ap.add_argument('--save_path',
                dest='save_path',
                action='store',
                default='checkpoint.pth')

args = ap.parse_args()

# Load Data
train_data, val_data, test_data, train_loader, val_loader = train_util.load_data(
    args.data_dir)

# Set the network
model = train_util.set_network(args.model_name, args.hidden_input)

# Train the model
train_util.train_network(model, train_loader, val_loader, args.learning_rate,
                         args.epochs, args.mode)

# Save the model checkpoint
train_util.save_checkpoint(train_data, model, args)