def train(self, args, warmup_epochs=5, warmup_lr=0):
        for epoch in range(self.start_epoch,
                           self.run_config.n_epochs + warmup_epochs):
            train_loss, (train_top1, train_top5) = self.train_one_epoch(
                args, epoch, warmup_epochs, warmup_lr)
            img_size, val_loss, val_top1, val_top5 = self.validate_all_resolution(
                epoch, is_test=False)

            is_best = list_mean(val_top1) > self.best_acc
            self.best_acc = max(self.best_acc, list_mean(val_top1))
            if self.is_root:
                val_log = '[{0}/{1}]\tloss {2:.3f}\t{6} acc {3:.3f} ({4:.3f})\t{7} acc {5:.3f}\t' \
                          'Train {6} {top1:.3f}\tloss {train_loss:.3f}\t'. \
                 format(epoch + 1 - warmup_epochs, self.run_config.n_epochs, list_mean(val_loss),
                           list_mean(val_top1), self.best_acc, list_mean(val_top5), *self.get_metric_names(),
                           top1=train_top1, train_loss=train_loss)
                for i_s, v_a in zip(img_size, val_top1):
                    val_log += '(%d, %.3f), ' % (i_s, v_a)
                self.write_log(val_log, prefix='valid', should_print=False)

                self.save_model(
                    {
                        'epoch': epoch,
                        'best_acc': self.best_acc,
                        'optimizer': self.optimizer.state_dict(),
                        'state_dict': self.net.state_dict(),
                    },
                    is_best=is_best)
Example #2
0
	def build_acc_dataset(self, run_manager, ofa_network, n_arch=1000, image_size_list=None):
		# load net_id_list, random sample if not exist
		if os.path.isfile(self.net_id_path):
			net_id_list = json.load(open(self.net_id_path))
		else:
			net_id_list = set()
			while len(net_id_list) < n_arch:
				net_setting = ofa_network.sample_active_subnet()
				net_id = net_setting2id(net_setting)
				net_id_list.add(net_id)
			net_id_list = list(net_id_list)
			net_id_list.sort()
			json.dump(net_id_list, open(self.net_id_path, 'w'), indent=4)

		image_size_list = [128, 160, 192, 224] if image_size_list is None else image_size_list

		with tqdm(total=len(net_id_list) * len(image_size_list), desc='Building Acc Dataset') as t:
			for image_size in image_size_list:
				# load val dataset into memory
				val_dataset = []
				run_manager.run_config.data_provider.assign_active_img_size(image_size)
				for images, labels in run_manager.run_config.valid_loader:
					val_dataset.append((images, labels))
				# save path
				os.makedirs(self.acc_src_folder, exist_ok=True)
				acc_save_path = os.path.join(self.acc_src_folder, '%d.dict' % image_size)
				acc_dict = {}
				# load existing acc dict
				if os.path.isfile(acc_save_path):
					existing_acc_dict = json.load(open(acc_save_path, 'r'))
				else:
					existing_acc_dict = {}
				for net_id in net_id_list:
					net_setting = net_id2setting(net_id)
					key = net_setting2id({**net_setting, 'image_size': image_size})
					if key in existing_acc_dict:
						acc_dict[key] = existing_acc_dict[key]
						t.set_postfix({
							'net_id': net_id,
							'image_size': image_size,
							'info_val': acc_dict[key],
							'status': 'loading',
						})
						t.update()
						continue
					ofa_network.set_active_subnet(**net_setting)
					run_manager.reset_running_statistics(ofa_network)
					net_setting_str = ','.join(['%s_%s' % (
						key, '%.1f' % list_mean(val) if isinstance(val, list) else val
					) for key, val in net_setting.items()])
					loss, (top1, top5) = run_manager.validate(
						run_str=net_setting_str, net=ofa_network, data_loader=val_dataset, no_logs=True,
					)
					info_val = top1

					t.set_postfix({
						'net_id': net_id,
						'image_size': image_size,
						'info_val': info_val,
					})
					t.update()

					acc_dict.update({
						key: info_val
					})
					json.dump(acc_dict, open(acc_save_path, 'w'), indent=4)
