Beispiel #1
0
# reload_file = join(vae_dir, 'best.pkl')
# if not args.noreload and exists(reload_file):
#	 state = torch.load(reload_file)
#	 print("Reloading model at epoch {}"
#		   ", with test error {}".format(
#			   state['epoch'],
#			   state['precision']))
#	 model.load_state_dict(state['state_dict'])
#	 optimizer.load_state_dict(state['optimizer'])
#	 trained=state['epoch']
	#trained=0
	# scheduler.load_state_dict(state['scheduler'])
	# earlystopping.load_state_dict(state['earlystopping'])
state = torch.load('/home/ld/gym-car/log/vae/contorl_checkpoint_52.pkl')
controller.load_state_dict(state['state_dict'])
optimizer_a.load_state_dict(state['optimizer'])
print('contorller load success')
state = torch.load('/home/ld/gym-car/log/vae/pre_checkpoint_52.pkl')
model_p.load_state_dict(state['state_dict'])
optimizer_p.load_state_dict(state['optimizer'])
print('prediction load success')
state = torch.load('/home/ld/gym-car/log/vae/vae_checkpoint_52.pkl')
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
trained=state['epoch']
print('vae load success')
trained=0
cur_best = None
all_data=6000
sample_data=1000
Beispiel #2
0
def main():
    global args

    np.random.seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    if args.fixed_arc:
        sys.stdout = Logger(filename='logs/' + args.output_filename + '_fixed.log')
    else:
        sys.stdout = Logger(filename='logs/' + args.output_filename + '.log')

    print(args)

    data_loaders = load_datasets()

    controller = Controller(search_for=args.search_for,
                            search_whole_channels=True,
                            num_layers=args.child_num_layers,
                            num_branches=args.child_num_branches,
                            out_filters=args.child_out_filters,
                            lstm_size=args.controller_lstm_size,
                            lstm_num_layers=args.controller_lstm_num_layers,
                            tanh_constant=args.controller_tanh_constant,
                            temperature=None,
                            skip_target=args.controller_skip_target,
                            skip_weight=args.controller_skip_weight)
    controller = controller.cuda()

    shared_cnn = SharedCNN(num_layers=args.child_num_layers,
                           num_branches=args.child_num_branches,
                           out_filters=args.child_out_filters,
                           keep_prob=args.child_keep_prob)
    shared_cnn = shared_cnn.cuda()

    # https://github.com/melodyguan/enas/blob/master/src/utils.py#L218
    controller_optimizer = torch.optim.Adam(params=controller.parameters(),
                                            lr=args.controller_lr,
                                            betas=(0.0, 0.999),
                                            eps=1e-3)

    # https://github.com/melodyguan/enas/blob/master/src/utils.py#L213
    shared_cnn_optimizer = torch.optim.SGD(params=shared_cnn.parameters(),
                                           lr=args.child_lr_max,
                                           momentum=0.9,
                                           nesterov=True,
                                           weight_decay=args.child_l2_reg)

    # https://github.com/melodyguan/enas/blob/master/src/utils.py#L154
    shared_cnn_scheduler = CosineAnnealingLR(optimizer=shared_cnn_optimizer,
                                             T_max=args.child_lr_T,
                                             eta_min=args.child_lr_min)

    if args.resume:
        if os.path.isfile(args.resume):
            print("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            # args = checkpoint['args']
            shared_cnn.load_state_dict(checkpoint['shared_cnn_state_dict'])
            controller.load_state_dict(checkpoint['controller_state_dict'])
            shared_cnn_optimizer.load_state_dict(checkpoint['shared_cnn_optimizer'])
            controller_optimizer.load_state_dict(checkpoint['controller_optimizer'])
            shared_cnn_scheduler.optimizer = shared_cnn_optimizer  # Not sure if this actually works
            print("Loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            raise ValueError("No checkpoint found at '{}'".format(args.resume))
    else:
        start_epoch = 0

    if not args.fixed_arc:
        train_enas(start_epoch,
                   controller,
                   shared_cnn,
                   data_loaders,
                   shared_cnn_optimizer,
                   controller_optimizer,
                   shared_cnn_scheduler)
    else:
        assert args.resume != '', 'A pretrained model should be used when training a fixed architecture.'
        train_fixed(start_epoch,
                    controller,
                    shared_cnn,
                    data_loaders)