val_dataloader_brand_list[eval_brand_txt] = DataLoader(val_data_brand_list[eval_brand_txt], batch_size=batch_size, shuffle=False, num_workers=num_workers)


if model_type == "s2":
    model = shufflenetv2(num_classes=num_classes, model_width=model_width, image_h=args.image_h, image_w=args.image_w)
#elif model_type == "cnn":
#    model = cnn.cnn(num_classes=num_classes, model_width=model_width, image_h=args.image_h, image_w=args.image_w)
#elif model_type == "ef":
#    model = efficientnet.efficientnet_b0(num_classes=num_classes)
#elif model_type == "eff_caff":
#    model = eff_caff.get_from_name("efficientnet-b0")
#elif model_type == "gfeff":
#    model = EfficientNet.from_name('efficientnet-b0')
#    feature = model._fc.in_features
#    model._fc = nn.Linear(in_features=feature,out_features=2,bias=True)
elif model_type=="s2_test":
    model = shuf(num_classes=num_classes, model_width=model_width, image_h=args.image_h, image_w=args.image_w) 
# optimizer
trainable_vars = [param for param in model.parameters() if param.requires_grad]
print("Training with sgd")
params.optimizer = torch.optim.SGD(trainable_vars, lr=init_lr,
                                   momentum=momentum,
                                   weight_decay=weight_decay,
                                   nesterov=nesterov)

# Train
params.lr_scheduler = ReduceLROnPlateau(params.optimizer, 'min', factor=lr_decay, patience=10, cooldown=10, verbose=True)
#print('init lr is {}'.format(params.optimizer.get_lr()))
trainer = Trainer(model, args.model_name,  params, train_dataloader, val_dataloader, val_dataloader_brand_list, train_data.get_data_len(), num_classes=args.num_classes)
trainer.train()