def main(config):
    with SummaryWriter(
            comment='_{}_{}'.format(config.arch, config.dataset)) as writer:
        dataset_config = datasets.cifar10(
        ) if config.dataset == 'cifar10' else datasets.cifar100()
        num_classes = dataset_config.pop('num_classes')
        train_loader, eval_loader = create_data_loaders(**dataset_config,
                                                        config=config)

        dummy_input = (torch.randn(10, 3, 32, 32), )
        net = arch[config.arch](num_classes)
        writer.add_graph(net, dummy_input)

        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        criterion = create_loss_fn(config)
        if config.is_parallel:
            net = torch.nn.DataParallel(net).to(device)
        else:
            device = 'cuda:{}'.format(
                config.gpu) if torch.cuda.is_available() else 'cpu'
            net = net.to(device)
        optimizer = create_optim(net.parameters(), config)
        scheduler = create_lr_scheduler(optimizer, config)

        trainer = Trainer.PseudoLabel(net, optimizer, criterion, device,
                                      config, writer)
        trainer.loop(config.epochs,
                     train_loader,
                     eval_loader,
                     scheduler=scheduler,
                     print_freq=config.print_freq)
Beispiel #2
0
def main(config):
    # SummaryWriter画图用的
    with SummaryWriter(
            comment='_{}_{}'.format(config.arch, config.dataset)) as writer:
        # 选择datasets中的cifar10
        dataset_config = datasets.FPN(
            config) if config.dataset == 'FPN' else datasets.cifar10()
        # dataset_config = datasets.cifar10() if config.dataset == 'cifar10' else datasets.cifar100()
        # num_classes为类别数
        # num_classes = dataset_config.pop('num_classes')
        train_loader, eval_loader = data_loaders(**dataset_config,
                                                 config=config)
        # torch.set_default_dtype(torch.float64)
        dummy_input = torch.randn(1, 1, 200, 200)  # 添加一个模型的图
        #net = dw(FPN_ResNet18())
        net = FPN_ResNet18()
        # net = Fpn_n()
        writer.add_graph(net, dummy_input)
        ###
        # checkpoint = torch.load(config.PATH)
        # net.load_state_dict(checkpoint['weight'])
        ###
        device1 = 'cuda' if torch.cuda.is_available() else 'cpu'
        criterion_l = create_loss_fn(config)
        if config.is_parallel:
            net = torch.nn.DataParallel(net).to(device1)
        else:
            device1 = 'cuda:{}'.format(
                config.gpu) if torch.cuda.is_available() else 'cpu'
            net = net.to(device1)
        optimizer = create_optim(net.parameters(), config)
        if config.train:
            trainer = Trainer.PseudoLabel(net,
                                          optimizer,
                                          criterion_l,
                                          device1,
                                          config,
                                          writer,
                                          save_dir='./model')
            scheduler = create_lr_scheduler(optimizer, config)
            trainer.loop(config.epochs,
                         train_loader,
                         eval_loader,
                         scheduler=scheduler,
                         print_freq=config.print_freq)
        else:
            checkpoint = torch.load(config.PATH)
            net.load_state_dict(checkpoint['weight'])
            trainer = Trainer.PseudoLabel(net,
                                          optimizer,
                                          criterion_l,
                                          device1,
                                          config,
                                          writer,
                                          save_dir='./model')
            trainer.testonce(eval_loader, print_freq=config.print_freq)


# class dw(torch.nn.Module):
#     '''
#     '''
#
#     def __init__(self, model):
#         '''
#         '''
#
#         # initialize the module using super() constructor
#         super(dw, self).__init__()
#         # assign the architectures
#         self.model = model
#         # assign the weights for each task
#         self.weights = torch.nn.Parameter(torch.ones(2).float())
#
#     def forward(self, x):
#         out = self.model(x)
#         return out