def main(): if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) random.seed(args.seed) np.random.seed(args.seed) torch.cuda.set_device(args.gpu) torch.backends.cudnn.benchmark = True torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) logging.info('gpu device = %d' % args.gpu) logging.info("args = %s", args) model = CNN(args) model.cuda() controller = Controller(args) controller.cuda() baseline = None optimizer = torch.optim.SGD( model.parameters(), args.child_lr_max, momentum=args.momentum, weight_decay=args.weight_decay, ) controller_optimizer = torch.optim.Adam( controller.parameters(), args.controller_lr, betas=(0.1, 0.999), eps=1e-3, ) train_loader, reward_loader, valid_loader = get_loaders(args) scheduler = utils.LRScheduler(optimizer, args) # zychen param_calculator = ParamCalculation(args.param_target) for epoch in range(args.epochs): lr = scheduler.update(epoch) logging.info('epoch %d lr %e', epoch, lr) # training train_acc = train(train_loader, model, controller, optimizer) logging.info('train_acc %f', train_acc) train_controller(reward_loader, model, controller, controller_optimizer, param_calculator) # validation valid_acc = infer(valid_loader, model, controller, param_calculator) logging.info('valid_acc %f', valid_acc) utils.save(model, os.path.join(args.save, 'weights.pt'))
def main(): start_time = time.time() if not torch.cuda.is_available(): logging.info('no gpu device available') sys.exit(1) random.seed(args.seed) np.random.seed(args.seed) torch.cuda.set_device(args.gpu) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) logging.info('gpu device = {}'.format(args.gpu)) logging.info("args = {}".format(args)) #supernet model = model_maker(cell_nums=args.child_num_cells, out_filters=args.child_out_filters, normal_block_repeat=[4, 4], classes=args.num_class, aux=args.child_use_aux_heads) #generator controller = separable_LSTM(2) model.start_conv1 = nn.Sequential( conv2d_std(in_channels=3, out_channels=args.child_out_filters, kernel_size=3, stride=2), nn.BatchNorm2d(args.child_out_filters, track_running_stats=False), Mish(), conv2d_std(in_channels=args.child_out_filters, out_channels=args.child_out_filters, kernel_size=3, stride=2), nn.BatchNorm2d(args.child_out_filters, track_running_stats=False), Mish()) model.start_conv2 = nn.Sequential( conv2d_std(in_channels=3, out_channels=args.child_out_filters, kernel_size=3, stride=2), nn.BatchNorm2d(args.child_out_filters, track_running_stats=False), Mish(), conv2d_std(in_channels=args.child_out_filters, out_channels=args.child_out_filters, kernel_size=3, stride=2), nn.BatchNorm2d(args.child_out_filters, track_running_stats=False), Mish()) logging.info('Total params: {:.6f}M'.format( (sum(p.numel() for p in model.parameters()) / 1000000.0))) optimizer = torch.optim.SGD( model.parameters(), args.child_lr_max, momentum=args.momentum, weight_decay=args.weight_decay, ) #generator's optimizer controller_optimizer = Ranger( controller.parameters(), args.controller_lr, #betas=(0.1,0.999), #eps=1e-3, ) controller.cuda() model.cuda() train_loader, reward_loader, valid_loader = get_loaders(args) #utils.BatchNorm2d_replace(model) model.cuda() model.apply(utils.initialize_weights) parameters = utils.add_weight_decay(model, args.weight_decay) criterion = nn.CrossEntropyLoss( ) #utils.CrossEntropyLabelSmooth(num_classes = 10) model, optimizer = amp.initialize(model, optimizer, opt_level="O0") scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, args.epochs) #int(args.epochs*0.3)) lr = args.child_lr_max for epoch in tqdm(range(args.epochs)): training_start = time.time() logging.info('epoch {:0>3d} lr {:.6f}'.format(epoch, lr)) # training drop_prob = args.droppath * epoch / args.epochs model.drop_path_prob(drop_prob) starter = True if epoch == 0 else False train_acc = train(train_loader, model, controller, optimizer, criterion, start=starter) scheduler.step() lr = scheduler.get_lr()[-1] logging.info('train_acc {:.3f}'.format(train_acc)) train_controller(reward_loader, model, controller, controller_optimizer) # validation valid_acc = infer(valid_loader, model, controller, criterion) logging.info('valid_acc {:.3f}'.format(valid_acc)) if (epoch + 1) % args.report_freq == 0: utils.save(model, os.path.join(args.save, 'weights.pt')) utils.save(controller, os.path.join(args.save, 'controller.pt')) epoch_inter_time = int(time.time() - training_start) #print(f'Trainging 1 Epoch ,Total time consumption {epoch_inter_time} /s ') print('Trainging 1 Epoch ,Total time consumption {} /s '.format( epoch_inter_time)) #logging.info(f'Trainging Complete ,Total time consumption {int(time.time()-start_time)} /s ') logging.info('Trainging Complete ,Total time consumption {} /s '.format( int(time.time() - start_time)))