def train(args): dataset_dir = get_dataset_dir(args) log_dir = get_log_dir(args) model_class = get_model_class(args) training_set, validation_set, num_classes = create_datasets(dataset_dir) training_dataset = Dataset(training_set, transform_for_training(model_class.IMAGE_SHAPE)) validation_dataset = Dataset(validation_set, transform_for_infer(model_class.IMAGE_SHAPE)) training_dataloader = torch.utils.data.DataLoader( training_dataset, batch_size=args.batch_size, num_workers=6, shuffle=True) validation_dataloader = torch.utils.data.DataLoader( validation_dataset, batch_size=args.batch_size, num_workers=6, shuffle=False) model = model_class(num_classes).to(device) trainables_wo_bn = [ param for name, param in model.named_parameters() if param.requires_grad and "bn" not in name ] trainables_only_bn = [ param for name, param in model.named_parameters() if param.requires_grad and "bn" in name ] optimizer = torch.optim.SGD( [ { "params": trainables_wo_bn, "weight_decay": 0.0001 }, { "params": trainables_only_bn }, ], lr=args.lr, momentum=0.9, ) trainer = Trainer( optimizer, model, training_dataloader, validation_dataloader, max_epoch=args.epochs, resume=args.resume, log_dir=log_dir, ) trainer.train()
def train(args): dataset_dir = get_dataset_dir(args) log_dir = get_log_dir(args) model_class = get_model_class(args) # orgainzesz dataset go more into training_set, validation_set, num_classes = create_datasets(dataset_dir) #look more into training_dataset = Dataset(training_set, transform_for_training(model_class.IMAGE_SHAPE)) validation_dataset = Dataset(validation_set, transform_for_infer(model_class.IMAGE_SHAPE)) training_dataloader = torch.utils.data.DataLoader( training_dataset, batch_size=args.batch_size, num_workers=6, shuffle=True) validation_dataloader = torch.utils.data.DataLoader( validation_dataset, batch_size=args.batch_size, num_workers=6, shuffle=False) # what is Numclases and devie model = model_class(num_classes).to(device) # tain trainables_wo_bn and trainables_only_bn trainables_wo_bn = [ param for name, param in model.named_parameters() if param.requires_grad and 'bn' not in name ] trainables_only_bn = [ param for name, param in model.named_parameters() if param.requires_grad and 'bn' in name ] optimizer = torch.optim.SGD([{ 'params': trainables_wo_bn, 'weight_decay': 0.0001 }, { 'params': trainables_only_bn }], lr=args.lr, momentum=0.9) trainer = Trainer(optimizer, model, training_dataloader, validation_dataloader, max_epoch=args.epochs, resume=args.resume, log_dir=log_dir) trainer.train()
def train(args): t_training_set = [] t_validation_set = [] t_num_classes = [] dataset_dir = get_dataset_dir(args) log_dir = get_log_dir(args) model_class = get_model_class(args) if args.w != 0: w_training_set, w_validation_set, num_classes_w = create_datasetsW( dataset_dir) w_training_set = w_training_set[0:(int(args.w / 2))] w_validation_set = w_validation_set[0:(int(args.w / 2))] t_training_set = + [w_training_set] t_validation_set = + [w_validation_set] t_num_classes = +num_classes_w if args.sa != 0: sa_training_set, sa_validation_set, num_classes_sa = create_datasetsSA( dataset_dir) sa_training_set = sa_training_set[0:int(args.sa / 2)] sa_validation_set = sa_validation_set[0:int(args.sa / 2)] t_training_set = +sa_training_set t_validation_set = +sa_validation_set t_num_classes = +num_classes_sa if args.ai != 0: as_training_set, as_validation_set, num_classes_as = create_datasetsAs( dataset_dir) as_training_set = as_training_set[0:(int(args.ai / 2)) + 1] as_validation_set = as_validation_set[0:(int(args.ai / 2)) + 1] t_training_set = +as_training_set t_validation_set = +as_validation_set t_num_classes = +num_classes_as if args.af != 0: af_training_set, af_validation_set, classes_af = create_datasetsAF( dataset_dir) af_training_set = af_training_set[0:(int(args.ai / 2)) + 1] af_validation_set = af_validation_set[0:(int(args.ai / 2)) + 1] t_training_set.append(af_training_set) t_validation_set.append(af_validation_set) t_num_classes += classes_af training_set = t_training_set validation_set = t_validation_set num_classes = t_num_classes training_dataset = Dataset(training_set, transform_for_training(model_class.IMAGE_SHAPE)) validation_dataset = Dataset(validation_set, transform_for_infer(model_class.IMAGE_SHAPE)) training_dataloader = torch.utils.data.DataLoader( training_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) validation_dataloader = torch.utils.data.DataLoader( validation_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) model = model_class(num_classes).to(device) trainables_wo_bn = [ param for name, param in model.named_parameters() if param.requires_grad and 'bn' not in name ] trainables_only_bn = [ param for name, param in model.named_parameters() if param.requires_grad and 'bn' in name ] optimizer = torch.optim.SGD([{ 'params': trainables_wo_bn, 'weight_decay': 0.0001 }, { 'params': trainables_only_bn }], lr=args.lr, momentum=0.9) trainer = Trainer(optimizer, model, training_dataloader, validation_dataloader, max_epoch=args.epochs, resume=args.resume, log_dir=log_dir) trainer.train()