def validate(run_manager,
             epoch=0,
             is_test=False,
             image_size_list=None,
             ks_list=None,
             expand_ratio_list=None,
             depth_list=None,
             width_mult_list=None,
             additional_setting=None):
    dynamic_net = run_manager.net
    if isinstance(dynamic_net, nn.DataParallel):
        dynamic_net = dynamic_net.module

    dynamic_net.eval()

    if image_size_list is None:
        image_size_list = val2list(
            run_manager.run_config.data_provider.image_size, 1)
    if ks_list is None:
        ks_list = dynamic_net.ks_list
    if expand_ratio_list is None:
        expand_ratio_list = dynamic_net.expand_ratio_list
    if depth_list is None:
        depth_list = dynamic_net.depth_list
    if width_mult_list is None:
        if 'width_mult_list' in dynamic_net.__dict__:
            width_mult_list = list(range(len(dynamic_net.width_mult_list)))
        else:
            width_mult_list = [0]

    subnet_settings = []
    for d in depth_list:
        for e in expand_ratio_list:
            for k in ks_list:
                for w in width_mult_list:
                    for img_size in image_size_list:
                        subnet_settings.append([{
                            'image_size': img_size,
                            'd': d,
                            'e': e,
                            'ks': k,
                            'w': w,
                        },
                                                'R%s-D%s-E%s-K%s-W%s' %
                                                (img_size, d, e, k, w)])
    if additional_setting is not None:
        subnet_settings += additional_setting

    losses_of_subnets, top1_of_subnets, top5_of_subnets = [], [], []

    valid_log = ''
    for setting, name in subnet_settings:
        run_manager.write_log('-' * 30 + ' Validate %s ' % name + '-' * 30,
                              'train',
                              should_print=False)
        run_manager.run_config.data_provider.assign_active_img_size(
            setting.pop('image_size'))
        dynamic_net.set_active_subnet(**setting)
        run_manager.write_log(dynamic_net.module_str,
                              'train',
                              should_print=False)

        run_manager.reset_running_statistics(dynamic_net)
        loss, (top1, top5) = run_manager.validate(epoch=epoch,
                                                  is_test=is_test,
                                                  run_str=name,
                                                  net=dynamic_net)
        losses_of_subnets.append(loss)
        top1_of_subnets.append(top1)
        top5_of_subnets.append(top5)
        valid_log += '%s (%.3f), ' % (name, top1)

    return list_mean(losses_of_subnets), list_mean(top1_of_subnets), list_mean(
        top5_of_subnets), valid_log
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0):
    dynamic_net = run_manager.network
    distributed = isinstance(run_manager, DistributedRunManager)

    # switch to train mode
    dynamic_net.train()
    if distributed:
        run_manager.run_config.train_loader.sampler.set_epoch(epoch)
    MyRandomResizedCrop.EPOCH = epoch

    nBatch = len(run_manager.run_config.train_loader)

    data_time = AverageMeter()
    losses = DistributedMetric('train_loss') if distributed else AverageMeter()
    metric_dict = run_manager.get_metric_dict()

    with tqdm(total=nBatch,
              desc='Train Epoch #{}'.format(epoch + 1),
              disable=distributed and not run_manager.is_root) as t:
        end = time.time()
        for i, (images,
                labels) in enumerate(run_manager.run_config.train_loader):
            MyRandomResizedCrop.BATCH = i
            data_time.update(time.time() - end)
            if epoch < warmup_epochs:
                new_lr = run_manager.run_config.warmup_adjust_learning_rate(
                    run_manager.optimizer,
                    warmup_epochs * nBatch,
                    nBatch,
                    epoch,
                    i,
                    warmup_lr,
                )
            else:
                new_lr = run_manager.run_config.adjust_learning_rate(
                    run_manager.optimizer, epoch - warmup_epochs, i, nBatch)

            images, labels = images.cuda(), labels.cuda()
            target = labels

            # soft target
            if args.kd_ratio > 0:
                args.teacher_model.train()
                with torch.no_grad():
                    soft_logits = args.teacher_model(images).detach()
                    soft_label = F.softmax(soft_logits, dim=1)

            # clean gradients
            dynamic_net.zero_grad()

            loss_of_subnets = []
            # compute output
            subnet_str = ''
            for _ in range(args.dynamic_batch_size):
                # set random seed before sampling
                subnet_seed = int('%d%.3d%.3d' % (epoch * nBatch + i, _, 0))
                random.seed(subnet_seed)
                subnet_settings = dynamic_net.sample_active_subnet()
                subnet_str += '%d: ' % _ + ','.join([
                    '%s_%s' %
                    (key, '%.1f' %
                     subset_mean(val, 0) if isinstance(val, list) else val)
                    for key, val in subnet_settings.items()
                ]) + ' || '

                output = run_manager.net(images)
                if args.kd_ratio == 0:
                    loss = run_manager.train_criterion(output, labels)
                    loss_type = 'ce'
                else:
                    if args.kd_type == 'ce':
                        kd_loss = cross_entropy_loss_with_soft_target(
                            output, soft_label)
                    else:
                        kd_loss = F.mse_loss(output, soft_logits)
                    loss = args.kd_ratio * kd_loss + run_manager.train_criterion(
                        output, labels)
                    loss_type = '%.1fkd-%s & ce' % (args.kd_ratio,
                                                    args.kd_type)

                # measure accuracy and record loss
                loss_of_subnets.append(loss)
                run_manager.update_metric(metric_dict, output, target)

                loss.backward()
            run_manager.optimizer.step()

            losses.update(list_mean(loss_of_subnets), images.size(0))

            t.set_postfix({
                'loss':
                losses.avg.item(),
                **run_manager.get_metric_vals(metric_dict, return_dict=True),
                'R':
                images.size(2),
                'lr':
                new_lr,
                'loss_type':
                loss_type,
                'seed':
                str(subnet_seed),
                'str':
                subnet_str,
                'data_time':
                data_time.avg,
            })
            t.update(1)
            end = time.time()
    return losses.avg.item(), run_manager.get_metric_vals(metric_dict)
Example #5
0
            'Group %d: %d params with wd %f' %
            (i + 1, len(param_group['params']), param_group['weight_decay']),
            'grad_params', True, 'a')
    for name, param in net.named_parameters():
        if param.requires_grad:
            run_manager.write_log('%s: %s' % (name, list(param.data.size())),
                                  'grad_params', False, 'a')

    run_manager.save_config()
    if args.resume:
        run_manager.load_model()
    else:
        init_path = '%s/init' % args.path
        if os.path.isfile(init_path):
            checkpoint = torch.load(init_path, map_location='cpu')
            if 'state_dict' in checkpoint:
                checkpoint = checkpoint['state_dict']
            run_manager.network.load_state_dict(checkpoint)

    # train
    args.teacher_model = None
    run_manager.train(args)
    # test
    img_size, loss, acc1, acc5 = run_manager.validate_all_resolution(
        is_test=True)
    log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f\t' % (
        list_mean(loss), list_mean(acc1), list_mean(acc5))
    for i_s, v_a in zip(img_size, acc1):
        log += '(%d, %.3f), ' % (i_s, v_a)
    run_manager.write_log(log, prefix='test')