Ejemplo n.º 1
0
def load(config):
    if config.model == 'wideresnet':
        model = WideResNet(num_classes=config.dataset_classes)
        ema_model = WideResNet(num_classes=config.dataset_classes)
    else:
        model = CNN13(num_classes=config.dataset_classes)
        ema_model = CNN13(num_classes=config.dataset_classes)

    if config.semi_supervised == 'mix_match':
        semi_supervised = MixMatch(config)
        semi_supervised_loss = mix_match_loss
    elif config.semi_supervised == 'pseudo_label':
        semi_supervised = PseudoLabel(config)
        semi_supervised_loss = pseudo_label_loss

    model.to(config.device)
    ema_model.to(config.device)

    torch.backends.cudnn.benchmark = True

    optimizer = Adam(model.parameters(), lr=config.learning_rate)
    ema_optimizer = WeightEMA(model, ema_model, alpha=config.ema_decay)

    if config.resume:
        checkpoint = torch.load(config.checkpoint_path, map_location=config.device)

        model.load_state_dict(checkpoint['model_state'])
        ema_model.load_state_dict(checkpoint['ema_model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])

        # optimizer state should be moved to corresponding device
        for optimizer_state in optimizer.state.values():
            for k, v in optimizer_state.items():
                if isinstance(v, torch.Tensor):
                    optimizer_state[k] = v.to(config.device)

    return model, ema_model, optimizer, ema_optimizer, semi_supervised, semi_supervised_loss
Ejemplo n.º 2
0
    lbl_loader = DataLoader(labeled_dataset,
                            batch_size=B,
                            collate_fn=labeled_collate,
                            num_workers=args.num_workers2,
                            pin_memory=True,
                            shuffle=True)
    ulbl_loader = DataLoader(unlabeled_dataset,
                             batch_size=mu * B,
                             collate_fn=unlabeled_collate,
                             num_workers=args.num_workers3,
                             pin_memory=True,
                             shuffle=True)

    #  Model Settings

    model.to(device)
    ema = EMA(model, decay=0.999)
    ema.register()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)
    scheduler = cosineLRreduce(optimizer, K, warmup=args.warmup_scheduler)

    train_fixmatch(model, ema, zip(lbl_loader, ulbl_loader), v_loader,
                   augmentation, optimizer, scheduler, device, K, tb_writer)
    tb_writer.close()

    # Save everything